diff --git a/.github/doxygen/Doxyfile b/.github/doxygen/Doxyfile index 067bc80c..3786c0a5 100644 --- a/.github/doxygen/Doxyfile +++ b/.github/doxygen/Doxyfile @@ -48,7 +48,7 @@ PROJECT_NAME = "LLM for Unity" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = v2.5.2 +PROJECT_NUMBER = v3.0.0 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/Editor/LLMCallerEditor.cs b/Editor/LLMClientEditor.cs similarity index 61% rename from Editor/LLMCallerEditor.cs rename to Editor/LLMClientEditor.cs index 8bdd8f58..50b7cac2 100644 --- a/Editor/LLMCallerEditor.cs +++ b/Editor/LLMClientEditor.cs @@ -5,11 +5,11 @@ namespace LLMUnity { - [CustomEditor(typeof(LLMCaller), true)] + [CustomEditor(typeof(LLMClient), true)] public class LLMCallerEditor : PropertyEditor {} - [CustomEditor(typeof(LLMCharacter), true)] - public class LLMCharacterEditor : LLMCallerEditor + [CustomEditor(typeof(LLMAgent), true)] + public class LLMAgentEditor : LLMCallerEditor { public override void AddModelSettings(SerializedObject llmScriptSO) { @@ -23,26 +23,14 @@ public override void AddModelSettings(SerializedObject llmScriptSO) ShowPropertiesOfClass("", llmScriptSO, new List { typeof(ModelAttribute) }, false); EditorGUILayout.BeginHorizontal(); - GUILayout.Label("Grammar", GUILayout.Width(EditorGUIUtility.labelWidth)); if (GUILayout.Button("Load grammar", GUILayout.Width(buttonWidth))) { EditorApplication.delayCall += () => { - string path = EditorUtility.OpenFilePanelWithFilters("Select a gbnf grammar file", "", new string[] { "Grammar Files", "gbnf" }); + string path = EditorUtility.OpenFilePanelWithFilters("Select a gbnf grammar file", "", new string[] { "Grammar Files", "json,gbnf" }); if (!string.IsNullOrEmpty(path)) { - ((LLMCharacter)target).SetGrammar(path); - } - }; - } - if (GUILayout.Button("Load JSON grammar", GUILayout.Width(buttonWidth))) - { - EditorApplication.delayCall += () => - { - string path = EditorUtility.OpenFilePanelWithFilters("Select a json schema grammar file", "", new string[] { "Grammar Files", "json" }); - if (!string.IsNullOrEmpty(path)) - { - ((LLMCharacter)target).SetJSONGrammar(path); + ((LLMAgent)target).LoadGrammar(path); } }; } diff --git a/Runtime/LLMCharacter.cs.meta b/Editor/LLMClientEditor.cs.meta similarity index 83% rename from Runtime/LLMCharacter.cs.meta rename to Editor/LLMClientEditor.cs.meta index bfe1cd03..a8c8d2fc 100644 --- a/Runtime/LLMCharacter.cs.meta +++ b/Editor/LLMClientEditor.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 3f6c87a428fd5d0be9bbc686bdc8c3c2 +guid: 5a933780afd25b58aa243136435108ec MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index 5ed29afe..1863a631 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -44,29 +44,17 @@ void AddSSLLoad(string type, Callback setterCallback) } } - void AddSSLInfo(string propertyName, string type, Callback setterCallback) - { - string path = llmScriptSO.FindProperty(propertyName).stringValue; - if (path != "") - { - EditorGUILayout.BeginHorizontal(); - EditorGUILayout.LabelField("SSL " + type + " path", path); - if (GUILayout.Button(trashIcon, GUILayout.Height(actionColumnWidth), GUILayout.Width(actionColumnWidth))) setterCallback(""); - EditorGUILayout.EndHorizontal(); - } - } - EditorGUILayout.LabelField("Server Security Settings", EditorStyles.boldLabel); - EditorGUILayout.PropertyField(llmScriptSO.FindProperty("APIKey")); + EditorGUILayout.PropertyField(llmScriptSO.FindProperty("_APIKey")); if (llmScriptSO.FindProperty("advancedOptions").boolValue) { EditorGUILayout.BeginHorizontal(); - AddSSLLoad("certificate", llmScript.SetSSLCert); - AddSSLLoad("key", llmScript.SetSSLKey); + AddSSLLoad("certificate", llmScript.SetSSLCertFromFile); + AddSSLLoad("key", llmScript.SetSSLKeyFromFile); EditorGUILayout.EndHorizontal(); - AddSSLInfo("SSLCertPath", "certificate", llmScript.SetSSLCert); - AddSSLInfo("SSLKeyPath", "key", llmScript.SetSSLKey); + EditorGUILayout.PropertyField(llmScriptSO.FindProperty("_SSLCert")); + EditorGUILayout.PropertyField(llmScriptSO.FindProperty("_SSLKey")); } Space(); } @@ -111,7 +99,7 @@ public override void AddModelSettings(SerializedObject llmScriptSO) if (llmScriptSO.FindProperty("advancedOptions").boolValue) { attributeClasses.Add(typeof(ModelAdvancedAttribute)); - if (LLMUnitySetup.FullLlamaLib) attributeClasses.Add(typeof(ModelExtrasAttribute)); + attributeClasses.Add(typeof(ModelExtrasAttribute)); } ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false); Space(); @@ -126,10 +114,10 @@ static void ResetModelOptions() { List existingOptions = new List(); foreach (ModelEntry entry in LLMManager.modelEntries) existingOptions.Add(entry.url); - modelOptions = new List(){"Download model", "Custom URL"}; - modelNames = new List(){null, null}; - modelURLs = new List(){null, null}; - modelLicenses = new List(){null, null}; + modelOptions = new List() { "Download model", "Custom URL" }; + modelNames = new List() { null, null }; + modelURLs = new List() { null, null }; + modelLicenses = new List() { null, null }; foreach (var entry in LLMUnitySetup.modelOptions) { string category = entry.Key; @@ -146,9 +134,9 @@ static void ResetModelOptions() float[] GetColumnWidths(bool expandedView) { - List widths = new List(){actionColumnWidth, nameColumnWidth, templateColumnWidth}; - if (expandedView) widths.AddRange(new List(){textColumnWidth, textColumnWidth}); - widths.AddRange(new List(){includeInBuildColumnWidth, actionColumnWidth}); + List widths = new List() { actionColumnWidth, nameColumnWidth, templateColumnWidth }; + if (expandedView) widths.AddRange(new List() { textColumnWidth, textColumnWidth }); + widths.AddRange(new List() { includeInBuildColumnWidth, actionColumnWidth }); return widths.ToArray(); } @@ -351,19 +339,6 @@ void OnEnable() DrawCopyableLabel(nameRect, entry.label, entry.filename); - if (!entry.lora) - { - string[] templateDescriptions = ChatTemplate.templatesDescription.Keys.ToList().ToArray(); - string[] templates = ChatTemplate.templatesDescription.Values.ToList().ToArray(); - int templateIndex = Array.IndexOf(templates, entry.chatTemplate); - int newTemplateIndex = EditorGUI.Popup(templateRect, templateIndex, templateDescriptions); - if (newTemplateIndex != templateIndex) - { - LLMManager.SetTemplate(entry.filename, templates[newTemplateIndex]); - UpdateModels(); - } - } - if (expandedView) { if (hasURL) @@ -440,14 +415,14 @@ private void DrawCopyableLabel(Rect rect, string label, string text = "") private void CopyToClipboard(string text) { - TextEditor te = new TextEditor {text = text}; + TextEditor te = new TextEditor { text = text }; te.SelectAll(); te.Copy(); } public void AddExtrasToggle() { - if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib); + if (ToggleButton("Use cuBLAS", LLMUnitySetup.CUBLAS)) LLMUnitySetup.SetCUBLAS(!LLMUnitySetup.CUBLAS); } public override void AddOptionsToggles(SerializedObject llmScriptSO) @@ -481,7 +456,7 @@ public override void OnInspectorGUI() AddOptionsToggles(llmScriptSO); AddSetupSettings(llmScriptSO); - if (llmScriptSO.FindProperty("remote").boolValue) AddSecuritySettings(llmScriptSO, llmScript); + if (llmScriptSO.FindProperty("_remote").boolValue) AddSecuritySettings(llmScriptSO, llmScript); AddModelLoadersSettings(llmScriptSO, llmScript); AddChatSettings(llmScriptSO); diff --git a/Editor/PropertyEditor.cs b/Editor/PropertyEditor.cs index 31a727a9..58241446 100644 --- a/Editor/PropertyEditor.cs +++ b/Editor/PropertyEditor.cs @@ -26,7 +26,7 @@ public virtual bool ToggleButton(string text, bool activated) public virtual void AddSetupSettings(SerializedObject llmScriptSO) { List attributeClasses = new List(){typeof(LocalRemoteAttribute)}; - SerializedProperty remoteProperty = llmScriptSO.FindProperty("remote"); + SerializedProperty remoteProperty = llmScriptSO.FindProperty("_remote"); if (remoteProperty != null) attributeClasses.Add(remoteProperty.boolValue ? typeof(RemoteAttribute) : typeof(LocalAttribute)); attributeClasses.Add(typeof(LLMAttribute)); if (llmScriptSO.FindProperty("advancedOptions").boolValue) diff --git a/Options.md b/Options.md index e90b1317..15a2237d 100644 --- a/Options.md +++ b/Options.md @@ -21,15 +21,15 @@ If the user's GPU is not supported, the LLM will fall back to the CPU - `Debug` select to log the output of the model in the Unity Editor -
Advanced options - -
Parallel Prompts number of prompts / slots that can happen in parallel (default: -1 = number of LLMCharacter objects). Note that the context size is divided among the slots. If you want to retain as much context for the LLM and don't need all the characters present at the same time, you can set this number and specify the slot for each LLMCharacter object. - e.g. Setting `Parallel Prompts` to 1 and slot 0 for all LLMCharacter objects will use the full context, but the entire prompt will need to be computed (no caching) whenever a LLMCharacter object is used for chat.
+ -
Parallel Prompts number of prompts / slots that can happen in parallel (default: -1 = number of LLMAgent objects). Note that the context size is divided among the slots. If you want to retain as much context for the LLM and don't need all the characters present at the same time, you can set this number and specify the slot for each LLMAgent object. + e.g. Setting `Parallel Prompts` to 1 and slot 0 for all LLMAgent objects will use the full context, but the entire prompt will need to be computed (no caching) whenever a LLMAgent object is used for chat.
- `Dont Destroy On Load` select to not destroy the LLM GameObject when loading a new Scene
## Server Security Settings -- `API key` API key to use to allow access to requests from LLMCharacter objects (if `Remote` is set) +- `API key` API key to use to allow access to requests from LLMAgent objects (if `Remote` is set) -
Advanced options - `Load SSL certificate` allows to load a SSL certificate for end-to-end encryption of requests (if `Remote` is set). Requires SSL key as well. @@ -58,7 +58,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
-## LLMCharacter Settings +## LLMAgent Settings - `Show/Hide Advanced Options` Toggle to show/hide advanced options from below - `Log Level` select how verbose the log messages are @@ -66,7 +66,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU ## 💻 Setup Settings
- +
- `Remote` whether the LLM used is remote or local @@ -113,4 +113,4 @@ If it is not selected, the full reply from the model is received in one go - `N Probs`: if greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0) - `Ignore Eos`: enable to ignore end of stream tokens and continue generating (default: false). - \ No newline at end of file + diff --git a/README.md b/README.md index 603e5e29..b108a329 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ First you will setup the LLM for your game 🏎: - Download one of the default models with the `Download Model` button (~GBs).
Or load your own .gguf model with the `Load model` button (see [LLM model management](#llm-model-management)). Then you can setup each of your characters as follows 🙋‍♀️: -- Create an empty GameObject for the character.
In the GameObject Inspector click `Add Component` and select the LLMCharacter script. +- Create an empty GameObject for the character.
In the GameObject Inspector click `Add Component` and select the LLMAgent script. - Define the role of your AI in the `Prompt`. You can define the name of the AI (`AI Name`) and the player (`Player Name`). - (Optional) Select the LLM constructed above in the `LLM` field if you have more than one LLM GameObjects. @@ -104,7 +104,7 @@ In your script you can then use it as follows 🦄: using LLMUnity; public class MyScript { - public LLMCharacter llmCharacter; + public LLMAgent llmAgent; void HandleReply(string reply){ // do something with the reply from the model @@ -115,7 +115,7 @@ public class MyScript { // your game function ... string message = "Hello bot!"; - _ = llmCharacter.Chat(message, HandleReply); + _ = llmAgent.Chat(message, HandleReply); ... } } @@ -132,17 +132,17 @@ This is useful if the `Stream` option is enabled for continuous output from the // your game function ... string message = "Hello bot!"; - _ = llmCharacter.Chat(message, HandleReply, ReplyCompleted); + _ = llmAgent.Chat(message, HandleReply, ReplyCompleted); ... } ``` To stop the chat without waiting for its completion you can use: ``` c# - llmCharacter.CancelRequests(); + llmAgent.CancelRequests(); ``` -- Finally, in the Inspector of the GameObject of your script, select the LLMCharacter GameObject created above as the llmCharacter property. +- Finally, in the Inspector of the GameObject of your script, select the LLMAgent GameObject created above as the llmAgent property. That's all ✨!

@@ -185,17 +185,17 @@ The [MobileDemo](Samples~/MobileDemo) is an example application for Android / iO Restrict the output of the LLM / Function calling To restrict the output of the LLM you can use a grammar, read more [here](https://github.com/ggerganov/llama.cpp/tree/master/grammars).
-The grammar can be saved in a .gbnf file and loaded at the LLMCharacter with the `Load Grammar` button (Advanced options).
+The grammar can be saved in a .gbnf file and loaded at the LLMAgent with the `Load Grammar` button (Advanced options).
For instance to receive replies in json format you can use the [json.gbnf](https://github.com/ggerganov/llama.cpp/blob/b4218/grammars/json.gbnf) grammar.
Graamars in JSON schema format are also supported and can be loaded with the `Load JSON Grammar` button (Advanced options).
Alternatively you can set the grammar directly with code: ``` c# // GBNF grammar -llmCharacter.grammarString = "your GBNF grammar here"; +llmAgent.grammarString = "your GBNF grammar here"; // or JSON schema grammar -llmCharacter.grammarJSONString = "your JSON schema grammar here"; +llmAgent.grammarJSONString = "your JSON schema grammar here"; ``` For function calling you can define similarly a grammar that allows only the function names as output, and then call the respective function.
@@ -204,22 +204,22 @@ You can look into the [FunctionCalling](Samples~/FunctionCalling) sample for an
Access / Save / Load your chat history -The chat history of a `LLMCharacter` is retained in the `chat` variable that is a list of `ChatMessage` objects.
+The chat history of a `LLMAgent` is retained in the `chat` variable that is a list of `ChatMessage` objects.
The ChatMessage is a struct that defines the `role` of the message and the `content`.
The first element of the list is always the system prompt and then alternating messages with the player prompt and the AI reply.
You can modify the chat history directly in this list.
-To automatically save / load your chat history, you can specify the `Save` parameter of the LLMCharacter to the filename (or relative path) of your choice. +To automatically save / load your chat history, you can specify the `Save` parameter of the LLMAgent to the filename (or relative path) of your choice. The file is saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html). This also saves the state of the LLM which means that the previously cached prompt does not need to be recomputed. To manually save your chat history, you can use: ``` c# - llmCharacter.Save("filename"); + llmAgent.Save("filename"); ``` and to load the history: ``` c# - llmCharacter.Load("filename"); + llmAgent.Load("filename"); ``` where filename the filename or relative path of your choice. @@ -236,7 +236,7 @@ where filename the filename or relative path of your choice. void Game(){ // your game function ... - _ = llmCharacter.Warmup(WarmupCompleted); + _ = llmAgent.Warmup(WarmupCompleted); ... } ``` @@ -251,7 +251,7 @@ where filename the filename or relative path of your choice. // your game function ... string message = "Hello bot!"; - _ = llmCharacter.Chat(message, HandleReply, ReplyCompleted, false); + _ = llmAgent.Chat(message, HandleReply, ReplyCompleted, false); ... } ``` @@ -265,7 +265,7 @@ where filename the filename or relative path of your choice. // your game function ... string message = "The cat is away"; - _ = llmCharacter.Complete(message, HandleReply, ReplyCompleted); + _ = llmAgent.Complete(message, HandleReply, ReplyCompleted); ... } ``` @@ -280,7 +280,7 @@ where filename the filename or relative path of your choice. // your game function ... string message = "Hello bot!"; - string reply = await llmCharacter.Chat(message, HandleReply, ReplyCompleted); + string reply = await llmAgent.Chat(message, HandleReply, ReplyCompleted); Debug.Log(reply); ... } @@ -288,7 +288,7 @@ where filename the filename or relative path of your choice.
-Add a LLM / LLMCharacter component programmatically +Add a LLM / LLMAgent component programmatically ``` c# using UnityEngine; @@ -297,7 +297,7 @@ using LLMUnity; public class MyScript : MonoBehaviour { LLM llm; - LLMCharacter llmCharacter; + LLMAgent llmAgent; async void Start() { @@ -321,23 +321,23 @@ public class MyScript : MonoBehaviour // optional: enable GPU by setting the number of model layers to offload to it llm.numGPULayers = 10; - // Add an LLMCharacter object - llmCharacter = gameObject.AddComponent(); + // Add an LLMAgent object + llmAgent = gameObject.AddComponent(); // set the LLM object that handles the model - llmCharacter.llm = llm; + llmAgent.llm = llm; // set the character prompt - llmCharacter.SetPrompt("A chat between a curious human and an artificial intelligence assistant."); + llmAgent.SetPrompt("A chat between a curious human and an artificial intelligence assistant."); // set the AI and player name - llmCharacter.AIName = "AI"; - llmCharacter.playerName = "Human"; + llmAgent.assistantRole = "AI"; + llmAgent.userRole = "Human"; // optional: set streaming to false to get the complete result in one go - // llmCharacter.stream = true; + // llmAgent.stream = true; // optional: set a save path - // llmCharacter.save = "AICharacter1"; + // llmAgent.save = "AICharacter1"; // optional: enable the save cache to avoid recomputation when loading a save file (requires ~100 MB) - // llmCharacter.saveCache = true; + // llmAgent.saveCache = true; // optional: set a grammar - // await llmCharacter.SetGrammar("json.gbnf"); + // await llmAgent.SetGrammar("json.gbnf"); // re-enable gameObject gameObject.SetActive(true); @@ -359,14 +359,14 @@ To create the server: Alternatively you can use a server binary for easier deployment: - Run the above scene from the Editor and copy the command from the Debug messages (starting with "Server command:") -- Download the [server binaries](https://github.com/undreamai/LlamaLib/releases/download/v1.2.6/undreamai-v1.2.6-server.zip) and [DLLs](https://github.com/undreamai/LlamaLib/releases/download/v1.2.6/undreamai-v1.2.6-llamacpp-full.zip) and extract them into the same folder +- Download the [server binaries](https://github.com/undreamai/LlamaLib/releases/download/v2.0.0/undreamai-v2.0.0-server.zip) and [DLLs](https://github.com/undreamai/LlamaLib/releases/download/v2.0.0/undreamai-v2.0.0-llamacpp-full.zip) and extract them into the same folder - Find the architecture you are interested in from the folder above e.g. for Windows and CUDA use the `windows-cuda-cu12.2.0`.
You can also check the architecture that works for your system from the Debug messages (starting with "Using architecture"). - From command line change directory to the architecture folder selected and start the server by running the command copied from above. In both cases you'll need to enable 'Allow Downloads Over HTTP' in the project settings. **Create the characters**
-Create a second project with the game characters using the `LLMCharacter` script as described above. +Create a second project with the game characters using the `LLMAgent` script as described above. Enable the `Remote` option and configure the host with the IP address (starting with "http://") and port of the server.
@@ -375,7 +375,7 @@ Enable the `Remote` option and configure the host with the IP address (starting The `Embeddings` function can be used to obtain the emdeddings of a phrase: ``` c# - List embeddings = await llmCharacter.Embeddings("hi, how are you?"); + List embeddings = await llmAgent.Embeddings("hi, how are you?"); ``` @@ -460,7 +460,7 @@ You can use the RAG to feed relevant data to the LLM based on a user message: prompt += $"Data:\n"; foreach (string similarPhrase in similarPhrases) prompt += $"\n- {similarPhrase}"; - _ = llmCharacter.Chat(prompt, HandleReply, ReplyCompleted); + _ = llmAgent.Chat(prompt, HandleReply, ReplyCompleted); ``` The `RAG` sample includes an example RAG implementation as well as an example RAG-LLM integration. diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 3818741b..71482597 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -1,132 +1,342 @@ /// @file -/// @brief File implementing the LLM. +/// @brief File implementing the LLM server component for Unity. using System; using System.Collections.Generic; using System.IO; -using System.Threading; +using System.Linq; using System.Threading.Tasks; +using UndreamAI.LlamaLib; using UnityEditor; using UnityEngine; namespace LLMUnity { - [DefaultExecutionOrder(-1)] /// @ingroup llm /// - /// Class implementing the LLM server. + /// Unity MonoBehaviour component that manages a local LLM server instance. + /// Handles model loading, GPU acceleration, LORA adapters, and provides + /// completion, tokenization, and embedding functionality. /// public class LLM : MonoBehaviour { - /// show/hide advanced options in the GameObject - [Tooltip("show/hide advanced options in the GameObject")] + #region Inspector Fields + /// Show/hide advanced options in the inspector + [Tooltip("Show/hide advanced options in the inspector")] [HideInInspector] public bool advancedOptions = false; - /// enable remote server functionality - [Tooltip("enable remote server functionality")] - [LocalRemote] public bool remote = false; - /// port to use for the remote LLM server - [Tooltip("port to use for the remote LLM server")] - [Remote] public int port = 13333; - /// number of threads to use (-1 = all) - [Tooltip("number of threads to use (-1 = all)")] - [LLM] public int numThreads = -1; - /// number of model layers to offload to the GPU (0 = GPU not used). - /// If the user's GPU is not supported, the LLM will fall back to the CPU - [Tooltip("number of model layers to offload to the GPU (0 = GPU not used). If the user's GPU is not supported, the LLM will fall back to the CPU")] - [LLM] public int numGPULayers = 0; - /// log the output of the LLM in the Unity Editor. - [Tooltip("log the output of the LLM in the Unity Editor.")] - [LLM] public bool debug = false; - /// number of prompts that can happen in parallel (-1 = number of LLMCaller objects) - [Tooltip("number of prompts that can happen in parallel (-1 = number of LLMCaller objects)")] - [LLMAdvanced] public int parallelPrompts = -1; - /// do not destroy the LLM GameObject when loading a new Scene. - [Tooltip("do not destroy the LLM GameObject when loading a new Scene.")] - [LLMAdvanced] public bool dontDestroyOnLoad = true; - /// Size of the prompt context (0 = context size of the model). - /// This is the number of tokens the model can take as input when generating responses. - [Tooltip("Size of the prompt context (0 = context size of the model). This is the number of tokens the model can take as input when generating responses.")] - [DynamicRange("minContextLength", "maxContextLength", false), Model] public int contextSize = 8192; - /// Batch size for prompt processing. - [Tooltip("Batch size for prompt processing.")] - [ModelAdvanced] public int batchSize = 512; - /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. - public bool started { get; protected set; } = false; - /// Boolean set to true if the server has failed to start. - public bool failed { get; protected set; } = false; - /// Boolean set to true if the models were not downloaded successfully. - public static bool modelSetupFailed { get; protected set; } = false; - /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. - public static bool modelSetupComplete { get; protected set; } = false; - /// LLM model to use (.gguf format) - [Tooltip("LLM model to use (.gguf format)")] - [ModelAdvanced] public string model = ""; - /// Chat template for the model - [Tooltip("Chat template for the model")] - [ModelAdvanced] public string chatTemplate = ChatTemplate.DefaultTemplate; - /// LORA models to use (.gguf format) - [Tooltip("LORA models to use (.gguf format)")] - [ModelAdvanced] public string lora = ""; - /// the weights of the LORA models being used. - [Tooltip("the weights of the LORA models being used.")] - [ModelAdvanced] public string loraWeights = ""; - /// enable use of flash attention - [Tooltip("enable use of flash attention")] - [ModelExtras] public bool flashAttention = false; - /// API key to use for the server - [Tooltip("API key to use for the server")] - public string APIKey; - - // SSL certificate - [SerializeField] - private string SSLCert = ""; - public string SSLCertPath = ""; - // SSL key - [SerializeField] - private string SSLKey = ""; - public string SSLKeyPath = ""; + /// Enable remote server functionality to allow external connections + [Tooltip("Enable remote server functionality to allow external connections")] + [LocalRemote, SerializeField] private bool _remote = false; + + /// Port to use for the remote LLM server + [Tooltip("Port to use for the remote LLM server")] + [Remote, SerializeField] private int _port = 13333; + + /// API key required for server access (leave empty to disable authentication) + [Tooltip("API key required for server access (leave empty to disable authentication)")] + [SerializeField] private string _APIKey = ""; + + /// SSL certificate for the remote LLM server + [Tooltip("SSL certificate for the remote LLM server")] + [SerializeField] private string _SSLCert = ""; + + /// SSL key for the remote LLM server + [Tooltip("SSL key for the remote LLM server")] + [SerializeField] private string _SSLKey = ""; + + /// Number of threads to use for processing (-1 = use all available threads) + [Tooltip("Number of threads to use for processing (-1 = use all available threads)")] + [LLM, SerializeField] private int _numThreads = -1; + + /// Number of model layers to offload to GPU (0 = CPU only). Falls back to CPU if GPU unsupported + [Tooltip("Number of model layers to offload to GPU (0 = CPU only). Falls back to CPU if GPU unsupported")] + [LLM, SerializeField] private int _numGPULayers = 0; + + /// Number of prompts that can be processed in parallel (-1 = auto-detect from clients) + [Tooltip("Number of prompts that can be processed in parallel (-1 = auto-detect from clients)")] + [LLM, SerializeField] private int _parallelPrompts = -1; + + /// Size of the prompt context in tokens (0 = use model's default context size) + [Tooltip("Size of the prompt context in tokens (0 = use model's default context size). This determines how much conversation history the model can remember.")] + [DynamicRange("minContextLength", "maxContextLength", false), Model, SerializeField] private int _contextSize = 8192; + + /// Batch size for prompt processing (larger = more memory, potentially faster) + [Tooltip("Batch size for prompt processing (larger = more memory, potentially faster)")] + [ModelAdvanced, SerializeField] private int _batchSize = 512; + + /// LLM model file path (.gguf format) + [Tooltip("LLM model file path (.gguf format)")] + [ModelAdvanced, SerializeField] private string _model = ""; + + /// Enable flash attention optimization (requires compatible model) + [Tooltip("Enable flash attention optimization (requires compatible model)")] + [ModelExtras, SerializeField] private bool _flashAttention = false; + + /// Chat template for conversation formatting ("auto" = detect from model) + [Tooltip("Chat template for conversation formatting (\"auto\" = detect from model)")] + [ModelAdvanced, SerializeField] private string _chatTemplate = "auto"; + + /// LORA adapter model paths (.gguf format), separated by commas + [Tooltip("LORA adapter model paths (.gguf format), separated by commas")] + [ModelAdvanced, SerializeField] private string _lora = ""; + + /// Weights for LORA adapters, separated by commas (default: 1.0 for each) + [Tooltip("Weights for LORA adapters, separated by commas (default: 1.0 for each)")] + [ModelAdvanced, SerializeField] private string _loraWeights = ""; + + /// Persist this LLM GameObject across scene transitions + [Tooltip("Persist this LLM GameObject across scene transitions")] + [LLM] public bool dontDestroyOnLoad = true; + #endregion + + #region Public Properties with Validation + + /// Number of threads to use for processing (-1 = use all available threads) + public int numThreads + { + get => _numThreads; + set + { + AssertNotStarted(); + if (value < -1) + throw new ArgumentException("numThreads must be >= -1"); + _numThreads = value; + } + } + + /// Number of model layers to offload to GPU (0 = CPU only) + public int numGPULayers + { + get => _numGPULayers; + set + { + AssertNotStarted(); + if (value < 0) + throw new ArgumentException("numGPULayers must be >= 0"); + _numGPULayers = value; + } + } + + /// Number of prompts that can be processed in parallel (-1 = auto-detect from clients) + public int parallelPrompts + { + get => _parallelPrompts; + set + { + AssertNotStarted(); + if (value < -1) + throw new ArgumentException("parallelPrompts must be >= -1"); + _parallelPrompts = value; + } + } + + /// Size of the prompt context in tokens (0 = use model's default context size) + public int contextSize + { + get => _contextSize; + set + { + AssertNotStarted(); + if (value < 0) + throw new ArgumentException("contextSize must be >= 0"); + _contextSize = value; + } + } + + /// Batch size for prompt processing (larger = more memory, potentially faster) + public int batchSize + { + get => _batchSize; + set + { + AssertNotStarted(); + if (value <= 0) + throw new ArgumentException("batchSize must be > 0"); + _batchSize = value; + } + } + + /// Enable flash attention optimization (requires compatible model) + public bool flashAttention + { + get => _flashAttention; + set + { + AssertNotStarted(); + _flashAttention = value; + } + } + + /// LLM model file path (.gguf format) + public string model + { + get => _model; + set => SetModel(value); + } + + /// Chat template for conversation formatting ("auto" = detect from model) + public string chatTemplate + { + get => _chatTemplate; + set => SetTemplate(value); + } + + /// LORA adapter model paths (.gguf format), separated by commas + public string lora + { + get => _lora; + set + { + if (value == _lora) return; + AssertNotStarted(); + _lora = value; + UpdateLoraManagerFromStrings(); + } + } + + /// Weights for LORA adapters, separated by commas (default: 1.0 for each) + public string loraWeights + { + get => _loraWeights; + set + { + if (value == _loraWeights) return; + _loraWeights = value; + UpdateLoraManagerFromStrings(); + ApplyLoras(); + } + } + + /// Enable remote server functionality to allow external connections + public bool remote + { + get => _remote; + set + { + if (value == _remote) return; + _remote = value; + RestartServer(); + } + } + + /// Port to use for the remote LLM server + public int port + { + get => _port; + set + { + if (value == _port) return; + if (value < 0 || value > 65535) + throw new ArgumentException("port must be between 0 and 65535"); + _port = value; + RestartServer(); + } + } + + /// API key required for server access (leave empty to disable authentication) + public string APIKey + { + get => _APIKey; + set + { + if (value == _APIKey) return; + _APIKey = value; + RestartServer(); + } + } + + /// SSL certificate for the remote LLM server + public string SSLCert + { + get => _SSLCert; + set + { + AssertNotStarted(); + if (value == _SSLCert) return; + _SSLCert = value; + } + } + + /// SSL key for the remote LLM server + public string SSLKey + { + get => _SSLKey; + set + { + AssertNotStarted(); + if (value == _SSLKey) return; + _SSLKey = value; + } + } + + #endregion + + #region Other Public Properties + /// True if the LLM server has started and is ready to receive requests + public bool started { get; private set; } = false; + + /// True if the LLM server failed to start + public bool failed { get; private set; } = false; + + /// True if model setup failed during initialization + public static bool modelSetupFailed { get; private set; } = false; + + /// True if model setup completed (successfully or not) + public static bool modelSetupComplete { get; private set; } = false; + + /// The underlying LLM service instance + public LLMService llmService { get; private set; } + + /// Model architecture name (e.g., "llama", "mistral") + public string architecture => llmlib?.architecture; + + /// True if this model only supports embeddings (no text generation) + public bool embeddingsOnly { get; private set; } = false; + + /// Number of dimensions in embedding vectors (0 if not an embedding model) + public int embeddingLength { get; private set; } = 0; + #endregion + + #region Private Fields /// \cond HIDE public int minContextLength = 0; public int maxContextLength = 0; - public string architecture => llmlib.architecture; - - IntPtr LLMObject = IntPtr.Zero; - List clients = new List(); - LLMLib llmlib; - StreamWrapper logStreamWrapper = null; - Thread llmThread = null; - List streamWrappers = new List(); + + public static readonly string[] ChatTemplates = new string[] + { + "auto", "chatml", "llama2", "llama2-sys", "llama2-sys-bos", "llama2-sys-strip", + "mistral-v1", "mistral-v3", "mistral-v3-tekken", "mistral-v7", "mistral-v7-tekken", + "phi3", "phi4", "falcon3", "zephyr", "monarch", "gemma", "orion", "openchat", + "vicuna", "vicuna-orca", "deepseek", "deepseek2", "deepseek3", "command-r", + "llama3", "chatglm3", "chatglm4", "glmedge", "minicpm", "exaone3", "exaone4", + "rwkv-world", "granite", "gigachat", "megrez", "yandex", "bailing", "llama4", + "smolvlm", "hunyuan-moe", "gpt-oss", "hunyuan-dense", "kimi-k2" + }; + + private LlamaLib llmlib = null; + // [Local, SerializeField] + protected LLMService _llmService; + private readonly List clients = new List(); public LLMManager llmManager = new LLMManager(); - private readonly object startLock = new object(); - static readonly object staticLock = new object(); + private static readonly object staticLock = new object(); public LoraManager loraManager = new LoraManager(); - string loraPre = ""; - string loraWeightsPre = ""; - public bool embeddingsOnly = false; - public int embeddingLength = 0; - /// \endcond + #endregion + #region Unity Lifecycle public LLM() { LLMManager.Register(this); } - void OnValidate() - { - if (lora != loraPre || loraWeights != loraWeightsPre) - { - loraManager.FromStrings(lora, loraWeights); - (loraPre, loraWeightsPre) = (lora, loraWeights); - } - } - /// - /// The Unity Awake function that starts the LLM server. + /// Unity Awake method that initializes the LLM server. + /// Sets up the model, starts the service, and handles GPU fallback if needed. /// public async void Awake() { if (!enabled) return; + #if !UNITY_EDITOR modelSetupFailed = !await LLMManager.Setup(); #endif @@ -136,128 +346,415 @@ public async void Awake() failed = true; return; } - string arguments = GetLlamaccpArguments(); - if (arguments == null) + + await StartServiceAsync(); + if (!started) return; + if (dontDestroyOnLoad) DontDestroyOnLoad(transform.root.gameObject); + } + + public void OnDestroy() + { + Destroy(); + LLMManager.Unregister(this); + } + + #endregion + + #region Initialization + private void ValidateParameters() + { + if ((SSLCert != "" && SSLKey == "") || (SSLCert == "" && SSLKey != "")) + { + throw new ArgumentException("Both SSL certificate and key must be provided together!"); + } + } + + private string GetValidatedModelPath() + { + if (string.IsNullOrEmpty(model)) + { + throw new ArgumentException("No model file provided!"); + } + + string modelPath = GetLLMManagerAssetRuntime(model); + if (!File.Exists(modelPath)) + { + throw new ArgumentException($"Model file not found: {modelPath}"); + } + return modelPath; + } + + private List GetValidatedLoraPaths() + { + loraManager.FromStrings(lora, loraWeights); + List loraPaths = new List(); + + foreach (string loraPath in loraManager.GetLoras()) + { + string resolvedPath = GetLLMManagerAssetRuntime(loraPath); + if (!File.Exists(resolvedPath)) + { + throw new ArgumentException($"LORA file not found: {resolvedPath}"); + } + loraPaths.Add(resolvedPath); + } + return loraPaths; + } + + private async Task StartServiceAsync() + { + started = false; + failed = false; + + try + { + ValidateParameters(); + string modelPath = GetValidatedModelPath(); + List loraPaths = GetValidatedLoraPaths(); + + CreateLib(); + await CreateServiceAsync(modelPath, loraPaths); + } + catch (ArgumentException ex) + { + LLMUnitySetup.LogError(ex.Message); + failed = true; + return; + } + catch (Exception ex) { + LLMUnitySetup.LogError($"Failed to create LLM service: {ex.Message}"); + Destroy(); failed = true; return; } - await Task.Run(() => StartLLMServer(arguments)); + + if (started) + { + LLMUnitySetup.Log($"LLM service created successfully, using {architecture}"); + } + } + + private void CreateLib() + { + bool useGPU = numGPULayers > 0; + llmlib = new LlamaLibUnity(useGPU); + + if (LLMUnitySetup.DebugMode <= LLMUnitySetup.DebugModeType.All) + { + LlamaLibUnity.Debug(LLMUnitySetup.DebugModeType.All - LLMUnitySetup.DebugMode + 1); + LlamaLibUnity.LoggingCallback(LLMUnitySetup.Log); + } + } + + /// + /// Setup the remote LLM server + /// + private void SetupServer() + { + if (!remote) return; + + if (!string.IsNullOrEmpty(SSLCert) && !string.IsNullOrEmpty(SSLKey)) + { + LLMUnitySetup.Log("Enabling SSL for server"); + llmService.SetSSL(SSLCert, SSLKey); + } + llmService.StartServer("", port, APIKey); + } + + /// + /// Restart the remote LLM server (on parameter change) + /// + private void RestartServer() + { if (!started) return; - if (dontDestroyOnLoad) DontDestroyOnLoad(transform.root.gameObject); + llmService.StopServer(); + SetupServer(); } + private async Task CreateServiceAsync(string modelPath, List loraPaths) + { + int numSlots = GetNumClients(); + int effectiveThreads = numThreads; + + if (Application.platform == RuntimePlatform.Android && numThreads <= 0) + { + effectiveThreads = LLMUnitySetup.AndroidGetNumBigCores(); + } + + await Task.Run(() => + { + lock (staticLock) + { + IntPtr llmPtr = LLMService.CreateLLM( + llmlib, modelPath, numSlots, effectiveThreads, numGPULayers, + flashAttention, contextSize, batchSize, embeddingsOnly, loraPaths.ToArray()); + + llmService = new LLMService(llmlib, llmPtr); + SetupServer(); + llmService.Start(); + } + }); + + started = llmService.Started(); + if (!started) return; + + ApplyLoras(); + SetTemplate(chatTemplate); + } + + #endregion + + #region Public Methods /// - /// Allows to wait until the LLM is ready + /// Waits asynchronously until the LLM is ready to accept requests. /// + /// Task that completes when LLM is ready public async Task WaitUntilReady() { - while (!started) await Task.Yield(); + while (!started && !failed) + { + await Task.Yield(); + } + + if (failed) + { + throw new InvalidOperationException("LLM failed to start"); + } } /// - /// Allows to wait until the LLM models are downloaded and ready + /// Waits asynchronously until model setup is complete. /// - /// function to call with the download progress (float) + /// Optional callback for download progress updates + /// True if setup succeeded, false if it failed public static async Task WaitUntilModelSetup(Callback downloadProgressCallback = null) { - if (downloadProgressCallback != null) LLMManager.downloadProgressCallbacks.Add(downloadProgressCallback); - while (!modelSetupComplete) await Task.Yield(); + if (downloadProgressCallback != null) + { + LLMManager.downloadProgressCallbacks.Add(downloadProgressCallback); + } + + while (!modelSetupComplete) + { + await Task.Yield(); + } + return !modelSetupFailed; } - /// \cond HIDE - public static string GetLLMManagerAsset(string path) + /// + /// Sets the model file to use. Automatically configures context size and embedding settings. + /// + /// Path to the model file (.gguf format) + public void SetModel(string path) { + if (model == path) return; + AssertNotStarted(); + + _model = GetLLMManagerAsset(path); + if (string.IsNullOrEmpty(model)) return; + + ModelEntry modelEntry = LLMManager.Get(model) ?? new ModelEntry(GetLLMManagerAssetRuntime(model)); + + maxContextLength = modelEntry.contextLength; + if (contextSize > maxContextLength) + { + contextSize = maxContextLength; + } + + SetEmbeddings(modelEntry.embeddingLength, modelEntry.embeddingOnly); + + if (contextSize == 0 && modelEntry.contextLength > 32768) + { + LLMUnitySetup.LogWarning($"Model {path} has large context size ({modelEntry.contextLength}). Consider setting contextSize to ≤32768 to avoid excessive memory usage."); + } + #if UNITY_EDITOR - if (!EditorApplication.isPlaying) return GetLLMManagerAssetEditor(path); + if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); #endif - return GetLLMManagerAssetRuntime(path); } - public static string GetLLMManagerAssetEditor(string path) + /// + /// Sets the chat template for message formatting. + /// + /// Template name (see ChatTemplates array for options) + /// Mark object dirty in editor + public void SetTemplate(string templateName, bool setDirty = true) { - // empty - if (string.IsNullOrEmpty(path)) return path; - // LLMManager - return location the file will be stored in StreamingAssets - ModelEntry modelEntry = LLMManager.Get(path); - if (modelEntry != null) return modelEntry.filename; - // StreamingAssets - return relative location within StreamingAssets - string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed - string basePath = LLMUnitySetup.GetAssetPath(); - if (File.Exists(assetPath)) + if (_chatTemplate == templateName) return; + if (!ChatTemplates.Contains(templateName)) { - if (LLMUnitySetup.IsSubPath(assetPath, basePath)) return LLMUnitySetup.RelativePath(assetPath, basePath); + LLMUnitySetup.LogError($"Unsupported chat template: {templateName}"); + return; } - // full path - if (!File.Exists(assetPath)) + + _chatTemplate = templateName; + if (started) { - LLMUnitySetup.LogError($"Model {path} was not found."); + llmService.SetTemplate(_chatTemplate == "auto" ? "" : _chatTemplate); } - else + +#if UNITY_EDITOR + if (setDirty && !EditorApplication.isPlaying) EditorUtility.SetDirty(this); +#endif + } + + /// + /// Configure the LLM for embedding generation. + /// + /// Number of embedding dimensions + /// True if model only supports embeddings + public void SetEmbeddings(int embeddingLength, bool embeddingsOnly) + { + this.embeddingsOnly = embeddingsOnly; + this.embeddingLength = embeddingLength; + +#if UNITY_EDITOR + if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); +#endif + } + + /// + /// Registers an LLMClient for slot management. + /// + /// Client to register + /// Assigned slot ID + public void Register(LLMClient llmClient) + { + if (llmClient == null) { - string errorMessage = $"The model {path} was loaded locally. You can include it in the build in one of these ways:"; - errorMessage += $"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path"; - errorMessage += $"\n-Load the model with the model manager inside the LLM GameObject and use its filename"; - LLMUnitySetup.LogWarning(errorMessage); + throw new ArgumentNullException(nameof(llmClient)); + } + + clients.Add(llmClient); + } + + /// + /// Tokenizes the provided text into a list of token IDs. + /// + /// Text to tokenize + /// List of token IDs + public List Tokenize(string content) + { + AssertStarted(); + return llmService.Tokenize(content); + } + + /// + /// Converts token IDs back to text. + /// + /// List of token IDs + /// Detokenised text + public string Detokenize(List tokens) + { + AssertStarted(); + return llmService.Detokenize(tokens); + } + + /// + /// Generates embedding vectors for the provided text. + /// + /// Text to embed + /// Embedding vector + public List Embeddings(string content) + { + AssertStarted(); + return llmService.Embeddings(content); + } + + /// + /// Generates text completion for the given prompt. + /// + /// Input prompt + /// Optional callback for streaming responses + /// Slot ID (-1 for automatic assignment) + /// Generated text + public string Completion(string prompt, LlamaLibUnity.CharArrayCallback streamCallback = null, int id_slot = -1) + { + AssertStarted(); + return llmService.Completion(prompt, streamCallback, id_slot); + } + + /// + /// Generates text completion asynchronously. + /// + /// Input prompt + /// Optional callback for streaming responses + /// Slot ID (-1 for automatic assignment) + /// Task that returns generated text + public async Task CompletionAsync(string prompt, LlamaLibUnity.CharArrayCallback streamCallback = null, int id_slot = -1) + { + AssertStarted(); + // Wrap callback to ensure it runs on the main thread + LlamaLib.CharArrayCallback wrappedCallback = Utils.WrapCallbackForAsync(streamCallback); + return await llmService.CompletionAsync(prompt, wrappedCallback, id_slot); + } + + /// + /// Cancels the request in the specified slot. + /// + /// Slot ID + public void CancelRequest(int id_slot) + { + AssertStarted(); + llmService.Cancel(id_slot); + } + + /// + /// Cancels all active requests. + /// + public void CancelRequests() + { + for (int i = 0; i < parallelPrompts; i++) + { + CancelRequest(i); } - return path; } - public static string GetLLMManagerAssetRuntime(string path) + /// + /// Saves the state of a specific slot to disk. + /// + /// Slot ID + /// File path to save to + /// Result message + public string SaveSlot(int idSlot, string filepath) { - // empty - if (string.IsNullOrEmpty(path)) return path; - // LLMManager - string managerPath = LLMManager.GetAssetPath(path); - if (!string.IsNullOrEmpty(managerPath) && File.Exists(managerPath)) return managerPath; - // StreamingAssets - string assetPath = LLMUnitySetup.GetAssetPath(path); - if (File.Exists(assetPath)) return assetPath; - // download path - assetPath = LLMUnitySetup.GetDownloadAssetPath(path); - if (File.Exists(assetPath)) return assetPath; - // give up - return path; + AssertStarted(); + return llmService.SaveSlot(idSlot, filepath); } - /// \endcond - /// - /// Allows to set the model used by the LLM. - /// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build. - /// Models supported are in .gguf format. + /// Loads the state of a specific slot from disk. /// - /// path to model to use (.gguf format) - public void SetModel(string path) + /// Slot ID + /// File path to load from + /// Result message + public string LoadSlot(int idSlot, string filepath) { - model = GetLLMManagerAsset(path); - if (!string.IsNullOrEmpty(model)) - { - ModelEntry modelEntry = LLMManager.Get(model); - if (modelEntry == null) modelEntry = new ModelEntry(GetLLMManagerAssetRuntime(model)); - SetTemplate(modelEntry.chatTemplate); + AssertStarted(); + return llmService.LoadSlot(idSlot, filepath); + } - maxContextLength = modelEntry.contextLength; - if (contextSize > maxContextLength) contextSize = maxContextLength; - SetEmbeddings(modelEntry.embeddingLength, modelEntry.embeddingOnly); - if (contextSize == 0 && modelEntry.contextLength > 32768) - { - LLMUnitySetup.LogWarning($"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM"); - } - } -#if UNITY_EDITOR - if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); -#endif + /// + /// Gets a list of loaded LORA adapters. + /// + /// List of LORA adapter information + public List ListLoras() + { + AssertStarted(); + return llmService.LoraList(); } + #endregion + + #region LORA Management /// - /// Allows to set a LORA model to use in the LLM. - /// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build. - /// Models supported are in .gguf format. + /// Sets a single LORA adapter, replacing any existing ones. /// - /// path to LORA model to use (.gguf format) - public void SetLora(string path, float weight = 1) + /// Path to LORA file (.gguf format) + /// Adapter weight (default: 1.0) + public void SetLora(string path, float weight = 1f) { AssertNotStarted(); loraManager.Clear(); @@ -265,12 +762,11 @@ public void SetLora(string path, float weight = 1) } /// - /// Allows to add a LORA model to use in the LLM. - /// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build. - /// Models supported are in .gguf format. + /// Adds a LORA adapter to the existing set. /// - /// path to LORA model to use (.gguf format) - public void AddLora(string path, float weight = 1) + /// Path to LORA file (.gguf format) + /// Adapter weight (default: 1.0) + public void AddLora(string path, float weight = 1f) { AssertNotStarted(); loraManager.Add(path, weight); @@ -278,10 +774,9 @@ public void AddLora(string path, float weight = 1) } /// - /// Allows to remove a LORA model from the LLM. - /// Models supported are in .gguf format. + /// Removes a specific LORA adapter. /// - /// path to LORA model to remove (.gguf format) + /// Path to LORA file to remove public void RemoveLora(string path) { AssertNotStarted(); @@ -290,7 +785,7 @@ public void RemoveLora(string path) } /// - /// Allows to remove all LORA models from the LLM. + /// Removes all LORA adapters. /// public void RemoveLoras() { @@ -300,10 +795,10 @@ public void RemoveLoras() } /// - /// Allows to change the weight (scale) of a LORA model in the LLM. + /// Changes the weight of a specific LORA adapter. /// - /// path of LORA model to change (.gguf format) - /// weight of LORA + /// Path to LORA file + /// New weight value public void SetLoraWeight(string path, float weight) { loraManager.SetWeight(path, weight); @@ -312,538 +807,228 @@ public void SetLoraWeight(string path, float weight) } /// - /// Allows to change the weights (scale) of the LORA models in the LLM. + /// Changes the weights of multiple LORA adapters. /// - /// Dictionary (string, float) mapping the path of LORA models with weights to change + /// Dictionary mapping LORA paths to weights public void SetLoraWeights(Dictionary loraToWeight) { - foreach (KeyValuePair entry in loraToWeight) loraManager.SetWeight(entry.Key, entry.Value); + if (loraToWeight == null) + { + throw new ArgumentNullException(nameof(loraToWeight)); + } + + foreach (var entry in loraToWeight) + { + loraManager.SetWeight(entry.Key, entry.Value); + } UpdateLoras(); if (started) ApplyLoras(); } - public void UpdateLoras() + private void UpdateLoras() { - (lora, loraWeights) = loraManager.ToStrings(); - (loraPre, loraWeightsPre) = (lora, loraWeights); + (_lora, _loraWeights) = loraManager.ToStrings(); #if UNITY_EDITOR if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); #endif } - /// - /// Set the chat template for the LLM. - /// - /// the chat template to use. The available templates can be found in the ChatTemplate.templates.Keys array - public void SetTemplate(string templateName, bool setDirty = true) + private void UpdateLoraManagerFromStrings() { - chatTemplate = templateName; - if (started) llmlib?.LLM_SetTemplate(LLMObject, chatTemplate); -#if UNITY_EDITOR - if (setDirty && !EditorApplication.isPlaying) EditorUtility.SetDirty(this); -#endif + loraManager.FromStrings(_lora, _loraWeights); } - /// - /// Set LLM Embedding parameters - /// - /// number of embedding dimensions - /// if true, the LLM will be used only for embeddings - public void SetEmbeddings(int embeddingLength, bool embeddingsOnly) + private void ApplyLoras() { - this.embeddingsOnly = embeddingsOnly; - this.embeddingLength = embeddingLength; -#if UNITY_EDITOR - if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); -#endif - } + if (!started) return; + var loras = new List(); + float[] weights = loraManager.GetWeights(); - /// \cond HIDE + for (int i = 0; i < weights.Length; i++) + { + loras.Add(new LoraIdScale(i, weights[i])); + } - string ReadFileContents(string path) - { - if (String.IsNullOrEmpty(path)) return ""; - else if (!File.Exists(path)) + if (loras.Count > 0) { - LLMUnitySetup.LogError($"File {path} not found!"); - return ""; + llmService.LoraWeight(loras); } - return File.ReadAllText(path); } - /// \endcond + #endregion + #region SSL Configuration /// - /// Use a SSL certificate for the LLM server. + /// Sets the SSL certificate for secure server connections. /// - /// the SSL certificate path - public void SetSSLCert(string path) + /// Path to SSL certificate file + public void SetSSLCertFromFile(string path) { - SSLCertPath = path; SSLCert = ReadFileContents(path); } /// - /// Use a SSL key for the LLM server. + /// Sets the SSL private key for secure server connections. /// - /// the SSL key path - public void SetSSLKey(string path) + /// Path to SSL private key file + public void SetSSLKeyFromFile(string path) { - SSLKeyPath = path; SSLKey = ReadFileContents(path); } - /// - /// Returns the chat template of the LLM. - /// - /// chat template of the LLM - public string GetTemplate() - { - return chatTemplate; - } - - protected virtual string GetLlamaccpArguments() + private string ReadFileContents(string path) { - // Start the LLM server in a cross-platform way - if ((SSLCert != "" && SSLKey == "") || (SSLCert == "" && SSLKey != "")) - { - LLMUnitySetup.LogError($"Both SSL certificate and key need to be provided!"); - return null; - } - - if (model == "") - { - LLMUnitySetup.LogError("No model file provided!"); - return null; - } - string modelPath = GetLLMManagerAssetRuntime(model); - if (!File.Exists(modelPath)) - { - LLMUnitySetup.LogError($"File {modelPath} not found!"); - return null; - } - - loraManager.FromStrings(lora, loraWeights); - string loraArgument = ""; - foreach (string lora in loraManager.GetLoras()) - { - string loraPath = GetLLMManagerAssetRuntime(lora); - if (!File.Exists(loraPath)) - { - LLMUnitySetup.LogError($"File {loraPath} not found!"); - return null; - } - loraArgument += $" --lora \"{loraPath}\""; - } - - int numThreadsToUse = numThreads; - if (Application.platform == RuntimePlatform.Android && numThreads <= 0) numThreadsToUse = LLMUnitySetup.AndroidGetNumBigCores(); + if (string.IsNullOrEmpty(path)) return ""; - int slots = GetNumClients(); - string arguments = $"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}"; - if (embeddingsOnly) arguments += " --embedding"; - if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}"; - arguments += loraArgument; - if (numGPULayers > 0) arguments += $" -ngl {numGPULayers}"; - if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn"; - if (remote) + if (!File.Exists(path)) { - arguments += $" --port {port} --host 0.0.0.0"; - if (!String.IsNullOrEmpty(APIKey)) arguments += $" --api-key {APIKey}"; + LLMUnitySetup.LogError($"File not found: {path}"); + return ""; } - // the following is the equivalent for running from command line - string serverCommand; - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer) serverCommand = "undreamai_server.exe"; - else serverCommand = "./undreamai_server"; - serverCommand += " " + arguments; - serverCommand += $" --template \"{chatTemplate}\""; - if (remote && SSLCert != "" && SSLKey != "") serverCommand += $" --ssl-cert-file {SSLCertPath} --ssl-key-file {SSLKeyPath}"; - LLMUnitySetup.Log($"Deploy server command: {serverCommand}"); - return arguments; + return File.ReadAllText(path); } - private void SetupLogging() - { - logStreamWrapper = ConstructStreamWrapper(LLMUnitySetup.LogWarning, true); - llmlib?.Logging(logStreamWrapper.GetStringWrapper()); - } + #endregion - private void StopLogging() + #region Helper Methods + private int GetNumClients() { - if (logStreamWrapper == null) return; - llmlib?.StopLogging(); - DestroyStreamWrapper(logStreamWrapper); + return Math.Max(parallelPrompts == -1 ? clients.Count : parallelPrompts, 1); } - private void StartLLMServer(string arguments) + private void AssertStarted() { - started = false; - failed = false; - bool useGPU = numGPULayers > 0; - - foreach (string arch in LLMLib.PossibleArchitectures(useGPU)) - { - string error; - try - { - InitLib(arch); - InitService(arguments); - LLMUnitySetup.Log($"Using architecture: {arch}"); - break; - } - catch (LLMException e) - { - error = e.Message; - Destroy(); - } - catch (DestroyException) - { - break; - } - catch (Exception e) - { - error = $"{e.GetType()}: {e.Message}"; - } - LLMUnitySetup.Log($"Tried architecture: {arch}, error: " + error); - } - if (llmlib == null) + string error = null; + if (failed) error = "LLM service couldn't be created"; + else if (!started) error = "LLM service not started"; + if (error != null) { - LLMUnitySetup.LogError("LLM service couldn't be created"); - failed = true; - return; + LLMUnitySetup.LogError(error); + throw new Exception(error); } - CallWithLock(StartService); - LLMUnitySetup.Log("LLM service created"); - } - - private void InitLib(string arch) - { - llmlib = new LLMLib(arch); - CheckLLMStatus(false); } - void CallWithLock(EmptyCallback fn) + private void AssertNotStarted() { - lock (startLock) + if (started) { - if (llmlib == null) throw new DestroyException(); - fn(); + string error = "This method can't be called when the LLM has started"; + LLMUnitySetup.LogError(error); + throw new Exception(error); } } - private void InitService(string arguments) + /// + /// Stops and cleans up the LLM service. + /// + public void Destroy() { lock (staticLock) { - if (debug) CallWithLock(SetupLogging); - CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); }); - CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate)); - if (remote) + try + { + llmService?.Dispose(); + llmlib = null; + started = false; + failed = false; + } + catch (Exception ex) { - if (SSLCert != "" && SSLKey != "") - { - LLMUnitySetup.Log("Using SSL"); - CallWithLock(() => llmlib.LLM_SetSSL(LLMObject, SSLCert, SSLKey)); - } - CallWithLock(() => llmlib.LLM_StartServer(LLMObject)); + LLMUnitySetup.LogError($"Error during LLM cleanup: {ex.Message}"); } - CallWithLock(() => CheckLLMStatus(false)); } } - private void StartService() - { - llmThread = new Thread(() => llmlib.LLM_Start(LLMObject)); - llmThread.Start(); - while (!llmlib.LLM_Started(LLMObject)) {} - ApplyLoras(); - started = true; - } - - /// - /// Registers a local LLMCaller object. - /// This allows to bind the LLMCaller "client" to a specific slot of the LLM. - /// - /// - /// - public int Register(LLMCaller llmCaller) - { - clients.Add(llmCaller); - int index = clients.IndexOf(llmCaller); - if (parallelPrompts != -1) return index % parallelPrompts; - return index; - } - - protected int GetNumClients() - { - return Math.Max(parallelPrompts == -1 ? clients.Count : parallelPrompts, 1); - } + #endregion + #region Static Asset Management /// \cond HIDE - public delegate void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper); - public delegate void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper); - public delegate void LLMReplyCallback(IntPtr LLMObject, string json_data, IntPtr stringWrapper); - /// \endcond - - StreamWrapper ConstructStreamWrapper(Callback streamCallback = null, bool clearOnUpdate = false) + public static string GetLLMManagerAsset(string path) { - StreamWrapper streamWrapper = new StreamWrapper(llmlib, streamCallback, clearOnUpdate); - streamWrappers.Add(streamWrapper); - return streamWrapper; +#if UNITY_EDITOR + if (!EditorApplication.isPlaying) return GetLLMManagerAssetEditor(path); +#endif + return GetLLMManagerAssetRuntime(path); } - void DestroyStreamWrapper(StreamWrapper streamWrapper) + public static string GetLLMManagerAssetEditor(string path) { - streamWrappers.Remove(streamWrapper); - streamWrapper.Destroy(); - } + if (string.IsNullOrEmpty(path)) return path; - /// - /// The Unity Update function. It is used to retrieve the LLM replies. - public void Update() - { - foreach (StreamWrapper streamWrapper in streamWrappers) streamWrapper.Update(); - } + // Check LLMManager first + ModelEntry modelEntry = LLMManager.Get(path); + if (modelEntry != null) return modelEntry.filename; - void AssertStarted() - { - string error = null; - if (failed) error = "LLM service couldn't be created"; - else if (!started) error = "LLM service not started"; - if (error != null) - { - LLMUnitySetup.LogError(error); - throw new Exception(error); - } - } + // Check StreamingAssets + string assetPath = LLMUnitySetup.GetAssetPath(path); + string basePath = LLMUnitySetup.GetAssetPath(); - void AssertNotStarted() - { - if (started) + if (File.Exists(assetPath) && LLMUnitySetup.IsSubPath(assetPath, basePath)) { - string error = "This method can't be called when the LLM has started"; - LLMUnitySetup.LogError(error); - throw new Exception(error); + return LLMUnitySetup.RelativePath(assetPath, basePath); } - } - void CheckLLMStatus(bool log = true) - { - if (llmlib == null) { return; } - IntPtr stringWrapper = llmlib.StringWrapper_Construct(); - int status = llmlib.LLM_Status(LLMObject, stringWrapper); - string result = llmlib.GetStringWrapperResult(stringWrapper); - llmlib.StringWrapper_Delete(stringWrapper); - string message = $"LLM {status}: {result}"; - if (status > 0) + // Warn about local files not in build + if (File.Exists(assetPath)) { - if (log) LLMUnitySetup.LogError(message); - throw new LLMException(message, status); + string errorMessage = $"The model {path} was loaded locally. You can include it in the build in one of these ways:"; + errorMessage += $"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path"; + errorMessage += $"\n-Load the model with the model manager inside the LLM GameObject and use its filename"; + LLMUnitySetup.LogWarning(errorMessage); } - else if (status < 0) + else { - if (log) LLMUnitySetup.LogWarning(message); + LLMUnitySetup.LogError($"Model file not found: {path}"); } - } - async Task LLMNoInputReply(LLMNoInputReplyCallback callback) - { - AssertStarted(); - IntPtr stringWrapper = llmlib.StringWrapper_Construct(); - await Task.Run(() => callback(LLMObject, stringWrapper)); - string result = llmlib?.GetStringWrapperResult(stringWrapper); - llmlib?.StringWrapper_Delete(stringWrapper); - CheckLLMStatus(); - return result; - } - - async Task LLMReply(LLMReplyCallback callback, string json) - { - AssertStarted(); - IntPtr stringWrapper = llmlib.StringWrapper_Construct(); - await Task.Run(() => callback(LLMObject, json, stringWrapper)); - string result = llmlib?.GetStringWrapperResult(stringWrapper); - llmlib?.StringWrapper_Delete(stringWrapper); - CheckLLMStatus(); - return result; - } - - /// - /// Tokenises the provided query. - /// - /// json request containing the query - /// tokenisation result - public async Task Tokenize(string json) - { - AssertStarted(); - LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) => - { - llmlib.LLM_Tokenize(LLMObject, jsonData, strWrapper); - }; - return await LLMReply(callback, json); - } - - /// - /// Detokenises the provided query. - /// - /// json request containing the query - /// detokenisation result - public async Task Detokenize(string json) - { - AssertStarted(); - LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) => - { - llmlib.LLM_Detokenize(LLMObject, jsonData, strWrapper); - }; - return await LLMReply(callback, json); + return path; } - /// - /// Computes the embeddings of the provided query. - /// - /// json request containing the query - /// embeddings result - public async Task Embeddings(string json) + public static string GetLLMManagerAssetRuntime(string path) { - AssertStarted(); - LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) => - { - llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper); - }; - return await LLMReply(callback, json); - } + if (string.IsNullOrEmpty(path)) return path; - /// - /// Sets the lora scale, only works after the LLM service has started - /// - /// switch result - public void ApplyLoras() - { - LoraWeightRequestList loraWeightRequest = new LoraWeightRequestList(); - loraWeightRequest.loraWeights = new List(); - float[] weights = loraManager.GetWeights(); - if (weights.Length == 0) return; - for (int i = 0; i < weights.Length; i++) + // Try LLMManager path + string managerPath = LLMManager.GetAssetPath(path); + if (!string.IsNullOrEmpty(managerPath) && File.Exists(managerPath)) { - loraWeightRequest.loraWeights.Add(new LoraWeightRequest() { id = i, scale = weights[i] }); + return managerPath; } - string json = JsonUtility.ToJson(loraWeightRequest); - int startIndex = json.IndexOf("["); - int endIndex = json.LastIndexOf("]") + 1; - json = json.Substring(startIndex, endIndex - startIndex); + // Try StreamingAssets + string assetPath = LLMUnitySetup.GetAssetPath(path); + if (File.Exists(assetPath)) return assetPath; - IntPtr stringWrapper = llmlib.StringWrapper_Construct(); - llmlib.LLM_LoraWeight(LLMObject, json, stringWrapper); - llmlib.StringWrapper_Delete(stringWrapper); - } + // Try download path + string downloadPath = LLMUnitySetup.GetDownloadAssetPath(path); + if (File.Exists(downloadPath)) return downloadPath; - /// - /// Gets a list of the lora adapters - /// - /// list of lara adapters - public async Task> ListLoras() - { - AssertStarted(); - LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) => - { - llmlib.LLM_LoraList(LLMObject, strWrapper); - }; - string json = await LLMNoInputReply(callback); - if (String.IsNullOrEmpty(json)) return null; - LoraWeightResultList loraRequest = JsonUtility.FromJson("{\"loraWeights\": " + json + "}"); - return loraRequest.loraWeights; + return path; } - /// - /// Allows to save / restore the state of a slot - /// - /// json request containing the query - /// slot result - public async Task Slot(string json) - { - AssertStarted(); - LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) => - { - llmlib.LLM_Slot(LLMObject, jsonData, strWrapper); - }; - return await LLMReply(callback, json); - } + /// \endcond + #endregion + } - /// - /// Allows to use the chat and completion functionality of the LLM. - /// - /// json request containing the query - /// callback function to call with intermediate responses - /// completion result - public async Task Completion(string json, Callback streamCallback = null) - { - AssertStarted(); - if (streamCallback == null) streamCallback = (string s) => {}; - StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback); - await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper())); - if (!started) return null; - streamWrapper.Update(); - string result = streamWrapper.GetString(); - DestroyStreamWrapper(streamWrapper); - CheckLLMStatus(); - return result; - } + /// + /// Unity-specific implementation of LlamaLib for handling native library loading. + /// + public class LlamaLibUnity : UndreamAI.LlamaLib.LlamaLib + { + public LlamaLibUnity(bool gpu = false) : base(gpu) {} - /// - /// Allows to cancel the requests in a specific slot of the LLM - /// - /// slot of the LLM - public void CancelRequest(int id_slot) + public override string FindLibrary(string libraryName) { - AssertStarted(); - llmlib?.LLM_Cancel(LLMObject, id_slot); - CheckLLMStatus(); - } + string lookupDir = Path.Combine(LLMUnitySetup.libraryPath, GetPlatform(), "native"); + string libraryPath = Path.Combine(lookupDir, libraryName); - /// - /// Stops and destroys the LLM - /// - public void Destroy() - { - lock (staticLock) - lock (startLock) - { - try - { - if (llmlib != null) - { - if (LLMObject != IntPtr.Zero) - { - llmlib.LLM_Stop(LLMObject); - if (remote) llmlib.LLM_StopServer(LLMObject); - StopLogging(); - llmThread?.Join(); - llmlib.LLM_Delete(LLMObject); - LLMObject = IntPtr.Zero; - } - llmlib.Destroy(); - llmlib = null; - } - started = false; - failed = false; - } - catch (Exception e) - { - LLMUnitySetup.LogError(e.Message); - } - } - } + if (File.Exists(libraryPath)) + { + return libraryPath; + } - /// - /// The Unity OnDestroy function called when the onbject is destroyed. - /// The function StopProcess is called to stop the LLM server. - /// - public void OnDestroy() - { - Destroy(); - LLMManager.Unregister(this); + throw new FileNotFoundException($"Native library not found: {libraryName} in {lookupDir}"); } } } diff --git a/Runtime/LLMAgent.cs b/Runtime/LLMAgent.cs new file mode 100644 index 00000000..21094ea8 --- /dev/null +++ b/Runtime/LLMAgent.cs @@ -0,0 +1,487 @@ +/// @file +/// @brief File implementing the LLM chat agent functionality for Unity. +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using UndreamAI.LlamaLib; +using UnityEngine; + +namespace LLMUnity +{ + [DefaultExecutionOrder(-1)] + /// @ingroup llm + /// + /// Unity MonoBehaviour that implements a conversational AI agent with persistent chat history. + /// Extends LLMClient to provide chat-specific functionality including role management, + /// conversation history persistence, and specialized chat completion methods. + /// + public class LLMAgent : LLMClient + { + #region Inspector Fields + /// Filename for saving chat history (saved in persistentDataPath) + [Tooltip("Filename for saving chat history (saved in Application.persistentDataPath)")] + [LLM] public string save = ""; + + /// Save LLM processing cache for faster reload (~100MB per agent) + [Tooltip("Save LLM processing cache for faster reload (~100MB per agent)")] + [LLM] public bool saveCache = false; + + /// Server slot to use for processing (affects caching behavior) + [Tooltip("Server slot to use for processing (affects caching behavior)")] + [ModelAdvanced, SerializeField] protected int _slot = -1; + + /// Role name for user messages in conversation + [Tooltip("Role name for user messages in conversation")] + [Chat, SerializeField] protected string _userRole = "user"; + + /// Role name for AI assistant messages in conversation + [Tooltip("Role name for AI assistant messages in conversation")] + [Chat, SerializeField] protected string _assistantRole = "assistant"; + + /// System prompt that defines the AI's personality and behavior + [Tooltip("System prompt that defines the AI's personality and behavior")] + [TextArea(5, 10), Chat, SerializeField] + protected string _systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."; + #endregion + + #region Public Properties + /// Server slot ID for this agent's requests + public int slot + { + get => _slot; + set + { + if (_slot != value) + { + _slot = value; + if (llmAgent != null) llmAgent.SlotId = _slot; + } + } + } + + /// Role identifier for user messages + public string userRole + { + get => _userRole; + set + { + if (_userRole != value) + { + _userRole = value; + if (llmAgent != null) llmAgent.UserRole = _userRole; + } + } + } + + /// Role identifier for assistant messages + public string assistantRole + { + get => _assistantRole; + set + { + if (_assistantRole != value) + { + _assistantRole = value; + if (llmAgent != null) llmAgent.AssistantRole = _assistantRole; + } + } + } + + /// System prompt defining the agent's behavior and personality + public string systemPrompt + { + get => _systemPrompt; + set + { + if (_systemPrompt != value) + { + _systemPrompt = value; + if (llmAgent != null) llmAgent.SystemPrompt = _systemPrompt; + } + } + } + + /// The underlying LLMAgent instance from LlamaLib + public UndreamAI.LlamaLib.LLMAgent llmAgent { get; protected set; } + + /// Current conversation history as a list of chat messages + public List chat + { + get => llmAgent?.GetHistory() ?? new List(); + set + { + CheckLLMAgent(); + llmAgent.SetHistory(value ?? new List()); + } + } + #endregion + + #region Unity Lifecycle and Initialization + public override void Awake() + { + if (!remote) llm?.Register(this); + base.Awake(); + } + + private void CheckLLMAgent() + { + if (llmAgent == null) + { + string error = "LLMAgent not initialized"; + LLMUnitySetup.LogError(error); + throw new System.InvalidOperationException(error); + } + } + + protected override async Task SetupLLMClient() + { + await base.SetupLLMClient(); + + string exceptionMessage = ""; + try + { + llmAgent = new UndreamAI.LlamaLib.LLMAgent(llmClient, systemPrompt, userRole, assistantRole); + } + catch (Exception ex) + { + exceptionMessage = ex.Message; + } + if (llmAgent == null || exceptionMessage != "") + { + string error = "LLMAgent not initialized"; + if (exceptionMessage != "") error += ", error: " + exceptionMessage; + LLMUnitySetup.LogError(error); + throw new InvalidOperationException(error); + } + + if (slot != -1) llmAgent.SlotId = slot; + InitHistory(); + } + + protected override void OnValidate() + { + base.OnValidate(); + + // Validate slot configuration + if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) + { + LLMUnitySetup.LogError($"Slot must be between 0 and {llm.parallelPrompts - 1}, or -1 for auto-assignment"); + } + } + + protected override LLMLocal GetCaller() + { + return llmAgent; + } + + /// + /// Initializes conversation history by clearing current state and loading from file if available. + /// + protected virtual void InitHistory() + { + ClearChat(); + LoadHistory(); + } + + /// + /// Loads conversation history from the saved file if it exists. + /// + protected virtual void LoadHistory() + { + if (string.IsNullOrEmpty(save) || !File.Exists(GetJsonSavePath(save))) + { + return; + } + + try + { + Load(save); + } + catch (System.Exception ex) + { + LLMUnitySetup.LogError($"Failed to load chat history from '{save}': {ex.Message}"); + } + } + + #endregion + + #region File Path Management + /// + /// Gets the full path for a file in the persistent data directory. + /// + /// Filename or relative path + /// Full file path in persistent data directory + protected virtual string GetSavePath(string filename) + { + if (string.IsNullOrEmpty(filename)) + { + throw new System.ArgumentNullException(nameof(filename)); + } + + return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/'); + } + + /// + /// Gets the save path for chat history JSON file. + /// + /// Base filename (without extension) + /// Full path to JSON file + public virtual string GetJsonSavePath(string filename) + { + return GetSavePath(filename + ".json"); + } + + /// + /// Gets the save path for LLM cache file. + /// + /// Base filename (without extension) + /// Full path to cache file + public virtual string GetCacheSavePath(string filename) + { + return GetSavePath(filename + ".cache"); + } + + #endregion + + #region Chat Management + /// + /// Clears the entire conversation history. + /// + public virtual void ClearChat() + { + CheckLLMAgent(); + llmAgent.ClearHistory(); + } + + /// + /// Adds a message with a specific role to the conversation history. + /// + /// Message role (e.g., userRole, assistantRole, or custom role) + /// Message content + public virtual void AddMessage(string role, string content) + { + CheckLLMAgent(); + llmAgent.AddMessage(role, content); + } + + /// + /// Adds a user message to the conversation history. + /// + /// User message content + public virtual void AddUserMessage(string content) + { + CheckLLMAgent(); + llmAgent.AddUserMessage(content); + } + + /// + /// Adds an AI assistant message to the conversation history. + /// + /// Assistant message content + public virtual void AddAssistantMessage(string content) + { + CheckLLMAgent(); + llmAgent.AddAssistantMessage(content); + } + + #endregion + + #region Chat Functionality + /// + /// Processes a user query and generates an AI response using conversation context. + /// The query and response are automatically added to chat history if specified. + /// + /// User's message or question + /// Optional streaming callback for partial responses + /// Whether to add the exchange to conversation history + /// AI assistant's response + public virtual string Chat(string query, LlamaLib.CharArrayCallback callback = null, bool addToHistory = true) + { + CheckLLMAgent(); + SetCompletionParameters(); + return llmAgent.Chat(query, addToHistory, callback); + } + + /// + /// Processes a user query asynchronously and generates an AI response using conversation context. + /// The query and response are automatically added to chat history if specified. + /// + /// User's message or question + /// Optional streaming callback for partial responses + /// Optional callback when response is complete + /// Whether to add the exchange to conversation history + /// Task that returns the AI assistant's response + public virtual async Task ChatAsync(string query, LlamaLib.CharArrayCallback callback = null, + EmptyCallback completionCallback = null, bool addToHistory = true) + { + CheckLLMAgent(); + + // Wrap callback to ensure it runs on the main thread + LlamaLib.CharArrayCallback wrappedCallback = Utils.WrapCallbackForAsync(callback); + SetCompletionParameters(); + string result = await llmAgent.ChatAsync(query, addToHistory, wrappedCallback); + completionCallback?.Invoke(); + return result; + } + + #endregion + + #region Model Warmup + /// + /// Warms up the model by processing the system prompt without generating output. + /// This caches the system prompt processing for faster subsequent responses. + /// + /// Optional callback when warmup completes + /// Task that completes when warmup finishes + public virtual async Task Warmup(EmptyCallback completionCallback = null) + { + await Warmup(null, completionCallback); + } + + /// + /// Warms up the model with a specific prompt without adding it to history. + /// This pre-processes prompts for faster response times in subsequent interactions. + /// + /// Warmup prompt (not added to history) + /// Optional callback when warmup completes + /// Task that completes when warmup finishes + public virtual async Task Warmup(string query, EmptyCallback completionCallback = null) + { + int originalNumPredict = numPredict; + try + { + // Set to generate no tokens for warmup + numPredict = 0; + await ChatAsync(query, null, completionCallback, false); + } + finally + { + // Restore original setting + numPredict = originalNumPredict; + SetCompletionParameters(); + } + } + + #endregion + + #region Persistence + /// + /// Saves the conversation history and optionally the LLM cache to disk. + /// + /// Base filename (without extension) for saving + /// Result message from cache save operation, or null if cache not saved + public virtual string Save(string filename) + { + if (string.IsNullOrEmpty(filename)) + { + throw new System.ArgumentNullException(nameof(filename)); + } + CheckLLMAgent(); + + // Save chat history + string jsonPath = GetJsonSavePath(filename); + string directory = Path.GetDirectoryName(jsonPath); + + if (!Directory.Exists(directory)) + { + Directory.CreateDirectory(directory); + } + + try + { + llmAgent.SaveHistory(jsonPath); + LLMUnitySetup.Log($"Saved chat history to: {jsonPath}"); + } + catch (System.Exception ex) + { + LLMUnitySetup.LogError($"Failed to save chat history to '{jsonPath}': {ex.Message}"); + throw; + } + + // Save cache if enabled and not remote + if (!remote && saveCache) + { + try + { + string cachePath = GetCacheSavePath(filename); + string result = llmAgent.SaveSlot(cachePath); + LLMUnitySetup.Log($"Saved LLM cache to: {cachePath}"); + return result; + } + catch (System.Exception ex) + { + LLMUnitySetup.LogWarning($"Failed to save LLM cache: {ex.Message}"); + } + } + + return null; + } + + /// + /// Loads conversation history and optionally the LLM cache from disk. + /// + /// Base filename (without extension) to load from + /// Result message from cache load operation, or null if cache not loaded + public virtual string Load(string filename) + { + if (string.IsNullOrEmpty(filename)) + { + throw new System.ArgumentNullException(nameof(filename)); + } + CheckLLMAgent(); + + // Load chat history + string jsonPath = GetJsonSavePath(filename); + if (!File.Exists(jsonPath)) + { + throw new FileNotFoundException($"Chat history file not found: {jsonPath}"); + } + + try + { + llmAgent.LoadHistory(jsonPath); + LLMUnitySetup.Log($"Loaded chat history from: {jsonPath}"); + } + catch (System.Exception ex) + { + LLMUnitySetup.LogError($"Failed to load chat history from '{jsonPath}': {ex.Message}"); + throw; + } + + // Load cache if enabled and not remote + if (!remote && saveCache) + { + string cachePath = GetCacheSavePath(filename); + if (File.Exists(cachePath)) + { + try + { + string result = llmAgent.LoadSlot(cachePath); + LLMUnitySetup.Log($"Loaded LLM cache from: {cachePath}"); + return result; + } + catch (System.Exception ex) + { + LLMUnitySetup.LogWarning($"Failed to load LLM cache from '{cachePath}': {ex.Message}"); + } + } + } + + return null; + } + + #endregion + + #region Request Management + /// + /// Cancels any active requests for this agent. + /// + public void CancelRequests() + { + llmAgent?.Cancel(); + } + + #endregion + } +} diff --git a/Editor/LLMCallerEditor.cs.meta b/Runtime/LLMAgent.cs.meta similarity index 83% rename from Editor/LLMCallerEditor.cs.meta rename to Runtime/LLMAgent.cs.meta index 113bed10..2be83c16 100644 --- a/Editor/LLMCallerEditor.cs.meta +++ b/Runtime/LLMAgent.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 7d6ebb9c97f8c2e959bbdb536ee13ec5 +guid: b4326d5ae3b03ff55847035351559f4e MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/Runtime/LLMBuilder.cs b/Runtime/LLMBuilder.cs index 6725cd43..9d595c2a 100644 --- a/Runtime/LLMBuilder.cs +++ b/Runtime/LLMBuilder.cs @@ -98,7 +98,7 @@ public static void MovePath(string source, string target) /// path public static bool DeletePath(string path) { - string[] allowedDirs = new string[] { LLMUnitySetup.GetAssetPath(), BuildTempDir, PluginDir("Android"), PluginDir("iOS"), PluginDir("VisionOS")}; + string[] allowedDirs = new string[] { LLMUnitySetup.GetAssetPath(), BuildTempDir, PluginDir("Android"), PluginDir("iOS"), PluginDir("VisionOS") }; bool deleteOK = false; foreach (string allowedDir in allowedDirs) deleteOK = deleteOK || LLMUnitySetup.IsSubPath(path, allowedDir); if (!deleteOK) @@ -113,7 +113,7 @@ public static bool DeletePath(string path) static void AddMovedPair(string source, string target) { - movedPairs.Add(new StringPair {source = source, target = target}); + movedPairs.Add(new StringPair { source = source, target = target }); File.WriteAllText(movedCache, JsonUtility.ToJson(new ListStringPair { pairs = movedPairs }, true)); } @@ -164,37 +164,38 @@ static void AddActionAddMeta(string target) /// target platform public static void BuildLibraryPlatforms(BuildTarget buildTarget) { - string platform = ""; + List platforms = new List(); + bool checkCUBLAS = false; switch (buildTarget) { case BuildTarget.StandaloneWindows: case BuildTarget.StandaloneWindows64: - platform = "windows"; + platforms.Add("win-x64"); + checkCUBLAS = true; break; case BuildTarget.StandaloneLinux64: - platform = "linux"; + platforms.Add("linux-x64"); + checkCUBLAS = true; break; case BuildTarget.StandaloneOSX: - platform = "macos"; + platforms.Add("osx-universal"); break; case BuildTarget.Android: - platform = "android"; + platforms.Add("android-arm64"); + platforms.Add("android-x64"); break; case BuildTarget.iOS: - platform = "ios"; + platforms.Add("ios-arm64"); break; case BuildTarget.VisionOS: - platform = "visionos"; + platforms.Add("visionos-arm64"); break; } foreach (string source in Directory.GetDirectories(LLMUnitySetup.libraryPath)) { string sourceName = Path.GetFileName(source); - bool move = !sourceName.StartsWith(platform); - move = move || (sourceName.Contains("cuda") && !sourceName.Contains("full") && LLMUnitySetup.FullLlamaLib); - move = move || (sourceName.Contains("cuda") && sourceName.Contains("full") && !LLMUnitySetup.FullLlamaLib); - if (move) + if (!platforms.Contains(sourceName)) { string target = Path.Combine(BuildTempDir, sourceName); MoveAction(source, target); @@ -202,32 +203,65 @@ public static void BuildLibraryPlatforms(BuildTarget buildTarget) } } + if (checkCUBLAS) + { + List exclusionKeywords = LLMUnitySetup.CUBLAS ? new List() { "tinyblas" } : new List() { "cublas", "cudart" }; + foreach (string platform in platforms) + { + string platformDir = Path.Combine(LLMUnitySetup.libraryPath, platform, "native"); + foreach (string source in Directory.GetFiles(platformDir)) + { + string sourceName = Path.GetFileName(source); + foreach (string exclusionKeyword in exclusionKeywords) + { + if (sourceName.Contains(exclusionKeyword)) + { + string target = Path.Combine(BuildTempDir, platform, "native", sourceName); + MoveAction(source, target); + MoveAction(source + ".meta", target + ".meta"); + break; + } + } + } + } + } + if (buildTarget == BuildTarget.Android || buildTarget == BuildTarget.iOS || buildTarget == BuildTarget.VisionOS) { - string source = Path.Combine(LLMUnitySetup.libraryPath, platform); - string target = PluginLibraryDir(buildTarget.ToString()); - string pluginDir = PluginDir(buildTarget.ToString()); - MoveAction(source, target); - MoveAction(source + ".meta", target + ".meta"); - AddActionAddMeta(pluginDir); + foreach (string platform in platforms) + { + string source = Path.Combine(LLMUnitySetup.libraryPath, platform, "native"); + string target = Path.Combine(PluginLibraryDir(buildTarget.ToString()), platform); + string pluginDir = PluginDir(buildTarget.ToString()); + MoveAction(source, target); + MoveAction(source + ".meta", target + ".meta"); + AddActionAddMeta(pluginDir); + } } } static void OnPostprocessAllAssets(string[] importedAssets, string[] deletedAssets, string[] movedAssets, string[] movedFromAssetPaths, bool didDomainReload) { - foreach (BuildTarget buildTarget in new BuildTarget[]{BuildTarget.iOS, BuildTarget.VisionOS}) + foreach (BuildTarget buildTarget in new BuildTarget[] { BuildTarget.iOS, BuildTarget.VisionOS, BuildTarget.Android }) { - string pathToPlugin = Path.Combine("Assets", PluginLibraryDir(buildTarget.ToString(), true), $"libundreamai_{buildTarget.ToString().ToLower()}.a"); - for (int i = 0; i < movedAssets.Length; i++) + string suffix = (buildTarget == BuildTarget.Android) ? "so" : "a"; + string platformDir = Path.Combine("Assets", PluginLibraryDir(buildTarget.ToString(), true)); + if (!Directory.Exists(platformDir)) continue; + foreach (string archDir in Directory.GetDirectories(platformDir)) { - if (movedAssets[i] == pathToPlugin) + string arch = Path.GetFileName(archDir); + string pathToPlugin = Path.Combine(platformDir, $"libllamalib_{arch}.{suffix}"); + for (int i = 0; i < movedAssets.Length; i++) { - var importer = AssetImporter.GetAtPath(pathToPlugin) as PluginImporter; - if (importer != null && importer.isNativePlugin) + if (movedAssets[i] == pathToPlugin) { - importer.SetCompatibleWithPlatform(buildTarget, true); - importer.SetPlatformData(buildTarget, "CPU", "ARM64"); - AssetDatabase.ImportAsset(pathToPlugin); + var importer = AssetImporter.GetAtPath(pathToPlugin) as PluginImporter; + if (importer != null && importer.isNativePlugin) + { + importer.SetCompatibleWithPlatform(buildTarget, true); + importer.SetPlatformData(buildTarget, "CPU", arch.Split("_")[1].ToUpper()); + AssetDatabase.ImportAsset(pathToPlugin); + } } } } diff --git a/Runtime/LLMCaller.cs b/Runtime/LLMCaller.cs deleted file mode 100644 index 4b69137d..00000000 --- a/Runtime/LLMCaller.cs +++ /dev/null @@ -1,387 +0,0 @@ -/// @file -/// @brief File implementing the basic functionality for LLM callers. -using System; -using System.Collections.Generic; -using System.Threading.Tasks; -using UnityEngine; -using UnityEngine.Networking; - -namespace LLMUnity -{ - [DefaultExecutionOrder(-2)] - /// @ingroup llm - /// - /// Class implementing calling of LLM functions (local and remote). - /// - public class LLMCaller : MonoBehaviour - { - /// show/hide advanced options in the GameObject - [Tooltip("show/hide advanced options in the GameObject")] - [HideInInspector] public bool advancedOptions = false; - /// use remote LLM server - [Tooltip("use remote LLM server")] - [LocalRemote] public bool remote = false; - /// LLM GameObject to use - [Tooltip("LLM GameObject to use")] // Tooltip: ignore - [Local, SerializeField] protected LLM _llm; - public LLM llm - { - get => _llm;//whatever - set => SetLLM(value); - } - /// API key for the remote server - [Tooltip("API key for the remote server")] - [Remote] public string APIKey; - /// host of the remote LLM server - [Tooltip("host of the remote LLM server")] - [Remote] public string host = "localhost"; - /// port of the remote LLM server - [Tooltip("port of the remote LLM server")] - [Remote] public int port = 13333; - /// number of retries to use for the remote LLM server requests (-1 = infinite) - [Tooltip("number of retries to use for the remote LLM server requests (-1 = infinite)")] - [Remote] public int numRetries = 10; - - protected LLM _prellm; - protected List<(string, string)> requestHeaders; - protected List WIPRequests = new List(); - - /// - /// The Unity Awake function that initializes the state before the application starts. - /// The following actions are executed: - /// - the corresponding LLM server is defined (if ran locally) - /// - the grammar is set based on the grammar file - /// - the prompt and chat history are initialised - /// - the chat template is constructed - /// - the number of tokens to keep are based on the system prompt (if setNKeepToPrompt=true) - /// - public virtual void Awake() - { - // Start the LLM server in a cross-platform way - if (!enabled) return; - - requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") }; - if (!remote) - { - AssignLLM(); - if (llm == null) - { - string error = $"No LLM assigned or detected for LLMCharacter {name}!"; - LLMUnitySetup.LogError(error); - throw new Exception(error); - } - } - else - { - if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add(("Authorization", "Bearer " + APIKey)); - } - } - - /// - /// Sets the LLM object of the LLMCaller - /// - /// LLM object - protected virtual void SetLLM(LLM llmSet) - { - if (llmSet != null && !IsValidLLM(llmSet)) - { - LLMUnitySetup.LogError(NotValidLLMError()); - llmSet = null; - } - _llm = llmSet; - _prellm = _llm; - } - - /// - /// Checks if a LLM is valid for the LLMCaller - /// - /// LLM object - /// bool specifying whether the LLM is valid - public virtual bool IsValidLLM(LLM llmSet) - { - return true; - } - - /// - /// Checks if a LLM can be auto-assigned if the LLM of the LLMCaller is null - /// - /// - /// bool specifying whether the LLM can be auto-assigned - public virtual bool IsAutoAssignableLLM(LLM llmSet) - { - return true; - } - - protected virtual string NotValidLLMError() - { - return $"Can't set LLM {llm.name} to {name}"; - } - - protected virtual void OnValidate() - { - if (_llm != _prellm) SetLLM(_llm); - AssignLLM(); - } - - protected virtual void Reset() - { - AssignLLM(); - } - - protected virtual void AssignLLM() - { - if (remote || llm != null) return; - - List validLLMs = new List(); -#if UNITY_6000_0_OR_NEWER - foreach (LLM foundllm in FindObjectsByType(typeof(LLM), FindObjectsSortMode.None)) -#else - foreach (LLM foundllm in FindObjectsOfType()) -#endif - { - if (IsValidLLM(foundllm) && IsAutoAssignableLLM(foundllm)) validLLMs.Add(foundllm); - } - if (validLLMs.Count == 0) return; - - llm = SortLLMsByBestMatching(validLLMs.ToArray())[0]; - string msg = $"Assigning LLM {llm.name} to {GetType()} {name}"; - if (llm.gameObject.scene != gameObject.scene) msg += $" from scene {llm.gameObject.scene}"; - LLMUnitySetup.Log(msg); - } - - protected virtual LLM[] SortLLMsByBestMatching(LLM[] arrayIn) - { - LLM[] array = (LLM[])arrayIn.Clone(); - for (int i = 0; i < array.Length - 1; i++) - { - bool swapped = false; - for (int j = 0; j < array.Length - i - 1; j++) - { - bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene; - bool swap = ( - (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) || - (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex()) - ); - if (swap) - { - LLM temp = array[j]; - array[j] = array[j + 1]; - array[j + 1] = temp; - swapped = true; - } - } - if (!swapped) break; - } - return array; - } - - protected virtual List TokenizeContent(TokenizeResult result) - { - // get the tokens from a tokenize result received from the endpoint - return result.tokens; - } - - protected virtual string DetokenizeContent(TokenizeRequest result) - { - // get content from a chat result received from the endpoint - return result.content; - } - - protected virtual List EmbeddingsContent(EmbeddingsResult result) - { - // get content from a chat result received from the endpoint - return result.embedding; - } - - protected virtual Ret ConvertContent(string response, ContentCallback getContent = null) - { - // template function to convert the json received and get the content - if (response == null) return default; - response = response.Trim(); - if (response.StartsWith("data: ")) - { - string responseArray = ""; - foreach (string responsePart in response.Replace("\n\n", "").Split("data: ")) - { - if (responsePart == "") continue; - if (responseArray != "") responseArray += ",\n"; - responseArray += responsePart; - } - response = $"{{\"data\": [{responseArray}]}}"; - } - return getContent(JsonUtility.FromJson(response)); - } - - protected virtual void CancelRequestsLocal() {} - - protected virtual void CancelRequestsRemote() - { - foreach (UnityWebRequest request in WIPRequests) - { - request.Abort(); - } - WIPRequests.Clear(); - } - - /// - /// Cancel the ongoing requests e.g. Chat, Complete. - /// - // - public virtual void CancelRequests() - { - if (remote) CancelRequestsRemote(); - else CancelRequestsLocal(); - } - - protected virtual async Task PostRequestLocal(string json, string endpoint, ContentCallback getContent, Callback callback = null) - { - // send a post request to the server and call the relevant callbacks to convert the received content and handle it - // this function has streaming functionality i.e. handles the answer while it is being received - while (!llm.failed && !llm.started) await Task.Yield(); - string callResult = null; - switch (endpoint) - { - case "tokenize": - callResult = await llm.Tokenize(json); - break; - case "detokenize": - callResult = await llm.Detokenize(json); - break; - case "embeddings": - callResult = await llm.Embeddings(json); - break; - case "slots": - callResult = await llm.Slot(json); - break; - default: - LLMUnitySetup.LogError($"Unknown endpoint {endpoint}"); - break; - } - - Ret result = ConvertContent(callResult, getContent); - callback?.Invoke(result); - return result; - } - - protected virtual async Task PostRequestRemote(string json, string endpoint, ContentCallback getContent, Callback callback = null) - { - // send a post request to the server and call the relevant callbacks to convert the received content and handle it - // this function has streaming functionality i.e. handles the answer while it is being received - if (endpoint == "slots") - { - LLMUnitySetup.LogError("Saving and loading is not currently supported in remote setting"); - return default; - } - - Ret result = default; - byte[] jsonToSend = new System.Text.UTF8Encoding().GetBytes(json); - UnityWebRequest request = null; - string error = null; - int tryNr = numRetries; - - while (tryNr != 0) - { - using (request = UnityWebRequest.Put($"{host}{(port != 0 ? $":{port}" : "")}/{endpoint}", jsonToSend)) - { - WIPRequests.Add(request); - - request.method = "POST"; - if (requestHeaders != null) - { - for (int i = 0; i < requestHeaders.Count; i++) - request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2); - } - - // Start the request asynchronously - UnityWebRequestAsyncOperation asyncOperation = request.SendWebRequest(); - await Task.Yield(); // Wait for the next frame so that asyncOperation is properly registered (especially if not in main thread) - - float lastProgress = 0f; - // Continue updating progress until the request is completed - while (!asyncOperation.isDone) - { - float currentProgress = request.downloadProgress; - // Check if progress has changed - if (currentProgress != lastProgress && callback != null) - { - callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent)); - lastProgress = currentProgress; - } - // Wait for the next frame - await Task.Yield(); - } - WIPRequests.Remove(request); - if (request.result == UnityWebRequest.Result.Success) - { - result = ConvertContent(request.downloadHandler.text, getContent); - error = null; - break; - } - else - { - result = default; - error = request.error; - if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break; - } - } - tryNr--; - if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr)); - } - - if (error != null) LLMUnitySetup.LogError(error); - callback?.Invoke(result); - return result; - } - - protected virtual async Task PostRequest(string json, string endpoint, ContentCallback getContent, Callback callback = null) - { - if (remote) return await PostRequestRemote(json, endpoint, getContent, callback); - return await PostRequestLocal(json, endpoint, getContent, callback); - } - - /// - /// Tokenises the provided query. - /// - /// query to tokenise - /// callback function called with the result tokens - /// list of the tokens - public virtual async Task> Tokenize(string query, Callback> callback = null) - { - // handle the tokenization of a message by the user - TokenizeRequest tokenizeRequest = new TokenizeRequest(); - tokenizeRequest.content = query; - string json = JsonUtility.ToJson(tokenizeRequest); - return await PostRequest>(json, "tokenize", TokenizeContent, callback); - } - - /// - /// Detokenises the provided tokens to a string. - /// - /// tokens to detokenise - /// callback function called with the result string - /// the detokenised string - public virtual async Task Detokenize(List tokens, Callback callback = null) - { - // handle the detokenization of a message by the user - TokenizeResult tokenizeRequest = new TokenizeResult(); - tokenizeRequest.tokens = tokens; - string json = JsonUtility.ToJson(tokenizeRequest); - return await PostRequest(json, "detokenize", DetokenizeContent, callback); - } - - /// - /// Computes the embeddings of the provided input. - /// - /// input to compute the embeddings for - /// callback function called with the result string - /// the computed embeddings - public virtual async Task> Embeddings(string query, Callback> callback = null) - { - // handle the tokenization of a message by the user - TokenizeRequest tokenizeRequest = new TokenizeRequest(); - tokenizeRequest.content = query; - string json = JsonUtility.ToJson(tokenizeRequest); - return await PostRequest>(json, "embeddings", EmbeddingsContent, callback); - } - } -} diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs deleted file mode 100644 index 0bf91a15..00000000 --- a/Runtime/LLMCharacter.cs +++ /dev/null @@ -1,726 +0,0 @@ -/// @file -/// @brief File implementing the LLM characters. -using System; -using System.Collections.Generic; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using UnityEditor; -using UnityEngine; - -namespace LLMUnity -{ - [DefaultExecutionOrder(-2)] - /// @ingroup llm - /// - /// Class implementing the LLM characters. - /// - public class LLMCharacter : LLMCaller - { - /// file to save the chat history. - /// The file will be saved within the persistentDataPath directory. - [Tooltip("file to save the chat history. The file will be saved within the persistentDataPath directory.")] - [LLM] public string save = ""; - /// save the LLM cache. Speeds up the prompt calculation when reloading from history but also requires ~100MB of space per character. - [Tooltip("save the LLM cache. Speeds up the prompt calculation when reloading from history but also requires ~100MB of space per character.")] - [LLM] public bool saveCache = false; - /// log the constructed prompt the Unity Editor. - [Tooltip("log the constructed prompt the Unity Editor.")] - [LLM] public bool debugPrompt = false; - /// maximum number of tokens that the LLM will predict (-1 = infinity). - [Tooltip("maximum number of tokens that the LLM will predict (-1 = infinity).")] - [Model] public int numPredict = -1; - /// slot of the server to use for computation (affects caching) - [Tooltip("slot of the server to use for computation (affects caching)")] - [ModelAdvanced] public int slot = -1; - /// grammar file used for the LLMCharacter (.gbnf format) - [Tooltip("grammar file used for the LLMCharacter (.gbnf format)")] - [ModelAdvanced] public string grammar = null; - /// grammar file used for the LLMCharacter (.json format) - [Tooltip("grammar file used for the LLMCharacter (.json format)")] - [ModelAdvanced] public string grammarJSON = null; - /// cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!) - [Tooltip("cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!)")] - [ModelAdvanced] public bool cachePrompt = true; - /// seed for reproducibility (-1 = no reproducibility). - [Tooltip("seed for reproducibility (-1 = no reproducibility).")] - [ModelAdvanced] public int seed = 0; - /// LLM temperature, lower values give more deterministic answers. - [Tooltip("LLM temperature, lower values give more deterministic answers.")] - [ModelAdvanced, Float(0f, 2f)] public float temperature = 0.2f; - /// Top-k sampling selects the next token only from the top k most likely predicted tokens (0 = disabled). - /// Higher values lead to more diverse text, while lower value will generate more focused and conservative text. - /// - [Tooltip("Top-k sampling selects the next token only from the top k most likely predicted tokens (0 = disabled). Higher values lead to more diverse text, while lower value will generate more focused and conservative text. ")] - [ModelAdvanced, Int(-1, 100)] public int topK = 40; - /// Top-p sampling selects the next token from a subset of tokens that together have a cumulative probability of at least p (1.0 = disabled). - /// Higher values lead to more diverse text, while lower value will generate more focused and conservative text. - /// - [Tooltip("Top-p sampling selects the next token from a subset of tokens that together have a cumulative probability of at least p (1.0 = disabled). Higher values lead to more diverse text, while lower value will generate more focused and conservative text. ")] - [ModelAdvanced, Float(0f, 1f)] public float topP = 0.9f; - /// minimum probability for a token to be used. - [Tooltip("minimum probability for a token to be used.")] - [ModelAdvanced, Float(0f, 1f)] public float minP = 0.05f; - /// Penalty based on repeated tokens to control the repetition of token sequences in the generated text. - [Tooltip("Penalty based on repeated tokens to control the repetition of token sequences in the generated text.")] - [ModelAdvanced, Float(0f, 2f)] public float repeatPenalty = 1.1f; - /// Penalty based on token presence in previous responses to control the repetition of token sequences in the generated text. (0.0 = disabled). - [Tooltip("Penalty based on token presence in previous responses to control the repetition of token sequences in the generated text. (0.0 = disabled).")] - [ModelAdvanced, Float(0f, 1f)] public float presencePenalty = 0f; - /// Penalty based on token frequency in previous responses to control the repetition of token sequences in the generated text. (0.0 = disabled). - [Tooltip("Penalty based on token frequency in previous responses to control the repetition of token sequences in the generated text. (0.0 = disabled).")] - [ModelAdvanced, Float(0f, 1f)] public float frequencyPenalty = 0f; - /// enable locally typical sampling (1.0 = disabled). Higher values will promote more contextually coherent tokens, while lower values will promote more diverse tokens. - [Tooltip("enable locally typical sampling (1.0 = disabled). Higher values will promote more contextually coherent tokens, while lower values will promote more diverse tokens.")] - [ModelAdvanced, Float(0f, 1f)] public float typicalP = 1f; - /// last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size). - [Tooltip("last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size).")] - [ModelAdvanced, Int(0, 2048)] public int repeatLastN = 64; - /// penalize newline tokens when applying the repeat penalty. - [Tooltip("penalize newline tokens when applying the repeat penalty.")] - [ModelAdvanced] public bool penalizeNl = true; - /// prompt for the purpose of the penalty evaluation. Can be either null, a string or an array of numbers representing tokens (null/'' = use original prompt) - [Tooltip("prompt for the purpose of the penalty evaluation. Can be either null, a string or an array of numbers representing tokens (null/'' = use original prompt)")] - [ModelAdvanced] public string penaltyPrompt; - /// enable Mirostat sampling, controlling perplexity during text generation (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0). - [Tooltip("enable Mirostat sampling, controlling perplexity during text generation (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).")] - [ModelAdvanced, Int(0, 2)] public int mirostat = 0; - /// The Mirostat target entropy (tau) controls the balance between coherence and diversity in the generated text. - [Tooltip("The Mirostat target entropy (tau) controls the balance between coherence and diversity in the generated text.")] - [ModelAdvanced, Float(0f, 10f)] public float mirostatTau = 5f; - /// The Mirostat learning rate (eta) controls how quickly the algorithm responds to feedback from the generated text. - [Tooltip("The Mirostat learning rate (eta) controls how quickly the algorithm responds to feedback from the generated text.")] - [ModelAdvanced, Float(0f, 1f)] public float mirostatEta = 0.1f; - /// if greater than 0, the response also contains the probabilities of top N tokens for each generated token. - [Tooltip("if greater than 0, the response also contains the probabilities of top N tokens for each generated token.")] - [ModelAdvanced, Int(0, 10)] public int nProbs = 0; - /// ignore end of stream token and continue generating. - [Tooltip("ignore end of stream token and continue generating.")] - [ModelAdvanced] public bool ignoreEos = false; - /// number of tokens to retain from the prompt when the model runs out of context (-1 = LLMCharacter prompt tokens if setNKeepToPrompt is set to true). - [Tooltip("number of tokens to retain from the prompt when the model runs out of context (-1 = LLMCharacter prompt tokens if setNKeepToPrompt is set to true).")] - public int nKeep = -1; - /// stopwords to stop the LLM in addition to the default stopwords from the chat template. - [Tooltip("stopwords to stop the LLM in addition to the default stopwords from the chat template.")] - public List stop = new List(); - /// the logit bias option allows to manually adjust the likelihood of specific tokens appearing in the generated text. - /// By providing a token ID and a positive or negative bias value, you can increase or decrease the probability of that token being generated. - [Tooltip("the logit bias option allows to manually adjust the likelihood of specific tokens appearing in the generated text. By providing a token ID and a positive or negative bias value, you can increase or decrease the probability of that token being generated.")] - public Dictionary logitBias = null; - /// Receive the reply from the model as it is produced (recommended!). - /// If not selected, the full reply from the model is received in one go - [Tooltip("Receive the reply from the model as it is produced (recommended!). If not selected, the full reply from the model is received in one go")] - [Chat] public bool stream = true; - /// the name of the player - [Tooltip("the name of the player")] - [Chat] public string playerName = "user"; - /// the name of the AI - [Tooltip("the name of the AI")] - [Chat] public string AIName = "assistant"; - /// a description of the AI role (system prompt) - [Tooltip("a description of the AI role (system prompt)")] - [TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."; - /// set the number of tokens to always retain from the prompt (nKeep) based on the LLMCharacter system prompt - [Tooltip("set the number of tokens to always retain from the prompt (nKeep) based on the LLMCharacter system prompt")] - public bool setNKeepToPrompt = true; - /// the chat history as list of chat messages - [Tooltip("the chat history as list of chat messages")] - public List chat = new List(); - /// the grammar to use - [Tooltip("the grammar to use")] - public string grammarString; - /// the grammar to use - [Tooltip("the grammar to use")] - public string grammarJSONString; - - /// \cond HIDE - protected SemaphoreSlim chatLock = new SemaphoreSlim(1, 1); - protected string chatTemplate; - protected ChatTemplate template = null; - /// \endcond - - /// - /// The Unity Awake function that initializes the state before the application starts. - /// The following actions are executed: - /// - the corresponding LLM server is defined (if ran locally) - /// - the grammar is set based on the grammar file - /// - the prompt and chat history are initialised - /// - the chat template is constructed - /// - the number of tokens to keep are based on the system prompt (if setNKeepToPrompt=true) - /// - public override void Awake() - { - if (!enabled) return; - base.Awake(); - if (!remote) - { - int slotFromServer = llm.Register(this); - if (slot == -1) slot = slotFromServer; - } - InitGrammar(); - InitHistory(); - } - - protected override void OnValidate() - { - base.OnValidate(); - if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set"); - } - - protected override string NotValidLLMError() - { - return base.NotValidLLMError() + $", it is an embedding only model"; - } - - /// - /// Checks if a LLM is valid for the LLMCaller - /// - /// LLM object - /// bool specifying whether the LLM is valid - public override bool IsValidLLM(LLM llmSet) - { - return !llmSet.embeddingsOnly; - } - - protected virtual void InitHistory() - { - ClearChat(); - _ = LoadHistory(); - } - - protected virtual async Task LoadHistory() - { - if (save == "" || !File.Exists(GetJsonSavePath(save))) return; - await chatLock.WaitAsync(); // Acquire the lock - try - { - await Load(save); - } - finally - { - chatLock.Release(); // Release the lock - } - } - - protected virtual string GetSavePath(string filename) - { - return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/'); - } - - /// - /// Allows to get the save path of the chat history based on the provided filename or relative path. - /// - /// filename or relative path used for the save - /// save path - public virtual string GetJsonSavePath(string filename) - { - return GetSavePath(filename + ".json"); - } - - /// - /// Allows to get the save path of the LLM cache based on the provided filename or relative path. - /// - /// filename or relative path used for the save - /// save path - public virtual string GetCacheSavePath(string filename) - { - return GetSavePath(filename + ".cache"); - } - - /// - /// Clear the chat of the LLMCharacter. - /// - public virtual void ClearChat() - { - chat.Clear(); - ChatMessage promptMessage = new ChatMessage { role = "system", content = prompt }; - chat.Add(promptMessage); - } - - /// - /// Set the system prompt for the LLMCharacter. - /// - /// the system prompt - /// whether to clear (true) or keep (false) the current chat history on top of the system prompt. - public virtual void SetPrompt(string newPrompt, bool clearChat = true) - { - prompt = newPrompt; - nKeep = -1; - if (clearChat) ClearChat(); - else chat[0] = new ChatMessage { role = "system", content = prompt }; - } - - protected virtual bool CheckTemplate() - { - if (template == null) - { - LLMUnitySetup.LogError("Template not set!"); - return false; - } - return true; - } - - protected virtual async Task InitNKeep() - { - if (setNKeepToPrompt && nKeep == -1) - { - if (!CheckTemplate()) return false; - string systemPrompt = template.ComputePrompt(new List(){chat[0]}, playerName, "", false); - List tokens = await Tokenize(systemPrompt); - if (tokens == null) return false; - SetNKeep(tokens); - } - return true; - } - - protected virtual void InitGrammar() - { - grammarString = ""; - grammarJSONString = ""; - if (!String.IsNullOrEmpty(grammar)) - { - grammarString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammar)); - if (!String.IsNullOrEmpty(grammarJSON)) - LLMUnitySetup.LogWarning("Both GBNF and JSON grammars are set, only the GBNF will be used"); - } - else if (!String.IsNullOrEmpty(grammarJSON)) - { - grammarJSONString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammarJSON)); - } - } - - protected virtual void SetNKeep(List tokens) - { - // set the tokens to keep - nKeep = tokens.Count; - } - - /// - /// Loads the chat template of the LLMCharacter. - /// - /// - public virtual async Task LoadTemplate() - { - string llmTemplate; - if (remote) - { - llmTemplate = await AskTemplate(); - } - else - { - llmTemplate = llm.GetTemplate(); - } - if (llmTemplate != chatTemplate) - { - chatTemplate = llmTemplate; - template = chatTemplate == null ? null : ChatTemplate.GetTemplate(chatTemplate); - nKeep = -1; - } - } - - /// - /// Sets the grammar file of the LLMCharacter - /// - /// path to the grammar file - public virtual async Task SetGrammarFile(string path, bool gnbf) - { -#if UNITY_EDITOR - if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path); -#endif - await LLMUnitySetup.AndroidExtractAsset(path, true); - if (gnbf) grammar = path; - else grammarJSON = path; - InitGrammar(); - } - - /// - /// Sets the grammar file of the LLMCharacter (GBNF) - /// - /// path to the grammar file - public virtual async Task SetGrammar(string path) - { - await SetGrammarFile(path, true); - } - - /// - /// Sets the grammar file of the LLMCharacter (JSON schema) - /// - /// path to the grammar file - public virtual async Task SetJSONGrammar(string path) - { - await SetGrammarFile(path, false); - } - - protected virtual List GetStopwords() - { - if (!CheckTemplate()) return null; - List stopAll = new List(template.GetStop(playerName, AIName)); - if (stop != null) stopAll.AddRange(stop); - return stopAll; - } - - protected virtual ChatRequest GenerateRequest(string prompt) - { - // setup the request struct - ChatRequest chatRequest = new ChatRequest(); - if (debugPrompt) LLMUnitySetup.Log(prompt); - chatRequest.prompt = prompt; - chatRequest.id_slot = slot; - chatRequest.temperature = temperature; - chatRequest.top_k = topK; - chatRequest.top_p = topP; - chatRequest.min_p = minP; - chatRequest.n_predict = numPredict; - chatRequest.n_keep = nKeep; - chatRequest.stream = stream; - chatRequest.stop = GetStopwords(); - chatRequest.typical_p = typicalP; - chatRequest.repeat_penalty = repeatPenalty; - chatRequest.repeat_last_n = repeatLastN; - chatRequest.penalize_nl = penalizeNl; - chatRequest.presence_penalty = presencePenalty; - chatRequest.frequency_penalty = frequencyPenalty; - chatRequest.penalty_prompt = (penaltyPrompt != null && penaltyPrompt != "") ? penaltyPrompt : null; - chatRequest.mirostat = mirostat; - chatRequest.mirostat_tau = mirostatTau; - chatRequest.mirostat_eta = mirostatEta; - chatRequest.grammar = grammarString; - chatRequest.json_schema = grammarJSONString; - chatRequest.seed = seed; - chatRequest.ignore_eos = ignoreEos; - chatRequest.logit_bias = logitBias; - chatRequest.n_probs = nProbs; - chatRequest.cache_prompt = cachePrompt; - return chatRequest; - } - - /// - /// Allows to add a message in the chat history. - /// - /// message role (e.g. playerName or AIName) - /// message content - public virtual void AddMessage(string role, string content) - { - // add the question / answer to the chat list, update prompt - chat.Add(new ChatMessage { role = role, content = content }); - } - - /// - /// Allows to add a player message in the chat history. - /// - /// message content - public virtual void AddPlayerMessage(string content) - { - AddMessage(playerName, content); - } - - /// - /// Allows to add a AI message in the chat history. - /// - /// message content - public virtual void AddAIMessage(string content) - { - AddMessage(AIName, content); - } - - protected virtual string ChatContent(ChatResult result) - { - // get content from a chat result received from the endpoint - return result.content.Trim(); - } - - protected virtual string MultiChatContent(MultiChatResult result) - { - // get content from a chat result received from the endpoint - string response = ""; - foreach (ChatResult resultPart in result.data) - { - response += resultPart.content; - } - return response.Trim(); - } - - protected virtual string SlotContent(SlotResult result) - { - // get the tokens from a tokenize result received from the endpoint - return result.filename; - } - - protected virtual string TemplateContent(TemplateResult result) - { - // get content from a char result received from the endpoint in open AI format - return result.template; - } - - protected virtual string ChatRequestToJson(ChatRequest request) - { - string json = JsonUtility.ToJson(request); - int grammarIndex = json.LastIndexOf('}'); - if (!String.IsNullOrEmpty(request.grammar)) - { - GrammarWrapper grammarWrapper = new GrammarWrapper { grammar = request.grammar }; - string grammarToJSON = JsonUtility.ToJson(grammarWrapper); - int start = grammarToJSON.IndexOf(":\"") + 2; - int end = grammarToJSON.LastIndexOf("\""); - string grammarSerialised = grammarToJSON.Substring(start, end - start); - json = json.Insert(grammarIndex, $",\"grammar\": \"{grammarSerialised}\""); - } - else if (!String.IsNullOrEmpty(request.json_schema)) - { - json = json.Insert(grammarIndex, $",\"json_schema\":{request.json_schema}"); - } - return json; - } - - protected virtual async Task CompletionRequest(ChatRequest request, Callback callback = null) - { - string json = ChatRequestToJson(request); - string result = ""; - if (stream) - { - result = await PostRequest(json, "completion", MultiChatContent, callback); - } - else - { - result = await PostRequest(json, "completion", ChatContent, callback); - } - return result; - } - - protected async Task PromptWithQuery(string query) - { - ChatRequest result = default; - await chatLock.WaitAsync(); - try - { - AddPlayerMessage(query); - string prompt = template.ComputePrompt(chat, playerName, AIName); - result = GenerateRequest(prompt); - chat.RemoveAt(chat.Count - 1); - } - finally - { - chatLock.Release(); - } - return result; - } - - /// - /// Chat functionality of the LLM. - /// It calls the LLM completion based on the provided query including the previous chat history. - /// The function allows callbacks when the response is partially or fully received. - /// The question is added to the history if specified. - /// - /// user query - /// callback function that receives the response as string - /// callback function called when the full response has been received - /// whether to add the user query to the chat history - /// the LLM response - public virtual async Task Chat(string query, Callback callback = null, EmptyCallback completionCallback = null, bool addToHistory = true) - { - // handle a chat message by the user - // call the callback function while the answer is received - // call the completionCallback function when the answer is fully received - await LoadTemplate(); - if (!CheckTemplate()) return null; - if (!await InitNKeep()) return null; - - ChatRequest request = await PromptWithQuery(query); - string result = await CompletionRequest(request, callback); - - if (addToHistory && result != null) - { - await chatLock.WaitAsync(); - try - { - AddPlayerMessage(query); - AddAIMessage(result); - } - finally - { - chatLock.Release(); - } - if (save != "") _ = Save(save); - } - - completionCallback?.Invoke(); - return result; - } - - /// - /// Pure completion functionality of the LLM. - /// It calls the LLM completion based solely on the provided prompt (no formatting by the chat template). - /// The function allows callbacks when the response is partially or fully received. - /// - /// user query - /// callback function that receives the response as string - /// callback function called when the full response has been received - /// the LLM response - public virtual async Task Complete(string prompt, Callback callback = null, EmptyCallback completionCallback = null) - { - // handle a completion request by the user - // call the callback function while the answer is received - // call the completionCallback function when the answer is fully received - await LoadTemplate(); - - ChatRequest request = GenerateRequest(prompt); - string result = await CompletionRequest(request, callback); - completionCallback?.Invoke(); - return result; - } - - /// - /// Allow to warm-up a model by processing the system prompt. - /// The prompt processing will be cached (if cachePrompt=true) allowing for faster initialisation. - /// The function allows a callback function for when the prompt is processed and the response received. - /// - /// callback function called when the full response has been received - /// the LLM response - public virtual async Task Warmup(EmptyCallback completionCallback = null) - { - await Warmup(null, completionCallback); - } - - /// - /// Allow to warm-up a model by processing the provided prompt without adding it to history. - /// The prompt processing will be cached (if cachePrompt=true) allowing for faster initialisation. - /// The function allows a callback function for when the prompt is processed and the response received. - /// - /// - /// user prompt used during the initialisation (not added to history) - /// callback function called when the full response has been received - /// the LLM response - public virtual async Task Warmup(string query, EmptyCallback completionCallback = null) - { - await LoadTemplate(); - if (!CheckTemplate()) return; - if (!await InitNKeep()) return; - - ChatRequest request; - if (String.IsNullOrEmpty(query)) - { - string prompt = template.ComputePrompt(chat, playerName, AIName, false); - request = GenerateRequest(prompt); - } - else - { - request = await PromptWithQuery(query); - } - - request.n_predict = 0; - await CompletionRequest(request); - completionCallback?.Invoke(); - } - - /// - /// Asks the LLM for the chat template to use. - /// - /// the chat template of the LLM - public virtual async Task AskTemplate() - { - return await PostRequest("{}", "template", TemplateContent); - } - - protected override void CancelRequestsLocal() - { - if (slot >= 0) llm.CancelRequest(slot); - } - - protected virtual async Task Slot(string filepath, string action) - { - SlotRequest slotRequest = new SlotRequest(); - slotRequest.id_slot = slot; - slotRequest.filepath = filepath; - slotRequest.action = action; - string json = JsonUtility.ToJson(slotRequest); - return await PostRequest(json, "slots", SlotContent); - } - - /// - /// Saves the chat history and cache to the provided filename / relative path. - /// - /// filename / relative path to save the chat history - /// - public virtual async Task Save(string filename) - { - string filepath = GetJsonSavePath(filename); - string dirname = Path.GetDirectoryName(filepath); - if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname); - string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) }); - File.WriteAllText(filepath, json); - - string cachepath = GetCacheSavePath(filename); - if (remote || !saveCache) return null; - string result = await Slot(cachepath, "save"); - return result; - } - - /// - /// Load the chat history and cache from the provided filename / relative path. - /// - /// filename / relative path to load the chat history from - /// - public virtual async Task Load(string filename) - { - string filepath = GetJsonSavePath(filename); - if (!File.Exists(filepath)) - { - LLMUnitySetup.LogError($"File {filepath} does not exist."); - return null; - } - string json = File.ReadAllText(filepath); - List chatHistory = JsonUtility.FromJson(json).chat; - ClearChat(); - chat.AddRange(chatHistory); - LLMUnitySetup.Log($"Loaded {filepath}"); - - string cachepath = GetCacheSavePath(filename); - if (remote || !saveCache || !File.Exists(GetSavePath(cachepath))) return null; - string result = await Slot(cachepath, "restore"); - return result; - } - - protected override async Task PostRequestLocal(string json, string endpoint, ContentCallback getContent, Callback callback = null) - { - if (endpoint != "completion") return await base.PostRequestLocal(json, endpoint, getContent, callback); - - while (!llm.failed && !llm.started) await Task.Yield(); - - string callResult = null; - bool callbackCalled = false; - if (llm.embeddingsOnly) LLMUnitySetup.LogError("The LLM can't be used for completion, only for embeddings"); - else - { - Callback callbackString = null; - if (stream && callback != null) - { - if (typeof(Ret) == typeof(string)) - { - callbackString = (strArg) => - { - callback(ConvertContent(strArg, getContent)); - }; - } - else - { - LLMUnitySetup.LogError($"wrong callback type, should be string"); - } - callbackCalled = true; - } - callResult = await llm.Completion(json, callbackString); - } - - Ret result = ConvertContent(callResult, getContent); - if (!callbackCalled) callback?.Invoke(result); - return result; - } - } - - /// \cond HIDE - [Serializable] - public class ChatListWrapper - { - public List chat; - } - /// \endcond -} diff --git a/Runtime/LLMChatTemplates.cs b/Runtime/LLMChatTemplates.cs deleted file mode 100644 index 2919f56c..00000000 --- a/Runtime/LLMChatTemplates.cs +++ /dev/null @@ -1,710 +0,0 @@ -/// @file -/// @brief File implementing the chat templates. -using System.Collections.Generic; -using System.IO; - -namespace LLMUnity -{ - /// @ingroup template - /// - /// Class implementing the skeleton of a chat template - /// - public abstract class ChatTemplate - { - /// the default template used when it can't be determined ("chatml") - public static string DefaultTemplate; - /// a dictionary from chat template name to chat template type. - /// It can be used to get the chat template names supported with: - /// \code - /// ChatTemplate.templates.Keys - /// \endcode - /// - public static Dictionary templates; - /// \cond HIDE - public static ChatTemplate[] templateClasses; - public static Dictionary templatesDescription; - public static Dictionary modelTemplates; - public static Dictionary chatTemplates; - /// \endcond - - static ChatTemplate() - { - DefaultTemplate = "chatml"; - - templateClasses = new ChatTemplate[] - { - new ChatMLTemplate(), - new AlpacaTemplate(), - new GemmaTemplate(), - new MistralChatTemplate(), - new MistralInstructTemplate(), - new LLama3ChatTemplate(), - new LLama2ChatTemplate(), - new LLama2Template(), - new Phi4MiniTemplate(), - new Phi4Template(), - new Phi3_5Template(), - new Phi3Template(), - new Phi2Template(), - new DeepSeekR1Template(), - new DeepSeekV3Template(), - new DeepSeekV2Template(), - new VicunaTemplate(), - new ZephyrTemplate(), - new Qwen3Template(), - new BitNetTemplate(), - }; - - templates = new Dictionary(); - templatesDescription = new Dictionary(); - modelTemplates = new Dictionary(); - chatTemplates = new Dictionary(); - foreach (ChatTemplate template in templateClasses) - { - if (templates.ContainsKey(template.GetName())) LLMUnitySetup.LogError($"{template.GetName()} already in templates"); - templates[template.GetName()] = template; - if (templatesDescription.ContainsKey(template.GetDescription())) LLMUnitySetup.LogError($"{template.GetDescription()} already in templatesDescription"); - templatesDescription[template.GetDescription()] = template.GetName(); - foreach (string match in template.GetNameMatches()) - { - if (modelTemplates.ContainsKey(match)) LLMUnitySetup.LogError($"Name for {template.GetName()} already in modelTemplates"); - modelTemplates[match] = template.GetName(); - } - foreach (string match in template.GetChatTemplateMatches()) - { - if (chatTemplates.ContainsKey(match)) LLMUnitySetup.LogError($"Chat template for {template.GetName()} already in chatTemplates"); - chatTemplates[match] = template.GetName(); - } - } - } - - /// - /// Determines the chat template name from a search name. - /// It searches if any of the chat template names is a substring of the provided name. - /// - /// search name - /// chat template name - public static string FromName(string name) - { - if (name == null) return null; - string nameLower = name.ToLower(); - int maxMatch = 0; - string match = null; - foreach (var pair in modelTemplates) - { - if (nameLower.Contains(pair.Key) && pair.Key.Length > maxMatch) - { - maxMatch = pair.Key.Length; - match = pair.Value; - } - } - return match; - } - - /// - /// Determines the chat template name from a Jinja template. - /// - /// Jinja template - /// chat template name - public static string FromTemplate(string template) - { - if (template == null) return null; - string templateTrim = template.Trim(); - if (chatTemplates.TryGetValue(templateTrim, out string value)) - return value; - return null; - } - - /// - /// Determines the chat template name from a GGUF file. - /// It reads the GGUF file and then determines the chat template name based on: - /// - the jinja template defined in the file (if it exists and matched) - /// - the model name defined in the file (if it exists and matched) - /// - the filename defined in the file (if matched) - /// - otherwises uses the DefaultTemplate - /// - /// GGUF file path - /// template name - public static string FromGGUF(string path) - { - return FromGGUF(new GGUFReader(path), path); - } - - public static string FromGGUF(GGUFReader reader, string path) - { - string name; - name = FromTemplate(reader.GetStringField("tokenizer.chat_template")); - if (name != null) return name; - - name = FromName(reader.GetStringField("general.name")); - if (name != null) return name; - - name = FromName(Path.GetFileNameWithoutExtension(path)); - if (name != null) return name; - - LLMUnitySetup.Log("No chat template could be matched, fallback to ChatML"); - return DefaultTemplate; - } - - /// - /// Creates the chat template based on the provided chat template name - /// - /// chat template name - /// chat template - public static ChatTemplate GetTemplate(string template) - { - return templates[template]; - } - - /// Returns the chat template name - public virtual string GetName() { return ""; } - /// Returns the chat template description - public virtual string GetDescription() { return ""; } - /// Returns an array of names that can be used to match the chat template - public virtual string[] GetNameMatches() { return new string[] {}; } - /// Returns an array of jinja templates that can be used to match the chat template - public virtual string[] GetChatTemplateMatches() { return new string[] {}; } - /// Returns an array of the stopwords used by the template - public virtual string[] GetStop(string playerName, string AIName) { return new string[] {}; } - - protected virtual string PromptPrefix() { return ""; } - protected virtual string SystemPrefix() { return ""; } - protected virtual string SystemSuffix() { return ""; } - protected virtual string PlayerPrefix(string playerName) { return ""; } - protected virtual string AIPrefix(string AIName) { return ""; } - protected virtual string PrefixMessageSeparator() { return ""; } - protected virtual string RequestPrefix() { return ""; } - protected virtual string RequestSuffix() { return ""; } - protected virtual string PairSuffix() { return ""; } - - protected virtual bool SystemPromptSupported() { return true; } - protected virtual bool HasThinkingMode() { return false; } - - /// Constructs the prompt using the template based on a list of ChatMessages - /// list of ChatMessages e.g. the LLMCharacter chat - /// the AI name - /// whether to end the prompt with the AI prefix - /// prompt - public virtual string ComputePrompt(List chatMessages, string playerName, string AIName, bool endWithPrefix = true) - { - List messages = chatMessages; - if (!SystemPromptSupported()) - { - if (chatMessages[0].role == "system") - { - string firstUserMessage = chatMessages[0].content; - int newStart = 1; - if (chatMessages.Count > 1) - { - if (firstUserMessage != "") firstUserMessage += "\n\n"; - firstUserMessage += chatMessages[1].content; - newStart = 2; - } - messages = new List(){new ChatMessage { role = playerName, content = firstUserMessage }}; - messages.AddRange(chatMessages.GetRange(newStart, chatMessages.Count - newStart)); - } - } - - string chatPrompt = PromptPrefix(); - int start = 0; - if (messages[0].role == "system") - { - chatPrompt += RequestPrefix() + SystemPrefix() + messages[0].content + SystemSuffix(); - start = 1; - } - for (int i = start; i < messages.Count; i += 2) - { - if (i > start || start == 0) chatPrompt += RequestPrefix(); - chatPrompt += PlayerPrefix(messages[i].role) + PrefixMessageSeparator() + messages[i].content + RequestSuffix(); - if (i < messages.Count - 1) - { - chatPrompt += AIPrefix(messages[i + 1].role) + PrefixMessageSeparator() + messages[i + 1].content + PairSuffix(); - } - } - if (endWithPrefix) - { - chatPrompt += AIPrefix(AIName); - if (HasThinkingMode()) chatPrompt += "\n\n\n\n"; - } - return chatPrompt; - } - - protected string[] AddStopNewlines(string[] stop) - { - List stopWithNewLines = new List(); - foreach (string stopword in stop) - { - stopWithNewLines.Add(stopword); - stopWithNewLines.Add("\n" + stopword); - } - return stopWithNewLines.ToArray(); - } - } - - /// @ingroup template - /// - /// Class implementing the ChatML template - /// - public class ChatMLTemplate : ChatTemplate - { - public override string GetName() { return "chatml"; } - public override string GetDescription() { return "chatml (generic template)"; } - public override string[] GetNameMatches() { return new string[] {"chatml", "hermes", "qwen"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"}; } - - protected override string SystemPrefix() { return "<|im_start|>system\n"; } - protected override string SystemSuffix() { return "<|im_end|>\n"; } - protected override string PlayerPrefix(string playerName) { return $"<|im_start|>{playerName}\n"; } - protected override string AIPrefix(string AIName) { return $"<|im_start|>{AIName}\n"; } - protected override string RequestSuffix() { return "<|im_end|>\n"; } - protected override string PairSuffix() { return "<|im_end|>\n"; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "<|im_start|>", "<|im_end|>" }); - } - } - - /// @ingroup template - /// - /// Class implementing the LLama2 template - /// - public class LLama2Template : ChatTemplate - { - public override string GetName() { return "llama"; } - public override string GetDescription() { return "llama 2"; } - - protected override string SystemPrefix() { return "<>\n"; } - protected override string SystemSuffix() { return "\n<> "; } - protected override string RequestPrefix() { return "[INST] "; } - protected override string RequestSuffix() { return " [/INST]"; } - protected override string PairSuffix() { return " "; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "[INST]", "[/INST]" }); - } - } - - /// @ingroup template - /// - /// Class implementing a modified version of the LLama2 template for chat - /// - public class LLama2ChatTemplate : LLama2Template - { - public override string GetName() { return "llama chat"; } - public override string GetDescription() { return "llama 2 (chat)"; } - public override string[] GetNameMatches() { return new string[] {"llama-2", "llama v2"}; } - - protected override string PlayerPrefix(string playerName) { return "### " + playerName + ":"; } - protected override string AIPrefix(string AIName) { return "### " + AIName + ":"; } - protected override string PrefixMessageSeparator() { return " "; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "[INST]", "[/INST]", "###" }); - } - } - - /// @ingroup template - /// - /// Class implementing the LLama3 template for chat - /// - public class LLama3ChatTemplate : ChatTemplate - { - public override string GetName() { return "llama3 chat"; } - public override string GetDescription() { return "llama 3 (chat)"; } - public override string[] GetNameMatches() { return new string[] {"llama-3", "llama v3"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"};} - - protected override string SystemPrefix() { return "<|start_header_id|>system<|end_header_id|>\n\n"; } - protected override string SystemSuffix() { return "<|eot_id|>"; } - - protected override string RequestSuffix() { return "<|eot_id|>"; } - protected override string PairSuffix() { return "<|eot_id|>"; } - - protected override string PlayerPrefix(string playerName) { return $"<|start_header_id|>{playerName}<|end_header_id|>\n\n"; } - protected override string AIPrefix(string AIName) { return $"<|start_header_id|>{AIName}<|end_header_id|>\n\n"; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "<|eot_id|>" }); - } - } - - /// @ingroup template - /// - /// Class implementing the Mistral Instruct template - /// - public class MistralInstructTemplate : ChatTemplate - { - public override string GetName() { return "mistral instruct"; } - public override string GetDescription() { return "mistral instruct"; } - - protected override string SystemPrefix() { return ""; } - protected override string SystemSuffix() { return "\n\n"; } - protected override string RequestPrefix() { return "[INST] "; } - protected override string RequestSuffix() { return " [/INST]"; } - protected override string PairSuffix() { return ""; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "", "[INST]", "[/INST]" }); - } - } - - /// @ingroup template - /// - /// Class implementing a modified version of the Mistral Instruct template for chat - /// - public class MistralChatTemplate : MistralInstructTemplate - { - public override string GetName() { return "mistral chat"; } - public override string GetDescription() { return "mistral (chat)"; } - public override string[] GetNameMatches() { return new string[] {"mistral"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"}; } - - protected override string PlayerPrefix(string playerName) { return "### " + playerName + ":"; } - protected override string AIPrefix(string AIName) { return "### " + AIName + ":"; } - protected override string PrefixMessageSeparator() { return " "; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "", "[INST]", "[/INST]", "###" }); - } - } - - /// @ingroup template - /// - /// Class implementing the Gemma template - /// - public class GemmaTemplate : ChatTemplate - { - public override string GetName() { return "gemma"; } - public override string GetDescription() { return "gemma"; } - public override string[] GetNameMatches() { return new string[] {"gemma"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}"}; } - - protected override string RequestSuffix() { return "\n"; } - protected override string PairSuffix() { return "\n"; } - - protected override string PlayerPrefix(string playerName) { return "user\n"; } - protected override string AIPrefix(string AIName) { return "model\n"; } - - protected override bool SystemPromptSupported() { return false; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "", "" }); - } - } - - /// @ingroup template - /// - /// Class implementing the Alpaca template - /// - public class AlpacaTemplate : ChatTemplate - { - public override string GetName() { return "alpaca"; } - public override string GetDescription() { return "alpaca (best alternative)"; } - public override string[] GetNameMatches() { return new string[] {"alpaca"}; } - - protected override string SystemSuffix() { return "\n\n"; } - protected override string RequestSuffix() { return "\n"; } - protected override string PlayerPrefix(string playerName) { return "### " + playerName + ":"; } - protected override string AIPrefix(string AIName) { return "### " + AIName + ":"; } - protected override string PrefixMessageSeparator() { return " "; } - protected override string PairSuffix() { return "\n"; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "###" }); - } - } - - /// @ingroup template - /// - /// Class implementing the Vicuna template - /// - public class VicunaTemplate : ChatTemplate - { - public override string GetName() { return "vicuna"; } - public override string GetDescription() { return "vicuna"; } - public override string[] GetNameMatches() { return new string[] {"vicuna"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'system' %}{{message['content'] + ' '}}{% elif message['role'] == 'user' %}{{ 'USER: ' + message['content'] + ' '}}{% elif message['role'] == 'assistant' %}{{ 'ASSISTANT: ' + message['content'] + ' '}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT: '}}{% endif %}"}; } - - protected override string SystemSuffix() { return "\n"; } - protected override string PlayerPrefix(string playerName) { return "\n" + playerName + ":"; } - protected override string AIPrefix(string AIName) { return "\n" + AIName + ":"; } - protected override string PrefixMessageSeparator() { return " "; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { playerName + ":", AIName + ":" }); - } - } - - /// @ingroup template - /// - /// Class implementing the Phi-2 template - /// - public class Phi2Template : ChatTemplate - { - public override string GetName() { return "phi"; } - public override string GetDescription() { return "phi-2"; } - public override string[] GetNameMatches() { return new string[] {"phi-2"}; } - - protected override string SystemSuffix() { return "\n\n"; } - protected override string RequestSuffix() { return "\n"; } - protected override string PlayerPrefix(string playerName) { return playerName + ":"; } - protected override string AIPrefix(string AIName) { return AIName + ":"; } - protected override string PrefixMessageSeparator() { return " "; } - protected override string PairSuffix() { return "\n"; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { playerName + ":", AIName + ":" }); - } - } - - /// @ingroup template - /// - /// Class implementing the Phi-3 template - /// - public class Phi3Template : ChatTemplate - { - public override string GetName() { return "phi-3"; } - public override string GetDescription() { return "phi-3"; } - public override string[] GetNameMatches() { return new string[] {"phi-3"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"}; } - - protected override string PlayerPrefix(string playerName) { return $"<|user|>\n"; } - protected override string AIPrefix(string AIName) { return $"<|assistant|>\n"; } - protected override string RequestSuffix() { return "<|end|>\n"; } - protected override string PairSuffix() { return "<|end|>\n"; } - - protected override bool SystemPromptSupported() { return false; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "<|end|>", "<|user|>", "<|assistant|>" }); - } - } - - /// @ingroup template - /// - /// Class implementing the Phi-4 mini template - /// - public class Phi3_5Template : ChatTemplate - { - public override string GetName() { return "phi-3.5"; } - public override string GetDescription() { return "phi-3.5"; } - public override string[] GetNameMatches() { return new string[] {"phi-3.5"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}"};} - - protected override string PlayerPrefix(string playerName) { return $"<|user|>\n"; } - protected override string AIPrefix(string AIName) { return $"<|assistant|>\n"; } - protected override string RequestSuffix() { return "<|end|>\n"; } - protected override string PairSuffix() { return "<|end|>\n"; } - protected override string SystemPrefix() { return "<|system|>\n"; } - protected override string SystemSuffix() { return "<|end|>\n"; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "<|end|>", "<|user|>", "<|assistant|>" }); - } - } - - /// @ingroup template - /// - /// Class implementing the Phi-4 mini template - /// - public class Phi4MiniTemplate : ChatTemplate - { - public override string GetName() { return "phi-4-mini"; } - public override string GetDescription() { return "phi-4-mini"; } - public override string[] GetNameMatches() { return new string[] {"phi-4-mini"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{% for message in messages %}{% if message['role'] == 'system' and 'tools' in message and message['tools'] is not none %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|tool|>' + message['tools'] + '<|/tool|>' + '<|end|>' }}{% else %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|end|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}"};} - - protected override string PlayerPrefix(string playerName) { return $"<|user|>"; } - protected override string AIPrefix(string AIName) { return $"<|assistant|>"; } - protected override string RequestSuffix() { return "<|end|>"; } - protected override string PairSuffix() { return "<|end|>"; } - protected override string SystemPrefix() { return "<|system|>"; } - protected override string SystemSuffix() { return "<|end|>"; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "<|end|>", "<|user|>", "<|assistant|>" }); - } - } - - /// @ingroup template - /// - /// Class implementing the Phi-4 template - /// - public class Phi4Template : ChatTemplate - { - public override string GetName() { return "phi-4"; } - public override string GetDescription() { return "phi-4"; } - public override string[] GetNameMatches() { return new string[] {"phi-4"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>' + message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}"};} - - protected override string PlayerPrefix(string playerName) { return $"<|im_start|>user<|im_sep|>"; } - protected override string AIPrefix(string AIName) { return $"<|im_start|>assistant<|im_sep|>"; } - protected override string RequestSuffix() { return "<|im_end|>"; } - protected override string PairSuffix() { return "<|im_end|>"; } - protected override string SystemPrefix() { return "<|im_start|>system<|im_sep|>"; } - protected override string SystemSuffix() { return "<|im_end|>"; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "<|im_end|>", "<|im_start|>" }); - } - } - - /// @ingroup template - /// - /// Class implementing the Zephyr template - /// - public class ZephyrTemplate : ChatTemplate - { - public override string GetName() { return "zephyr"; } - public override string GetDescription() { return "zephyr"; } - public override string[] GetNameMatches() { return new string[] {"zephyr"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"}; } - - protected override string SystemPrefix() { return "<|system|>\n"; } - protected override string SystemSuffix() { return "\n"; } - protected override string PlayerPrefix(string playerName) { return $"<|user|>\n"; } - protected override string AIPrefix(string AIName) { return $"<|assistant|>\n"; } - protected override string RequestSuffix() { return "\n"; } - protected override string PairSuffix() { return "\n"; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { $"<|user|>", $"<|assistant|>" }); - } - } - - /// @ingroup template - /// - /// Class implementing the DeepSeek V2 template - /// - public class DeepSeekV2Template : ChatTemplate - { - public override string GetName() { return "deepseek-v2"; } - public override string GetDescription() { return "deepseek-v2"; } - public override string[] GetNameMatches() { return new string[] {"deepseek-v2", "deepseek-llm"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"}; } - - protected override string PrefixMessageSeparator() { return " "; } - protected override string PromptPrefix() { return "<|begin▁of▁sentence|>"; } - protected override string PlayerPrefix(string playerName) { return "User:"; } - protected override string AIPrefix(string AIName) { return "Assistant:"; } - protected override string PairSuffix() { return "<|end▁of▁sentence|>"; } - protected override string RequestSuffix() { return "\n\n"; } - protected override string SystemSuffix() { return "\n\n"; } - - // protected override bool SystemPromptSupported() { return false; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "<|end▁of▁sentence|>", "User:", "Assistant:" }); - } - } - - /// @ingroup template - /// - /// Class implementing the DeepSeek V3 template - /// - public class DeepSeekV3Template : DeepSeekV2Template - { - public override string GetName() { return "deepseek-v3"; } - public override string GetDescription() { return "deepseek-v3"; } - public override string[] GetNameMatches() { return new string[] {"deepseek-v2.5", "deepseek-v3"}; } - public override string[] GetChatTemplateMatches() - { - return new string[] - { - "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{{'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}", - "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{{'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}" - }; - } - - protected override string PrefixMessageSeparator() { return ""; } - protected override string PlayerPrefix(string playerName) { return "<|User|>"; } - protected override string AIPrefix(string AIName) { return "<|Assistant|>"; } - protected override string RequestSuffix() { return ""; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "<|end▁of▁sentence|>", "<|User|>", "<|Assistant|>" }); - } - } - - /// @ingroup template - /// - /// Class implementing the DeepSeek R1 template - /// - public class DeepSeekR1Template : DeepSeekV3Template - { - public override string GetName() { return "deepseek-r1"; } - public override string GetDescription() { return "deepseek-r1"; } - public override string[] GetNameMatches() { return new string[] {"deepseek-r1"}; } - public override string[] GetChatTemplateMatches() - { - return new string[] - { - "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\\n'}}{% endif %}", - "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\\n'}}{% endif %}" - }; - } - - protected override bool HasThinkingMode() { return true; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "<|end▁of▁sentence|>", "<|User|>", "<|Assistant|>", "" }); - } - } - - - /// @ingroup template - /// - /// Class implementing the Qwen3 template - /// - public class Qwen3Template : ChatMLTemplate - { - public override string GetName() { return "qwen3"; } - public override string GetDescription() { return "qwen3"; } - public override string[] GetNameMatches() { return new string[] { "qwen3" }; } - public override string[] GetChatTemplateMatches() { return new string[] { "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for forward_message in messages %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- set message = messages[index] %}\n {%- set tool_start = '' %}\n {%- set tool_start_length = tool_start|length %}\n {%- set start_of_message = message.content[:tool_start_length] %}\n {%- set tool_end = '' %}\n {%- set tool_end_length = tool_end|length %}\n {%- set start_pos = (message.content|length) - tool_end_length %}\n {%- if start_pos < 0 %}\n {%- set start_pos = 0 %}\n {%- endif %}\n {%- set end_of_message = message.content[start_pos:] %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(start_of_message == tool_start and end_of_message == tool_end) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = (message.content.split('')|last).lstrip('\\n') %}\n {%- set reasoning_content = (message.content.split('')|first).rstrip('\\n') %}\n {%- set reasoning_content = (reasoning_content.split('')|last).lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}" }; } - - protected override bool HasThinkingMode() { return true; } - } - - /// @ingroup template - /// - /// Class implementing the BitNet template - /// - public class BitNetTemplate : ChatTemplate - { - public override string GetName() { return "bitnet"; } - public override string GetDescription() { return "bitnet"; } - public override string[] GetNameMatches() { return new string[] {"bitnet"}; } - public override string[] GetChatTemplateMatches() { return new string[] {"{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = message['role'] | capitalize + ': '+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant: ' }}{% endif %}"};} - - protected override string PlayerPrefix(string playerName) { return "User: "; } - protected override string AIPrefix(string AIName) { return "Assistant: "; } - protected override string RequestSuffix() { return "<|eot_id|>"; } - protected override string PairSuffix() { return "<|eot_id|>"; } - protected override string SystemPrefix() { return "System: "; } - protected override string SystemSuffix() { return "<|eot_id|>"; } - - public override string[] GetStop(string playerName, string AIName) - { - return AddStopNewlines(new string[] { "<|eot_id|>", "User", "Assistant" }); - } - } -} diff --git a/Runtime/LLMClient.cs b/Runtime/LLMClient.cs new file mode 100644 index 00000000..0b4c0113 --- /dev/null +++ b/Runtime/LLMClient.cs @@ -0,0 +1,578 @@ +/// @file +/// @brief File implementing the base LLM client functionality for Unity. +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using UndreamAI.LlamaLib; +using UnityEngine; +using Newtonsoft.Json.Linq; + +namespace LLMUnity +{ + /// @ingroup llm + /// + /// Unity MonoBehaviour base class for LLM client functionality. + /// Handles both local and remote LLM connections, completion parameters, + /// and provides tokenization, completion, and embedding capabilities. + /// + public class LLMClient : MonoBehaviour + { + #region Inspector Fields + /// Show/hide advanced options in the inspector + [Tooltip("Show/hide advanced options in the inspector")] + [HideInInspector] public bool advancedOptions = false; + + /// Use remote LLM server instead of local instance + [Tooltip("Use remote LLM server instead of local instance")] + [LocalRemote, SerializeField] protected bool _remote; + + /// Local LLM GameObject to connect to + [Tooltip("Local LLM GameObject to connect to")] + [Local, SerializeField] protected LLM _llm; + + /// API key for remote server authentication + [Tooltip("API key for remote server authentication")] + [Remote, SerializeField] protected string _APIKey; + + /// Hostname or IP address of remote LLM server + [Tooltip("Hostname or IP address of remote LLM server")] + [Remote, SerializeField] protected string _host = "localhost"; + + /// Port number of remote LLM server + [Tooltip("Port number of remote LLM server")] + [Remote, SerializeField] protected int _port = 13333; + + /// Grammar constraints for output formatting (GBNF or JSON schema format) + [Tooltip("Grammar constraints for output formatting (GBNF or JSON schema format)")] + [ModelAdvanced, SerializeField] protected string _grammar = ""; + + // Completion Parameters + /// Maximum tokens to generate (-1 = unlimited) + [Tooltip("Maximum tokens to generate (-1 = unlimited)")] + [Model] public int numPredict = -1; + + /// Cache processed prompts to speed up subsequent requests + [Tooltip("Cache processed prompts to speed up subsequent requests")] + [ModelAdvanced] public bool cachePrompt = true; + + /// Random seed for reproducible generation (0 = random) + [Tooltip("Random seed for reproducible generation (0 = random)")] + [ModelAdvanced] public int seed = 0; + + /// Sampling temperature (0.0 = deterministic, higher = more creative) + [Tooltip("Sampling temperature (0.0 = deterministic, higher = more creative)")] + [ModelAdvanced, Range(0f, 2f)] public float temperature = 0.2f; + + /// Top-k sampling: limit to k most likely tokens (0 = disabled) + [Tooltip("Top-k sampling: limit to k most likely tokens (0 = disabled)")] + [ModelAdvanced, Range(0, 100)] public int topK = 40; + + /// Top-p (nucleus) sampling: cumulative probability threshold (1.0 = disabled) + [Tooltip("Top-p (nucleus) sampling: cumulative probability threshold (1.0 = disabled)")] + [ModelAdvanced, Range(0f, 1f)] public float topP = 0.9f; + + /// Minimum probability threshold for token selection + [Tooltip("Minimum probability threshold for token selection")] + [ModelAdvanced, Range(0f, 1f)] public float minP = 0.05f; + + /// Penalty for repeated tokens (1.0 = no penalty) + [Tooltip("Penalty for repeated tokens (1.0 = no penalty)")] + [ModelAdvanced, Range(0f, 2f)] public float repeatPenalty = 1.1f; + + /// Presence penalty: reduce likelihood of any repeated token (0.0 = disabled) + [Tooltip("Presence penalty: reduce likelihood of any repeated token (0.0 = disabled)")] + [ModelAdvanced, Range(0f, 1f)] public float presencePenalty = 0f; + + /// Frequency penalty: reduce likelihood based on token frequency (0.0 = disabled) + [Tooltip("Frequency penalty: reduce likelihood based on token frequency (0.0 = disabled)")] + [ModelAdvanced, Range(0f, 1f)] public float frequencyPenalty = 0f; + + /// Locally typical sampling strength (1.0 = disabled) + [Tooltip("Locally typical sampling strength (1.0 = disabled)")] + [ModelAdvanced, Range(0f, 1f)] public float typicalP = 1f; + + /// Number of recent tokens to consider for repetition penalty (0 = disabled, -1 = context size) + [Tooltip("Number of recent tokens to consider for repetition penalty (0 = disabled, -1 = context size)")] + [ModelAdvanced, Range(0, 2048)] public int repeatLastN = 64; + + /// Mirostat sampling mode (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) + [Tooltip("Mirostat sampling mode (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)")] + [ModelAdvanced, Range(0, 2)] public int mirostat = 0; + + /// Mirostat target entropy (tau) - balance between coherence and diversity + [Tooltip("Mirostat target entropy (tau) - balance between coherence and diversity")] + [ModelAdvanced, Range(0f, 10f)] public float mirostatTau = 5f; + + /// Mirostat learning rate (eta) - adaptation speed + [Tooltip("Mirostat learning rate (eta) - adaptation speed")] + [ModelAdvanced, Range(0f, 1f)] public float mirostatEta = 0.1f; + + /// Include top N token probabilities in response (0 = disabled) + [Tooltip("Include top N token probabilities in response (0 = disabled)")] + [ModelAdvanced, Range(0, 10)] public int nProbs = 0; + + /// Ignore end-of-stream token and continue generating + [Tooltip("Ignore end-of-stream token and continue generating")] + [ModelAdvanced] public bool ignoreEos = false; + #endregion + + #region Public Properties + /// Whether this client uses a remote server connection + public bool remote + { + get => _remote; + set + { + if (_remote != value) + { + _remote = value; + if (started) _ = SetupLLMClient(); + } + } + } + + /// The local LLM instance (null if using remote) + public LLM llm + { + get => _llm; + set => _ = SetLLM(value); + } + + /// API key for remote server authentication + public string APIKey + { + get => _APIKey; + set + { + if (_APIKey != value) + { + _APIKey = value; + if (started) _ = SetupLLMClient(); + } + } + } + + /// Remote server hostname or IP address + public string host + { + get => _host; + set + { + if (_host != value) + { + _host = value; + if (started) _ = SetupLLMClient(); + } + } + } + + /// Remote server port number + public int port + { + get => _port; + set + { + if (_port != value) + { + _port = value; + if (started) _ = SetupLLMClient(); + } + } + } + + /// Current grammar constraints for output formatting + public string grammar + { + get => _grammar; + set => SetGrammar(value); + } + #endregion + + #region Private Fields + protected UndreamAI.LlamaLib.LLMClient llmClient; + private bool started = false; + private string completionParametersCache = ""; + #endregion + + #region Unity Lifecycle + /// + /// Unity Awake method that validates configuration and assigns local LLM if needed. + /// + public virtual void Awake() + { + if (!enabled) return; + + if (!remote) + { + AssignLLM(); + if (llm == null) + { + string error = $"No LLM assigned or detected for {GetType().Name} '{name}'!"; + LLMUnitySetup.LogError(error); + throw new InvalidOperationException(error); + } + } + } + + /// + /// Unity Start method that initializes the LLM client connection. + /// + public virtual async void Start() + { + if (!enabled) return; + await SetupLLMClient(); + started = true; + } + + protected virtual void OnValidate() + { + AssignLLM(); + } + + protected virtual void Reset() + { + AssignLLM(); + } + + #endregion + + #region Initialization + private void CheckLLMClient() + { + if (llmClient == null) + { + string error = "LLMClient not initialized"; + LLMUnitySetup.LogError(error); + throw new System.InvalidOperationException(error); + } + } + + /// + /// Sets up the underlying LLM client connection (local or remote). + /// + protected virtual async Task SetupLLMClient() + { + string exceptionMessage = ""; + try + { + if (!remote) + { + if (llm != null) await llm.WaitUntilReady(); + if (llm?.llmService == null) + { + throw new InvalidOperationException("Local LLM service is not available"); + } + llmClient = new UndreamAI.LlamaLib.LLMClient(llm.llmService); + } + else + { + llmClient = new UndreamAI.LlamaLib.LLMClient(host, port, APIKey); + } + } + catch (Exception ex) + { + exceptionMessage = ex.Message; + } + if (llmClient == null || exceptionMessage != "") + { + string error = "llmClient not initialized"; + if (exceptionMessage != "") error += ", error: " + exceptionMessage; + LLMUnitySetup.LogError(error); + throw new InvalidOperationException(error); + } + + SetGrammar(grammar); + completionParametersCache = ""; + } + + /// + /// Gets the underlying LLMLocal instance for operations requiring local access. + /// + protected virtual LLMLocal GetCaller() + { + return llmClient; + } + + /// + /// Sets the local LLM instance for this client. + /// + /// LLM instance to connect to + protected virtual async Task SetLLM(LLM llmInstance) + { + if (llmInstance == _llm) return; + + if (remote) + { + LLMUnitySetup.LogError("Cannot set LLM when client is in remote mode"); + return; + } + + _llm = llmInstance; + if (started) await SetupLLMClient(); + } + + #endregion + + #region LLM Assignment + /// + /// Determines if an LLM instance can be auto-assigned to this client. + /// Override in derived classes to implement specific assignment logic. + /// + /// LLM instance to evaluate + /// True if the LLM can be auto-assigned + public virtual bool IsAutoAssignableLLM(LLM llmInstance) + { + return true; + } + + /// + /// Automatically assigns a suitable LLM instance if none is set. + /// + protected virtual void AssignLLM() + { + if (remote || llm != null) return; + + var validLLMs = new List(); + +#if UNITY_6000_0_OR_NEWER + foreach (LLM foundLlm in FindObjectsByType(FindObjectsSortMode.None)) +#else + foreach (LLM foundLlm in FindObjectsOfType()) +#endif + { + if (IsAutoAssignableLLM(foundLlm)) + { + validLLMs.Add(foundLlm); + } + } + + if (validLLMs.Count == 0) return; + + llm = SortLLMsByBestMatch(validLLMs.ToArray())[0]; + + string message = $"Auto-assigned LLM '{llm.name}' to {GetType().Name} '{name}'"; + if (llm.gameObject.scene != gameObject.scene) + { + message += $" (from scene '{llm.gameObject.scene.name}')"; + } + LLMUnitySetup.Log(message); + } + + /// + /// Sorts LLM instances by compatibility, preferring same-scene objects and hierarchy order. + /// + protected virtual LLM[] SortLLMsByBestMatch(LLM[] llmArray) + { + LLM[] array = (LLM[])llmArray.Clone(); + for (int i = 0; i < array.Length - 1; i++) + { + bool swapped = false; + for (int j = 0; j < array.Length - i - 1; j++) + { + bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene; + bool swap = ( + (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) || + (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex()) + ); + if (swap) + { + LLM temp = array[j]; + array[j] = array[j + 1]; + array[j + 1] = temp; + swapped = true; + } + } + if (!swapped) break; + } + return array; + } + + #endregion + + #region Grammar Management + /// + /// Sets grammar constraints for structured output generation. + /// + /// Grammar in GBNF or JSON schema format + public virtual void SetGrammar(string grammarString) + { + _grammar = grammarString ?? ""; + GetCaller()?.SetGrammar(_grammar); + } + + /// + /// Loads grammar constraints from a file. + /// + /// Path to grammar file + public virtual void LoadGrammar(string path) + { + if (string.IsNullOrEmpty(path)) return; + + if (!File.Exists(path)) + { + LLMUnitySetup.LogError($"Grammar file not found: {path}"); + return; + } + + try + { + string grammarContent = File.ReadAllText(path); + SetGrammar(grammarContent); + LLMUnitySetup.Log($"Loaded grammar from: {path}"); + } + catch (Exception ex) + { + LLMUnitySetup.LogError($"Failed to load grammar file '{path}': {ex.Message}"); + } + } + + #endregion + + #region Completion Parameters + /// + /// Applies current completion parameters to the LLM client. + /// Only updates if parameters have changed since last call. + /// + protected virtual void SetCompletionParameters() + { + if (llm != null && llm.embeddingsOnly) + { + string error = "LLM can't be used for completion, it is an embeddings only model!"; + LLMUnitySetup.LogError(error); + throw new Exception(error); + } + + var parameters = new JObject + { + ["temperature"] = temperature, + ["top_k"] = topK, + ["top_p"] = topP, + ["min_p"] = minP, + ["n_predict"] = numPredict, + ["typical_p"] = typicalP, + ["repeat_penalty"] = repeatPenalty, + ["repeat_last_n"] = repeatLastN, + ["presence_penalty"] = presencePenalty, + ["frequency_penalty"] = frequencyPenalty, + ["mirostat"] = mirostat, + ["mirostat_tau"] = mirostatTau, + ["mirostat_eta"] = mirostatEta, + ["seed"] = seed, + ["ignore_eos"] = ignoreEos, + ["n_probs"] = nProbs, + ["cache_prompt"] = cachePrompt + }; + + string parametersJson = parameters.ToString(); + if (parametersJson != completionParametersCache) + { + GetCaller()?.SetCompletionParameters(parameters); + completionParametersCache = parametersJson; + } + } + + #endregion + + #region Core LLM Operations + /// + /// Converts text into a list of token IDs. + /// + /// Text to tokenize + /// Optional callback to receive the result + /// List of token IDs + public virtual List Tokenize(string query, Callback> callback = null) + { + if (string.IsNullOrEmpty(query)) + { + throw new ArgumentNullException(nameof(query)); + } + CheckLLMClient(); + + List tokens = llmClient.Tokenize(query); + callback?.Invoke(tokens); + return tokens; + } + + /// + /// Converts token IDs back to text. + /// + /// Token IDs to decode + /// Optional callback to receive the result + /// Decoded text + public virtual string Detokenize(List tokens, Callback callback = null) + { + if (tokens == null) + { + throw new ArgumentNullException(nameof(tokens)); + } + CheckLLMClient(); + + string text = llmClient.Detokenize(tokens); + callback?.Invoke(text); + return text; + } + + /// + /// Generates embedding vectors for the input text. + /// + /// Text to embed + /// Optional callback to receive the result + /// Embedding vector + public virtual List Embeddings(string query, Callback> callback = null) + { + if (string.IsNullOrEmpty(query)) + { + throw new ArgumentNullException(nameof(query)); + } + CheckLLMClient(); + + List embeddings = llmClient.Embeddings(query); + callback?.Invoke(embeddings); + return embeddings; + } + + /// + /// Generates text completion for the given prompt. + /// + /// Input prompt text + /// Optional streaming callback for partial responses + /// Slot ID for the request (-1 for auto-assignment) + /// Generated completion text + public virtual string Completion(string prompt, LlamaLib.CharArrayCallback callback = null, int id_slot = -1) + { + CheckLLMClient(); + SetCompletionParameters(); + return llmClient.Completion(prompt, callback, id_slot); + } + + /// + /// Generates text completion asynchronously. + /// + /// Input prompt text + /// Optional streaming callback for partial responses + /// Optional callback when completion finishes + /// Slot ID for the request (-1 for auto-assignment) + /// Task that returns the generated completion text + public virtual async Task CompletionAsync(string prompt, LlamaLib.CharArrayCallback callback = null, + EmptyCallback completionCallback = null, int id_slot = -1) + { + CheckLLMClient(); + SetCompletionParameters(); + string result = await llmClient.CompletionAsync(prompt, callback, id_slot); + completionCallback?.Invoke(); + return result; + } + + /// + /// Cancels an active request in the specified slot. + /// + /// Slot ID of the request to cancel + public void CancelRequest(int id_slot) + { + llmClient?.Cancel(id_slot); + } + + #endregion + } +} diff --git a/Runtime/LLMCaller.cs.meta b/Runtime/LLMClient.cs.meta similarity index 100% rename from Runtime/LLMCaller.cs.meta rename to Runtime/LLMClient.cs.meta diff --git a/Runtime/LLMEmbedder.cs b/Runtime/LLMEmbedder.cs index 1a7962b5..0bbf54e3 100644 --- a/Runtime/LLMEmbedder.cs +++ b/Runtime/LLMEmbedder.cs @@ -1,5 +1,6 @@ /// @file /// @brief File implementing the LLM embedder. +using System.Threading.Tasks; using UnityEngine; namespace LLMUnity @@ -9,11 +10,11 @@ namespace LLMUnity /// /// Class implementing the LLM embedder. /// - public class LLMEmbedder : LLMCaller + public class LLMEmbedder : LLMClient { - protected override void SetLLM(LLM llmSet) + protected override async Task SetLLM(LLM llmSet) { - base.SetLLM(llmSet); + await base.SetLLM(llmSet); if (llmSet != null && !llmSet.embeddingsOnly) { LLMUnitySetup.LogWarning($"The LLM {llmSet.name} set for LLMEmbeddings {gameObject.name} is not an embeddings-only model, accuracy may be sub-optimal"); diff --git a/Runtime/LLMInterface.cs b/Runtime/LLMInterface.cs deleted file mode 100644 index 748131ff..00000000 --- a/Runtime/LLMInterface.cs +++ /dev/null @@ -1,158 +0,0 @@ -/// @file -/// @brief File implementing the LLM server interfaces. -using System; -using System.Collections.Generic; - -/// \cond HIDE -namespace LLMUnity -{ - [Serializable] - public struct ChatRequest - { - public string prompt; - public int id_slot; - public float temperature; - public int top_k; - public float top_p; - public float min_p; - public int n_predict; - public int n_keep; - public bool stream; - public List stop; - public float tfs_z; - public float typical_p; - public float repeat_penalty; - public int repeat_last_n; - public bool penalize_nl; - public float presence_penalty; - public float frequency_penalty; - public string penalty_prompt; - public int mirostat; - public float mirostat_tau; - public float mirostat_eta; - // EXCLUDE grammars from JsonUtility serialization, serialise them manually - [NonSerialized] public string grammar; - [NonSerialized] public string json_schema; - public int seed; - public bool ignore_eos; - public Dictionary logit_bias; - public int n_probs; - public bool cache_prompt; - public List messages; - } - - [Serializable] - public struct GrammarWrapper - { - public string grammar; - } - - [Serializable] - public struct SystemPromptRequest - { - public string prompt; - public string system_prompt; - public int n_predict; - } - - [Serializable] - public struct ChatResult - { - public int id_slot; - public string content; - public bool stop; - public string generation_settings; - public string model; - public string prompt; - public bool stopped_eos; - public bool stopped_limit; - public bool stopped_word; - public string stopping_word; - public string timings; - public int tokens_cached; - public int tokens_evaluated; - public bool truncated; - public bool cache_prompt; - public bool system_prompt; - } - - [Serializable] - public struct MultiChatResult - { - public List data; - } - - [Serializable] - public struct ChatMessage - { - public string role; - public string content; - } - - [Serializable] - public struct TokenizeRequest - { - public string content; - } - - [Serializable] - public struct TokenizeResult - { - public List tokens; - } - - [Serializable] - public struct EmbeddingsResult - { - public List embedding; - } - - [Serializable] - public struct LoraWeightRequest - { - public int id; - public float scale; - } - - [Serializable] - public struct LoraWeightRequestList - { - public List loraWeights; - } - - [Serializable] - public struct LoraWeightResult - { - public int id; - public string path; - public float scale; - } - - [Serializable] - public struct LoraWeightResultList - { - public List loraWeights; - } - - [Serializable] - public struct TemplateResult - { - public string template; - } - - [Serializable] - public struct SlotRequest - { - public int id_slot; - public string action; - public string filepath; - } - - [Serializable] - public struct SlotResult - { - public int id_slot; - public string filename; - } -} -/// \endcond diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs deleted file mode 100644 index f7417453..00000000 --- a/Runtime/LLMLib.cs +++ /dev/null @@ -1,727 +0,0 @@ -/// @file -/// @brief File implementing the LLM library calls. -/// \cond HIDE -using System; -using System.Collections.Generic; -using System.IO; -using System.Runtime.InteropServices; -using UnityEngine; - -namespace LLMUnity -{ - /// @ingroup utils - /// - /// Class implementing a wrapper for a communication stream between Unity and the llama.cpp library (mainly for completion calls and logging). - /// - public class StreamWrapper - { - LLMLib llmlib; - Callback callback; - IntPtr stringWrapper; - string previousString = ""; - string previousCalledString = ""; - int previousBufferSize = 0; - bool clearOnUpdate; - - public StreamWrapper(LLMLib llmlib, Callback callback, bool clearOnUpdate = false) - { - this.llmlib = llmlib; - this.callback = callback; - this.clearOnUpdate = clearOnUpdate; - stringWrapper = (llmlib?.StringWrapper_Construct()).GetValueOrDefault(); - } - - /// - /// Retrieves the content of the stream - /// - /// whether to clear the stream after retrieving the content - /// stream content - public string GetString(bool clear = false) - { - string result; - int bufferSize = (llmlib?.StringWrapper_GetStringSize(stringWrapper)).GetValueOrDefault(); - if (bufferSize <= 1) - { - result = ""; - } - else if (previousBufferSize != bufferSize) - { - IntPtr buffer = Marshal.AllocHGlobal(bufferSize); - try - { - llmlib?.StringWrapper_GetString(stringWrapper, buffer, bufferSize, clear); - result = Marshal.PtrToStringAnsi(buffer); - } - finally - { - Marshal.FreeHGlobal(buffer); - } - previousString = result; - } - else - { - result = previousString; - } - previousBufferSize = bufferSize; - return result; - } - - /// - /// Unity Update implementation that retrieves the content and calls the callback if it has changed. - /// - public void Update() - { - if (stringWrapper == IntPtr.Zero) return; - string result = GetString(clearOnUpdate); - if (result != "" && previousCalledString != result) - { - callback?.Invoke(result); - previousCalledString = result; - } - } - - /// - /// Gets the stringWrapper object to pass to the library. - /// - /// stringWrapper object - public IntPtr GetStringWrapper() - { - return stringWrapper; - } - - /// - /// Deletes the stringWrapper object. - /// - public void Destroy() - { - if (stringWrapper != IntPtr.Zero) llmlib?.StringWrapper_Delete(stringWrapper); - } - } - - /// @ingroup utils - /// - /// Class implementing a library loader for Unity. - /// Adapted from SkiaForUnity: - /// https://github.com/ammariqais/SkiaForUnity/blob/f43322218c736d1c41f3a3df9355b90db4259a07/SkiaUnity/Assets/SkiaSharp/SkiaSharp-Bindings/SkiaSharp.HarfBuzz.Shared/HarfBuzzSharp.Shared/LibraryLoader.cs - /// - static class LibraryLoader - { - /// - /// Allows to retrieve a function delegate for the library - /// - /// type to cast the function - /// library handle - /// function name - /// function delegate - public static T GetSymbolDelegate(IntPtr library, string name) where T : Delegate - { - var symbol = GetSymbol(library, name); - if (symbol == IntPtr.Zero) - throw new EntryPointNotFoundException($"Unable to load symbol '{name}'."); - - return Marshal.GetDelegateForFunctionPointer(symbol); - } - - /// - /// Loads the provided library in a cross-platform manner - /// - /// library path - /// library handle - public static IntPtr LoadLibrary(string libraryName) - { - if (string.IsNullOrEmpty(libraryName)) - throw new ArgumentNullException(nameof(libraryName)); - - IntPtr handle; - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer) - handle = Win32.LoadLibrary(libraryName); - else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer) - handle = Linux.dlopen(libraryName); - else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer || Application.platform == RuntimePlatform.OSXServer) - handle = Mac.dlopen(libraryName); - else if (Application.platform == RuntimePlatform.Android || Application.platform == RuntimePlatform.IPhonePlayer || Application.platform == RuntimePlatform.VisionOS) - handle = Mobile.dlopen(libraryName); - else - throw new PlatformNotSupportedException($"Current platform is unknown, unable to load library '{libraryName}'."); - - return handle; - } - - /// - /// Retrieve a function delegate for the library in a cross-platform manner - /// - /// library handle - /// function name - /// function handle - public static IntPtr GetSymbol(IntPtr library, string symbolName) - { - if (string.IsNullOrEmpty(symbolName)) - throw new ArgumentNullException(nameof(symbolName)); - - IntPtr handle; - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer) - handle = Win32.GetProcAddress(library, symbolName); - else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer) - handle = Linux.dlsym(library, symbolName); - else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer || Application.platform == RuntimePlatform.OSXServer) - handle = Mac.dlsym(library, symbolName); - else if (Application.platform == RuntimePlatform.Android || Application.platform == RuntimePlatform.IPhonePlayer || Application.platform == RuntimePlatform.VisionOS) - handle = Mobile.dlsym(library, symbolName); - else - throw new PlatformNotSupportedException($"Current platform is unknown, unable to load symbol '{symbolName}' from library {library}."); - - return handle; - } - - /// - /// Frees up the library - /// - /// library handle - public static void FreeLibrary(IntPtr library) - { - if (library == IntPtr.Zero) - return; - - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer) - Win32.FreeLibrary(library); - else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer) - Linux.dlclose(library); - else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer || Application.platform == RuntimePlatform.OSXServer) - Mac.dlclose(library); - else if (Application.platform == RuntimePlatform.Android || Application.platform == RuntimePlatform.IPhonePlayer || Application.platform == RuntimePlatform.VisionOS) - Mobile.dlclose(library); - else - throw new PlatformNotSupportedException($"Current platform is unknown, unable to close library '{library}'."); - } - - private static class Mac - { - private const string SystemLibrary = "/usr/lib/libSystem.dylib"; - - private const int RTLD_LAZY = 1; - private const int RTLD_NOW = 2; - - public static IntPtr dlopen(string path, bool lazy = true) => - dlopen(path, lazy ? RTLD_LAZY : RTLD_NOW); - - [DllImport(SystemLibrary)] - public static extern IntPtr dlopen(string path, int mode); - - [DllImport(SystemLibrary)] - public static extern IntPtr dlsym(IntPtr handle, string symbol); - - [DllImport(SystemLibrary)] - public static extern void dlclose(IntPtr handle); - } - - private static class Linux - { - private const string SystemLibrary = "libdl.so"; - private const string SystemLibrary2 = "libdl.so.2"; // newer Linux distros use this - - private const int RTLD_LAZY = 1; - private const int RTLD_NOW = 2; - - private static bool UseSystemLibrary2 = true; - - public static IntPtr dlopen(string path, bool lazy = true) - { - try - { - return dlopen2(path, lazy ? RTLD_LAZY : RTLD_NOW); - } - catch (DllNotFoundException) - { - UseSystemLibrary2 = false; - return dlopen1(path, lazy ? RTLD_LAZY : RTLD_NOW); - } - } - - public static IntPtr dlsym(IntPtr handle, string symbol) - { - return UseSystemLibrary2 ? dlsym2(handle, symbol) : dlsym1(handle, symbol); - } - - public static void dlclose(IntPtr handle) - { - if (UseSystemLibrary2) - dlclose2(handle); - else - dlclose1(handle); - } - - [DllImport(SystemLibrary, EntryPoint = "dlopen")] - private static extern IntPtr dlopen1(string path, int mode); - - [DllImport(SystemLibrary, EntryPoint = "dlsym")] - private static extern IntPtr dlsym1(IntPtr handle, string symbol); - - [DllImport(SystemLibrary, EntryPoint = "dlclose")] - private static extern void dlclose1(IntPtr handle); - - [DllImport(SystemLibrary2, EntryPoint = "dlopen")] - private static extern IntPtr dlopen2(string path, int mode); - - [DllImport(SystemLibrary2, EntryPoint = "dlsym")] - private static extern IntPtr dlsym2(IntPtr handle, string symbol); - - [DllImport(SystemLibrary2, EntryPoint = "dlclose")] - private static extern void dlclose2(IntPtr handle); - } - - private static class Win32 - { - private const string SystemLibrary = "Kernel32.dll"; - - [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)] - public static extern IntPtr LoadLibrary(string lpFileName); - - [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)] - public static extern IntPtr GetProcAddress(IntPtr hModule, string lpProcName); - - [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)] - public static extern void FreeLibrary(IntPtr hModule); - } - - private static class Mobile - { - public static IntPtr dlopen(string path) => dlopen(path, 1); - -#if UNITY_ANDROID || UNITY_IOS || UNITY_VISIONOS - [DllImport("__Internal")] - public static extern IntPtr dlopen(string filename, int flags); - - [DllImport("__Internal")] - public static extern IntPtr dlsym(IntPtr handle, string symbol); - - [DllImport("__Internal")] - public static extern int dlclose(IntPtr handle); -#else - public static IntPtr dlopen(string filename, int flags) - { - return default; - } - - public static IntPtr dlsym(IntPtr handle, string symbol) - { - return default; - } - - public static int dlclose(IntPtr handle) - { - return default; - } - -#endif - } - } - - /// @ingroup utils - /// - /// Class implementing the LLM library handling - /// - public class LLMLib - { - public string architecture { get; private set; } - IntPtr libraryHandle = IntPtr.Zero; - static bool has_avx = false; - static bool has_avx2 = false; - static bool has_avx512 = false; - List dependencyHandles = new List(); - -#if (UNITY_ANDROID || UNITY_IOS || UNITY_VISIONOS) && !UNITY_EDITOR - - public LLMLib(string arch) - { - architecture = arch; - } - -#if UNITY_ANDROID - public const string LibraryName = "libundreamai_android"; -#else - public const string LibraryName = "__Internal"; -#endif - - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "Logging")] - public static extern void LoggingStatic(IntPtr stringWrapper); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "StopLogging")] - public static extern void StopLoggingStatic(); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Construct")] - public static extern IntPtr LLM_ConstructStatic(string command); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Delete")] - public static extern void LLM_DeleteStatic(IntPtr LLMObject); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_StartServer")] - public static extern void LLM_StartServerStatic(IntPtr LLMObject); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_StopServer")] - public static extern void LLM_StopServerStatic(IntPtr LLMObject); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Start")] - public static extern void LLM_StartStatic(IntPtr LLMObject); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Started")] - public static extern bool LLM_StartedStatic(IntPtr LLMObject); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Stop")] - public static extern void LLM_StopStatic(IntPtr LLMObject); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_SetTemplate")] - public static extern void LLM_SetTemplateStatic(IntPtr LLMObject, string chatTemplate); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_SetSSL")] - public static extern void LLM_SetSSLStatic(IntPtr LLMObject, string SSLCert, string SSLKey); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Tokenize")] - public static extern void LLM_TokenizeStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Detokenize")] - public static extern void LLM_DetokenizeStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Embeddings")] - public static extern void LLM_EmbeddingsStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Lora_Weight")] - public static extern void LLM_LoraWeightStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Lora_List")] - public static extern void LLM_LoraListStatic(IntPtr LLMObject, IntPtr stringWrapper); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Completion")] - public static extern void LLM_CompletionStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Slot")] - public static extern void LLM_SlotStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Cancel")] - public static extern void LLM_CancelStatic(IntPtr LLMObject, int idSlot); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Status")] - public static extern int LLM_StatusStatic(IntPtr LLMObject, IntPtr stringWrapper); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "StringWrapper_Construct")] - public static extern IntPtr StringWrapper_ConstructStatic(); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "StringWrapper_Delete")] - public static extern void StringWrapper_DeleteStatic(IntPtr instance); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "StringWrapper_GetStringSize")] - public static extern int StringWrapper_GetStringSizeStatic(IntPtr instance); - [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "StringWrapper_GetString")] - public static extern void StringWrapper_GetStringStatic(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false); - - public void Logging(IntPtr stringWrapper) { LoggingStatic(stringWrapper); } - public void StopLogging() { StopLoggingStatic(); } - public IntPtr LLM_Construct(string command) { return LLM_ConstructStatic(command); } - public void LLM_Delete(IntPtr LLMObject) { LLM_DeleteStatic(LLMObject); } - public void LLM_StartServer(IntPtr LLMObject) { LLM_StartServerStatic(LLMObject); } - public void LLM_StopServer(IntPtr LLMObject) { LLM_StopServerStatic(LLMObject); } - public void LLM_Start(IntPtr LLMObject) { LLM_StartStatic(LLMObject); } - public bool LLM_Started(IntPtr LLMObject) { return LLM_StartedStatic(LLMObject); } - public void LLM_Stop(IntPtr LLMObject) { LLM_StopStatic(LLMObject); } - public void LLM_SetTemplate(IntPtr LLMObject, string chatTemplate) { LLM_SetTemplateStatic(LLMObject, chatTemplate); } - public void LLM_SetSSL(IntPtr LLMObject, string SSLCert, string SSLKey) { LLM_SetSSLStatic(LLMObject, SSLCert, SSLKey); } - public void LLM_Tokenize(IntPtr LLMObject, string jsonData, IntPtr stringWrapper) { LLM_TokenizeStatic(LLMObject, jsonData, stringWrapper); } - public void LLM_Detokenize(IntPtr LLMObject, string jsonData, IntPtr stringWrapper) { LLM_DetokenizeStatic(LLMObject, jsonData, stringWrapper); } - public void LLM_Embeddings(IntPtr LLMObject, string jsonData, IntPtr stringWrapper) { LLM_EmbeddingsStatic(LLMObject, jsonData, stringWrapper); } - public void LLM_LoraWeight(IntPtr LLMObject, string jsonData, IntPtr stringWrapper) { LLM_LoraWeightStatic(LLMObject, jsonData, stringWrapper); } - public void LLM_LoraList(IntPtr LLMObject, IntPtr stringWrapper) { LLM_LoraListStatic(LLMObject, stringWrapper); } - public void LLM_Completion(IntPtr LLMObject, string jsonData, IntPtr stringWrapper) { LLM_CompletionStatic(LLMObject, jsonData, stringWrapper); } - public void LLM_Slot(IntPtr LLMObject, string jsonData, IntPtr stringWrapper) { LLM_SlotStatic(LLMObject, jsonData, stringWrapper); } - public void LLM_Cancel(IntPtr LLMObject, int idSlot) { LLM_CancelStatic(LLMObject, idSlot); } - public int LLM_Status(IntPtr LLMObject, IntPtr stringWrapper) { return LLM_StatusStatic(LLMObject, stringWrapper); } - public IntPtr StringWrapper_Construct() { return StringWrapper_ConstructStatic(); } - public void StringWrapper_Delete(IntPtr instance) { StringWrapper_DeleteStatic(instance); } - public int StringWrapper_GetStringSize(IntPtr instance) { return StringWrapper_GetStringSizeStatic(instance); } - public void StringWrapper_GetString(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false) { StringWrapper_GetStringStatic(instance, buffer, bufferSize, clear); } - -#else - - static bool has_avx_set = false; - static readonly object staticLock = new object(); - - static LLMLib() - { - lock (staticLock) - { - if (has_avx_set) return; - string archCheckerPath = GetArchitectureCheckerPath(); - if (archCheckerPath != null) - { - IntPtr archCheckerHandle = LibraryLoader.LoadLibrary(archCheckerPath); - if (archCheckerHandle == IntPtr.Zero) - { - LLMUnitySetup.LogError($"Failed to load library {archCheckerPath}."); - } - else - { - try - { - has_avx = LibraryLoader.GetSymbolDelegate(archCheckerHandle, "has_avx")(); - has_avx2 = LibraryLoader.GetSymbolDelegate(archCheckerHandle, "has_avx2")(); - has_avx512 = LibraryLoader.GetSymbolDelegate(archCheckerHandle, "has_avx512")(); - LibraryLoader.FreeLibrary(archCheckerHandle); - } - catch (Exception e) - { - LLMUnitySetup.LogError($"{e.GetType()}: {e.Message}"); - } - } - } - has_avx_set = true; - } - } - - /// - /// Loads the library and function handles for the defined architecture - /// - /// archtecture - /// - public LLMLib(string arch) - { - architecture = arch; - foreach (string dependency in GetArchitectureDependencies(arch)) - { - LLMUnitySetup.Log($"Loading {dependency}"); - dependencyHandles.Add(LibraryLoader.LoadLibrary(dependency)); - } - - libraryHandle = LibraryLoader.LoadLibrary(GetArchitecturePath(arch)); - if (libraryHandle == IntPtr.Zero) - { - throw new Exception($"Failed to load library {arch}."); - } - - LLM_Construct = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Construct"); - LLM_Delete = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Delete"); - LLM_StartServer = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_StartServer"); - LLM_StopServer = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_StopServer"); - LLM_Start = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Start"); - LLM_Started = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Started"); - LLM_Stop = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Stop"); - LLM_SetTemplate = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_SetTemplate"); - LLM_SetSSL = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_SetSSL"); - LLM_Tokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Tokenize"); - LLM_Detokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Detokenize"); - LLM_Embeddings = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Embeddings"); - LLM_LoraWeight = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_Weight"); - LLM_LoraList = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_List"); - LLM_Completion = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Completion"); - LLM_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Slot"); - LLM_Cancel = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Cancel"); - LLM_Status = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Status"); - StringWrapper_Construct = LibraryLoader.GetSymbolDelegate(libraryHandle, "StringWrapper_Construct"); - StringWrapper_Delete = LibraryLoader.GetSymbolDelegate(libraryHandle, "StringWrapper_Delete"); - StringWrapper_GetStringSize = LibraryLoader.GetSymbolDelegate(libraryHandle, "StringWrapper_GetStringSize"); - StringWrapper_GetString = LibraryLoader.GetSymbolDelegate(libraryHandle, "StringWrapper_GetString"); - Logging = LibraryLoader.GetSymbolDelegate(libraryHandle, "Logging"); - StopLogging = LibraryLoader.GetSymbolDelegate(libraryHandle, "StopLogging"); - } - - /// - /// Gets the path of a library that allows to detect the underlying CPU (Windows / Linux). - /// - /// architecture checker library path - public static string GetArchitectureCheckerPath() - { - string filename; - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer) - { - filename = $"windows-archchecker/archchecker.dll"; - } - else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer) - { - filename = $"linux-archchecker/libarchchecker.so"; - } - else - { - return null; - } - return Path.Combine(LLMUnitySetup.libraryPath, filename); - } - - /// - /// Gets additional dependencies for the specified architecture. - /// - /// architecture - /// paths of dependency dlls - public static List GetArchitectureDependencies(string arch) - { - List dependencies = new List(); - if (arch == "cuda-cu12.2.0-full") - { - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer) - { - dependencies.Add(Path.Combine(LLMUnitySetup.libraryPath, $"windows-{arch}/cudart64_12.dll")); - dependencies.Add(Path.Combine(LLMUnitySetup.libraryPath, $"windows-{arch}/cublasLt64_12.dll")); - dependencies.Add(Path.Combine(LLMUnitySetup.libraryPath, $"windows-{arch}/cublas64_12.dll")); - } - } - else if (arch == "vulkan") - { - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer) - { - dependencies.Add(Path.Combine(LLMUnitySetup.libraryPath, $"windows-{arch}/vulkan-1.dll")); - } - else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer) - { - dependencies.Add(Path.Combine(LLMUnitySetup.libraryPath, $"linux-{arch}/libvulkan.so.1")); - } - } - return dependencies; - } - - /// - /// Gets the path of the llama.cpp library for the specified architecture. - /// - /// architecture - /// llama.cpp library path - public static string GetArchitecturePath(string arch) - { - string filename; - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer) - { - filename = $"windows-{arch}/undreamai_windows-{arch}.dll"; - } - else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer) - { - filename = $"linux-{arch}/libundreamai_linux-{arch}.so"; - } - else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer || Application.platform == RuntimePlatform.OSXServer) - { - filename = $"macos-{arch}/libundreamai_macos-{arch}.dylib"; - } - else - { - string error = "Unknown OS"; - LLMUnitySetup.LogError(error); - throw new Exception(error); - } - return Path.Combine(LLMUnitySetup.libraryPath, filename); - } - - public delegate bool HasArchDelegate(); - public delegate void LoggingDelegate(IntPtr stringWrapper); - public delegate void StopLoggingDelegate(); - public delegate IntPtr LLM_ConstructDelegate(string command); - public delegate void LLM_DeleteDelegate(IntPtr LLMObject); - public delegate void LLM_StartServerDelegate(IntPtr LLMObject); - public delegate void LLM_StopServerDelegate(IntPtr LLMObject); - public delegate void LLM_StartDelegate(IntPtr LLMObject); - public delegate bool LLM_StartedDelegate(IntPtr LLMObject); - public delegate void LLM_StopDelegate(IntPtr LLMObject); - public delegate void LLM_SetTemplateDelegate(IntPtr LLMObject, string chatTemplate); - public delegate void LLM_SetSSLDelegate(IntPtr LLMObject, string SSLCert, string SSLKey); - public delegate void LLM_TokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - public delegate void LLM_DetokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - public delegate void LLM_EmbeddingsDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - public delegate void LLM_LoraWeightDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - public delegate void LLM_LoraListDelegate(IntPtr LLMObject, IntPtr stringWrapper); - public delegate void LLM_CompletionDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - public delegate void LLM_SlotDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); - public delegate void LLM_CancelDelegate(IntPtr LLMObject, int idSlot); - public delegate int LLM_StatusDelegate(IntPtr LLMObject, IntPtr stringWrapper); - public delegate IntPtr StringWrapper_ConstructDelegate(); - public delegate void StringWrapper_DeleteDelegate(IntPtr instance); - public delegate int StringWrapper_GetStringSizeDelegate(IntPtr instance); - public delegate void StringWrapper_GetStringDelegate(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false); - - public LoggingDelegate Logging; - public StopLoggingDelegate StopLogging; - public LLM_ConstructDelegate LLM_Construct; - public LLM_DeleteDelegate LLM_Delete; - public LLM_StartServerDelegate LLM_StartServer; - public LLM_StopServerDelegate LLM_StopServer; - public LLM_StartDelegate LLM_Start; - public LLM_StartedDelegate LLM_Started; - public LLM_StopDelegate LLM_Stop; - public LLM_SetTemplateDelegate LLM_SetTemplate; - public LLM_SetSSLDelegate LLM_SetSSL; - public LLM_TokenizeDelegate LLM_Tokenize; - public LLM_DetokenizeDelegate LLM_Detokenize; - public LLM_CompletionDelegate LLM_Completion; - public LLM_EmbeddingsDelegate LLM_Embeddings; - public LLM_LoraWeightDelegate LLM_LoraWeight; - public LLM_LoraListDelegate LLM_LoraList; - public LLM_SlotDelegate LLM_Slot; - public LLM_CancelDelegate LLM_Cancel; - public LLM_StatusDelegate LLM_Status; - public StringWrapper_ConstructDelegate StringWrapper_Construct; - public StringWrapper_DeleteDelegate StringWrapper_Delete; - public StringWrapper_GetStringSizeDelegate StringWrapper_GetStringSize; - public StringWrapper_GetStringDelegate StringWrapper_GetString; - -#endif - - /// - /// Identifies the possible architectures that we can use based on the OS and GPU usage - /// - /// whether to allow GPU architectures - /// possible architectures - public static List PossibleArchitectures(bool gpu = false) - { - List architectures = new List(); - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer || - Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer) - { - if (gpu) - { - if (LLMUnitySetup.FullLlamaLib) - { - architectures.Add("cuda-cu12.2.0-full"); - } - else - { - architectures.Add("cuda-cu12.2.0"); - } - architectures.Add("hip"); - architectures.Add("vulkan"); - } - if (has_avx512) architectures.Add("avx512"); - if (has_avx2) architectures.Add("avx2"); - if (has_avx) architectures.Add("avx"); - architectures.Add("noavx"); - } - else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer) - { - architectures.Add("acc"); - architectures.Add("no_acc"); - } - else if (Application.platform == RuntimePlatform.Android) - { - architectures.Add("android"); - } - else if (Application.platform == RuntimePlatform.IPhonePlayer) - { - architectures.Add("ios"); - } - else if (Application.platform == RuntimePlatform.VisionOS) - { - architectures.Add("visionos"); - } - else - { - string error = "Unknown OS"; - LLMUnitySetup.LogError(error); - throw new Exception(error); - } - return architectures; - } - - /// - /// Allows to retrieve a string from the library (Unity only allows marshalling of chars) - /// - /// string wrapper pointer - /// retrieved string - public string GetStringWrapperResult(IntPtr stringWrapper) - { - string result = ""; - int bufferSize = StringWrapper_GetStringSize(stringWrapper); - if (bufferSize > 1) - { - IntPtr buffer = Marshal.AllocHGlobal(bufferSize); - try - { - StringWrapper_GetString(stringWrapper, buffer, bufferSize); - result = Marshal.PtrToStringAnsi(buffer); - } - finally - { - Marshal.FreeHGlobal(buffer); - } - } - return result; - } - - /// - /// Destroys the LLM library - /// - public void Destroy() - { - if (libraryHandle != IntPtr.Zero) LibraryLoader.FreeLibrary(libraryHandle); - foreach (IntPtr dependencyHandle in dependencyHandles) LibraryLoader.FreeLibrary(dependencyHandle); - } - } -} -/// \endcond diff --git a/Runtime/LLMLib.cs.meta b/Runtime/LLMLib.cs.meta deleted file mode 100644 index b2c93e3a..00000000 --- a/Runtime/LLMLib.cs.meta +++ /dev/null @@ -1,11 +0,0 @@ -fileFormatVersion: 2 -guid: bce72731fae1ccb80b4a12f3d616f1ee -MonoImporter: - externalObjects: {} - serializedVersion: 2 - defaultReferences: [] - executionOrder: 0 - icon: {instanceID: 0} - userData: - assetBundleName: - assetBundleVariant: diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 72e58a06..98813062 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -20,14 +20,13 @@ public class ModelEntry public string filename; public string path; public bool lora; - public string chatTemplate; public string url; public bool embeddingOnly; public int embeddingLength; public bool includeInBuild; public int contextLength; - static List embeddingOnlyArchs = new List {"bert", "nomic-bert", "jina-bert-v2", "t5", "t5encoder"}; + static List embeddingOnlyArchs = new List { "bert", "nomic-bert", "jina-bert-v2", "t5", "t5encoder" }; /// /// Returns the relative asset path if it is in the AssetPath folder (StreamingAssets or persistentPath), otherwise the filename. @@ -60,7 +59,6 @@ public ModelEntry(string path, bool lora = false, string label = null, string ur this.path = LLMUnitySetup.GetFullPath(path); this.url = url; includeInBuild = true; - chatTemplate = null; contextLength = -1; embeddingOnly = false; embeddingLength = 0; @@ -74,7 +72,6 @@ public ModelEntry(string path, bool lora = false, string label = null, string ur embeddingLength = reader.GetIntField($"{arch}.embedding_length"); } embeddingOnly = embeddingOnlyArchs.Contains(arch); - chatTemplate = embeddingOnly ? default : ChatTemplate.FromGGUF(reader, this.path); } } @@ -98,7 +95,6 @@ public class LLMManagerStore public bool downloadOnStart; public List modelEntries; public int debugMode; - public bool fullLlamaLib; } /// \endcond @@ -167,7 +163,7 @@ public static async Task SetupOnce() else { target = LLMUnitySetup.GetDownloadAssetPath(modelEntry.filename); - downloads.Add(new StringPair {source = modelEntry.url, target = target}); + downloads.Add(new StringPair { source = modelEntry.url, target = target }); } } if (downloads.Count == 0) return true; @@ -206,34 +202,6 @@ public static async Task SetupOnce() return true; } - /// - /// Sets the chat template for a model and distributes it to all LLMs using it - /// - /// model path - /// chat template - public static void SetTemplate(string filename, string chatTemplate) - { - SetTemplate(Get(filename), chatTemplate); - } - - /// - /// Sets the chat template for a model and distributes it to all LLMs using it - /// - /// model entry - /// chat template - public static void SetTemplate(ModelEntry entry, string chatTemplate) - { - if (entry == null) return; - entry.chatTemplate = chatTemplate; - foreach (LLM llm in llms) - { - if (llm != null && llm.model == entry.filename) llm.SetTemplate(chatTemplate); - } -#if UNITY_EDITOR - Save(); -#endif - } - /// /// Gets the model entry for a model path /// @@ -327,7 +295,6 @@ public static void LoadFromDisk() downloadOnStart = store.downloadOnStart; modelEntries = store.modelEntries; LLMUnitySetup.DebugMode = (LLMUnitySetup.DebugModeType)store.debugMode; - LLMUnitySetup.FullLlamaLib = store.fullLlamaLib; } #if UNITY_EDITOR @@ -650,8 +617,7 @@ public static void SaveToDisk() { modelEntries = modelEntriesBuild, downloadOnStart = downloadOnStart, - debugMode = (int)LLMUnitySetup.DebugMode, - fullLlamaLib = LLMUnitySetup.FullLlamaLib + debugMode = (int)LLMUnitySetup.DebugMode }, true); File.WriteAllText(LLMUnitySetup.LLMManagerPath, json); } diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index afcb5010..86b16d5a 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -101,19 +101,17 @@ public class LLMUnitySetup { // DON'T CHANGE! the version is autocompleted with a GitHub action /// LLM for Unity version - public static string Version = "v2.5.2"; + public static string Version = "v3.0.0"; /// LlamaLib version - public static string LlamaLibVersion = "v1.2.6"; + public static string LlamaLibVersion = "v2.0.0"; /// LlamaLib release url public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}"; /// LlamaLib name - public static string libraryName = GetLibraryName(LlamaLibVersion); + public static string libraryName = $"LlamaLib-{LlamaLibVersion}"; /// LlamaLib path public static string libraryPath = GetAssetPath(libraryName); /// LlamaLib url public static string LlamaLibURL = $"{LlamaLibReleaseURL}/{libraryName}.zip"; - /// LlamaLib extension url - public static string LlamaLibExtensionURL = $"{LlamaLibReleaseURL}/{libraryName}-full.zip"; /// LLMnity store path public static string LLMUnityStore = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "LLMUnity"); /// Model download path @@ -122,7 +120,8 @@ public class LLMUnitySetup public static string LLMManagerPath = GetAssetPath("LLMManager.json"); /// Default models for download - [HideInInspector] public static readonly Dictionary modelOptions = new Dictionary() + [HideInInspector] + public static readonly Dictionary modelOptions = new Dictionary() { {"Large models (more than 10B)", new(string, string, string)[] { @@ -168,14 +167,15 @@ public class LLMUnitySetup /// \cond HIDE [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All; static string DebugModeKey = "DebugMode"; - public static bool FullLlamaLib = false; - static string FullLlamaLibKey = "FullLlamaLib"; + public static bool CUBLAS = false; + static string CUBLASKey = "CUBLAS"; static List> errorCallbacks = new List>(); static readonly object lockObject = new object(); static Dictionary androidExtractTasks = new Dictionary(); public enum DebugModeType { + Debug, All, Warning, Error, @@ -204,7 +204,7 @@ public static void LogError(string message) static void LoadPlayerPrefs() { DebugMode = (DebugModeType)PlayerPrefs.GetInt(DebugModeKey, (int)DebugModeType.All); - FullLlamaLib = PlayerPrefs.GetInt(FullLlamaLibKey, 0) == 1; + CUBLAS = PlayerPrefs.GetInt(CUBLASKey, 0) == 1; } public static void SetDebugMode(DebugModeType newDebugMode) @@ -216,22 +216,16 @@ public static void SetDebugMode(DebugModeType newDebugMode) } #if UNITY_EDITOR - public static void SetFullLlamaLib(bool value) + public static void SetCUBLAS(bool value) { - if (FullLlamaLib == value) return; - FullLlamaLib = value; - PlayerPrefs.SetInt(FullLlamaLibKey, value ? 1 : 0); + if (CUBLAS == value) return; + CUBLAS = value; + PlayerPrefs.SetInt(CUBLASKey, value ? 1 : 0); PlayerPrefs.Save(); - _ = DownloadLibrary(); } #endif - public static string GetLibraryName(string version) - { - return $"undreamai-{version}-llamacpp"; - } - public static string GetAssetPath(string relPath = "") { string assetsDir = Application.platform == RuntimePlatform.Android ? Application.persistentDataPath : Application.streamingAssetsPath; @@ -303,7 +297,7 @@ public static async Task DownloadFile( callback?.Invoke(savePath); } - public static async Task AndroidExtractFile(string assetName, bool overwrite = false, bool log = true, int chunkSize = 1024*1024) + public static async Task AndroidExtractFile(string assetName, bool overwrite = false, bool log = true, int chunkSize = 1024 * 1024) { Task extractionTask; lock (lockObject) @@ -321,7 +315,7 @@ public static async Task AndroidExtractFile(string assetName, bool overwrite = f await extractionTask; } - public static async Task AndroidExtractFileOnce(string assetName, bool overwrite = false, bool log = true, int chunkSize = 1024*1024) + public static async Task AndroidExtractFileOnce(string assetName, bool overwrite = false, bool log = true, int chunkSize = 1024 * 1024) { string source = "jar:file://" + Application.dataPath + "!/assets/" + assetName; string target = GetAssetPath(assetName); @@ -411,14 +405,28 @@ public static void CreateEmptyFile(string path) File.Create(path).Dispose(); } - static void ExtractInsideDirectory(string zipPath, string extractPath, bool overwrite = true) + static void ExtractInsideDirectory(string zipPath, string extractPath, string prefix = "", bool overwrite = true) { using (ZipArchive archive = ZipFile.OpenRead(zipPath)) { foreach (ZipArchiveEntry entry in archive.Entries) { - if (string.IsNullOrEmpty(entry.Name)) continue; - string destinationPath = Path.Combine(extractPath, entry.FullName); + if (string.IsNullOrEmpty(entry.Name)) + continue; // Skip directories + + string destinationPath; + if (!String.IsNullOrEmpty(prefix)) + { + string normalizedPath = entry.FullName.Replace('\\', '/'); + if (!normalizedPath.StartsWith(prefix, StringComparison.OrdinalIgnoreCase)) + continue; + destinationPath = Path.Combine(extractPath, normalizedPath.Substring(prefix.Length)); + } + else + { + destinationPath = Path.Combine(extractPath, entry.FullName); + } + Directory.CreateDirectory(Path.GetDirectoryName(destinationPath)); entry.ExtractToFile(destinationPath, overwrite); } @@ -435,7 +443,7 @@ static async Task DownloadAndExtractInsideDirectory(string url, string path, str await DownloadFile(url, zipPath, true, null, SetLibraryProgress); AssetDatabase.StartAssetEditing(); - ExtractInsideDirectory(zipPath, path); + ExtractInsideDirectory(zipPath, path, $"{libraryName}/runtimes/"); CreateEmptyFile(setupFile); AssetDatabase.StopAssetEditing(); @@ -445,23 +453,26 @@ static async Task DownloadAndExtractInsideDirectory(string url, string path, str static void DeleteEarlierVersions() { List assetPathSubDirs = new List(); - foreach (string dir in new string[] {GetAssetPath(), Path.Combine(Application.dataPath, "Plugins", "Android")}) + foreach (string dir in new string[] { GetAssetPath(), Path.Combine(Application.dataPath, "Plugins", "Android") }) { if (Directory.Exists(dir)) assetPathSubDirs.AddRange(Directory.GetDirectories(dir)); } - Regex regex = new Regex(GetLibraryName("(.+)")); + List versionRegexes = new List { new Regex("undreamai-(.+)-llamacpp"), new Regex("LlamaLib-(.+)") }; foreach (string assetPathSubDir in assetPathSubDirs) { - Match match = regex.Match(Path.GetFileName(assetPathSubDir)); - if (match.Success) + foreach (Regex regex in versionRegexes) { - string version = match.Groups[1].Value; - if (version != LlamaLibVersion) + Match match = regex.Match(Path.GetFileName(assetPathSubDir)); + if (match.Success) { - Debug.Log($"Deleting other LLMUnity version folder: {assetPathSubDir}"); - Directory.Delete(assetPathSubDir, true); - if (File.Exists(assetPathSubDir + ".meta")) File.Delete(assetPathSubDir + ".meta"); + string version = match.Groups[1].Value; + if (version != LlamaLibVersion) + { + Debug.Log($"Deleting other LLMUnity version folder: {assetPathSubDir}"); + Directory.Delete(assetPathSubDir, true); + if (File.Exists(assetPathSubDir + ".meta")) File.Delete(assetPathSubDir + ".meta"); + } } } } @@ -481,9 +492,6 @@ static async Task DownloadLibrary() // setup LlamaLib in StreamingAssets await DownloadAndExtractInsideDirectory(LlamaLibURL, libraryPath, setupDir); - - // setup LlamaLib extras in StreamingAssets - if (FullLlamaLib) await DownloadAndExtractInsideDirectory(LlamaLibExtensionURL, libraryPath, setupDir); } catch (Exception e) { @@ -511,7 +519,7 @@ public static string AddAsset(string assetPath) string filename = Path.GetFileName(assetPath); string fullPath = GetAssetPath(filename); AssetDatabase.StartAssetEditing(); - foreach (string path in new string[] {fullPath, fullPath + ".meta"}) + foreach (string path in new string[] { fullPath, fullPath + ".meta" }) { if (File.Exists(path)) File.Delete(path); } @@ -627,7 +635,7 @@ public static int AndroidGetNumBigCores() int maxFreqKHz = GetMaxFreqKHz(coreIndex); cpuMaxFreqKHz.Add(maxFreqKHz); if (maxFreqKHz > maxFreqKHzMax) maxFreqKHzMax = maxFreqKHz; - if (maxFreqKHz < maxFreqKHzMin) maxFreqKHzMin = maxFreqKHz; + if (maxFreqKHz < maxFreqKHzMin) maxFreqKHzMin = maxFreqKHz; cpuIsSmtCpu.Add(IsSmtCpu(coreIndex)); } } diff --git a/Runtime/LLMUtils.cs b/Runtime/LLMUtils.cs index 43681479..560f2299 100644 --- a/Runtime/LLMUtils.cs +++ b/Runtime/LLMUtils.cs @@ -2,6 +2,8 @@ /// @brief File implementing LLM helper code. using System; using System.Collections.Generic; +using System.Threading; +using UndreamAI.LlamaLib; namespace LLMUnity { @@ -200,5 +202,25 @@ public string[] GetLoras() return loraPaths; } } + + public class Utils + { + public static LlamaLib.CharArrayCallback WrapCallbackForAsync(LlamaLib.CharArrayCallback callback) + { + LlamaLib.CharArrayCallback wrappedCallback = null; + if (callback != null) + { + var context = SynchronizationContext.Current; + wrappedCallback = (string msg) => + { + if (context != null) + context.Post(_ => callback(msg), null); + else + callback(msg); + }; + } + return wrappedCallback; + } + } /// \endcond } diff --git a/Runtime/LlamaLib.meta b/Runtime/LlamaLib.meta new file mode 100644 index 00000000..abad3160 --- /dev/null +++ b/Runtime/LlamaLib.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: d41736301318069f486b450daff6db61 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/LlamaLib/LLM.cs b/Runtime/LlamaLib/LLM.cs new file mode 100644 index 00000000..5dcd9d20 --- /dev/null +++ b/Runtime/LlamaLib/LLM.cs @@ -0,0 +1,420 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Newtonsoft.Json.Linq; + +namespace UndreamAI.LlamaLib +{ + // Data structures for LoRA operations + public struct LoraIdScale + { + public int Id { get; set; } + public float Scale { get; set; } + + public LoraIdScale(int id, float scale) + { + Id = id; + Scale = scale; + } + } + + public struct LoraIdScalePath + { + public int Id { get; set; } + public float Scale { get; set; } + public string Path { get; set; } + + public LoraIdScalePath(int id, float scale, string path) + { + Id = id; + Scale = scale; + Path = path; + } + } + + // Base LLM class + public abstract class LLM : IDisposable + { + public LlamaLib llamaLib = null; + public IntPtr llm = IntPtr.Zero; + protected readonly object _disposeLock = new object(); + public bool disposed = false; + + protected LLM() {} + + protected LLM(LlamaLib llamaLibInstance) + { + llamaLib = llamaLibInstance ?? throw new ArgumentNullException(nameof(llamaLibInstance)); + } + + public static void Debug(int debugLevel) + { + LlamaLib.Debug(debugLevel); + } + + public static void LoggingCallback(LlamaLib.CharArrayCallback callback) + { + LlamaLib.LoggingCallback(callback); + } + + public static void LoggingStop() + { + LlamaLib.LoggingStop(); + } + + protected void CheckLlamaLib() + { + if (disposed) throw new ObjectDisposedException(GetType().Name); + if (llamaLib == null) throw new InvalidOperationException("LlamaLib instance is not initialized"); + if (llm == IntPtr.Zero) throw new InvalidOperationException("LLM instance is not initialized"); + if (llamaLib.LLM_Status_Code() != 0) + { + string status_msg = Marshal.PtrToStringAnsi(llamaLib.LLM_Status_Message()) ?? string.Empty; + throw new AccessViolationException(status_msg); + } + } + + public virtual void Dispose() {} + + ~LLM() + { + Dispose(); + } + + public string GetTemplate() + { + CheckLlamaLib(); + IntPtr result = llamaLib.LLM_Get_Template(llm); + return Marshal.PtrToStringAnsi(result) ?? string.Empty; + } + + public string ApplyTemplate(JArray messages = null) + { + if (messages == null) + throw new ArgumentNullException(nameof(messages)); + CheckLlamaLib(); + IntPtr result = llamaLib.LLM_Apply_Template(llm, messages.ToString()); + return Marshal.PtrToStringAnsi(result) ?? string.Empty; + } + + public List Tokenize(string content) + { + if (string.IsNullOrEmpty(content)) + throw new ArgumentNullException(nameof(content)); + + CheckLlamaLib(); + IntPtr result = llamaLib.LLM_Tokenize(llm, content); + string resultStr = Marshal.PtrToStringAnsi(result) ?? string.Empty; + List ret = new List(); + try + { + JArray json = JArray.Parse(resultStr); + ret = json?.ToObject>(); + } + catch {} + return ret; + } + + public string Detokenize(List tokens) + { + if (tokens == null) + throw new ArgumentNullException(nameof(tokens)); + + CheckLlamaLib(); + JArray tokensJSON = JArray.FromObject(tokens); + IntPtr result = llamaLib.LLM_Detokenize(llm, tokensJSON.ToString()); + return Marshal.PtrToStringAnsi(result) ?? string.Empty; + } + + public string Detokenize(int[] tokens) + { + if (tokens == null) + throw new ArgumentNullException(nameof(tokens)); + return Detokenize(new List(tokens)); + } + + public List Embeddings(string content) + { + if (string.IsNullOrEmpty(content)) + throw new ArgumentNullException(nameof(content)); + + CheckLlamaLib(); + + IntPtr result = llamaLib.LLM_Embeddings(llm, content); + string resultStr = Marshal.PtrToStringAnsi(result) ?? string.Empty; + + List ret = new List(); + try + { + JArray json = JArray.Parse(resultStr); + ret = json?.ToObject>(); + } + catch {} + return ret; + } + + public void SetCompletionParameters(JObject parameters = null) + { + CheckLlamaLib(); + llamaLib.LLM_Set_Completion_Parameters(llm, parameters.ToString()); + } + + public JObject GetCompletionParameters() + { + CheckLlamaLib(); + JObject parameters = new JObject(); + IntPtr result = llamaLib.LLM_Get_Completion_Parameters(llm); + string parametersString = Marshal.PtrToStringAnsi(result) ?? "{}"; + try + { + parameters = JObject.Parse(parametersString); + } + catch {} + return parameters; + } + + public void SetGrammar(string grammar) + { + CheckLlamaLib(); + llamaLib.LLM_Set_Grammar(llm, grammar); + } + + public string GetGrammar() + { + CheckLlamaLib(); + IntPtr result = llamaLib.LLM_Get_Grammar(llm); + return Marshal.PtrToStringAnsi(result) ?? ""; + } + + public void CheckCompletionInternal(string prompt) + { + if (string.IsNullOrEmpty(prompt)) + throw new ArgumentNullException(nameof(prompt)); + CheckLlamaLib(); + } + + public string CompletionInternal(string prompt, LlamaLib.CharArrayCallback callback, int idSlot) + { + IntPtr result; + result = llamaLib.LLM_Completion(llm, prompt, callback, idSlot); + return Marshal.PtrToStringAnsi(result) ?? string.Empty; + } + + public string Completion(string prompt, LlamaLib.CharArrayCallback callback = null, int idSlot = -1) + { + CheckCompletionInternal(prompt); + return CompletionInternal(prompt, callback, idSlot); + } + + public async Task CompletionAsync(string prompt, LlamaLib.CharArrayCallback callback = null, int idSlot = -1) + { + CheckCompletionInternal(prompt); + return await Task.Run(() => CompletionInternal(prompt, callback, idSlot)); + } + } + + // LLMLocal class + public abstract class LLMLocal : LLM + { + protected LLMLocal() : base() {} + + protected LLMLocal(LlamaLib llamaLibInstance) : base(llamaLibInstance) {} + + public string SaveSlot(int idSlot, string filepath) + { + if (string.IsNullOrEmpty(filepath)) + throw new ArgumentNullException(nameof(filepath)); + + IntPtr result = llamaLib.LLM_Save_Slot(llm, idSlot, filepath); + return Marshal.PtrToStringAnsi(result) ?? string.Empty; + } + + public string LoadSlot(int idSlot, string filepath) + { + if (string.IsNullOrEmpty(filepath) || !File.Exists(filepath)) + throw new ArgumentNullException(nameof(filepath)); + + IntPtr result = llamaLib.LLM_Load_Slot(llm, idSlot, filepath); + return Marshal.PtrToStringAnsi(result) ?? string.Empty; + } + + public void Cancel(int idSlot) + { + CheckLlamaLib(); + llamaLib.LLM_Cancel(llm, idSlot); + } + } + + // LLMProvider class + public abstract class LLMProvider : LLMLocal + { + protected LLMProvider() : base() {} + + protected LLMProvider(LlamaLib llamaLibInstance) : base(llamaLibInstance) {} + + + public void SetTemplate(string template) + { + CheckLlamaLib(); + llamaLib.LLM_Set_Template(llm, template); + } + + // LoRA Weight methods + public string BuildLoraWeightJSON(List loras) + { + var jsonArray = new JArray(); + foreach (var lora in loras) + { + jsonArray.Add(new JObject { ["id"] = lora.Id, ["scale"] = lora.Scale }); + } + return jsonArray.ToString(); + } + + public bool LoraWeight(List loras) + { + if (loras == null) + throw new ArgumentNullException(nameof(loras)); + + var lorasJSON = BuildLoraWeightJSON(loras); + return llamaLib.LLM_Lora_Weight(llm, lorasJSON); + } + + public bool LoraWeight(params LoraIdScale[] loras) + { + return LoraWeight(new List(loras)); + } + + // LoRA List methods + public List ParseLoraListJSON(string result) + { + var loras = new List(); + try + { + var jsonArray = JArray.Parse(result); + foreach (var item in jsonArray) + { + int id = item["id"]?.ToObject() ?? -1; + if (id < 0) continue; + loras.Add(new LoraIdScalePath( + id, + item["scale"]?.ToObject() ?? 0.0f, + item["path"]?.ToString() ?? string.Empty + )); + } + } + catch {} + return loras; + } + + public string LoraListJSON() + { + CheckLlamaLib(); + var result = llamaLib.LLM_Lora_List(llm); + return Marshal.PtrToStringAnsi(result) ?? string.Empty; + } + + public List LoraList() + { + var jsonResult = LoraListJSON(); + return ParseLoraListJSON(jsonResult); + } + + // Server methods + public bool Start() + { + CheckLlamaLib(); + llamaLib.LLM_Start(llm); + return llamaLib.LLM_Started(llm); + } + + public async Task StartAsync() + { + CheckLlamaLib(); + return await Task.Run(() => + { + llamaLib.LLM_Start(llm); + return llamaLib.LLM_Started(llm); + }); + } + + public bool Started() + { + CheckLlamaLib(); + return llamaLib.LLM_Started(llm); + } + + public void Stop() + { + CheckLlamaLib(); + llamaLib.LLM_Stop(llm); + } + + public void StartServer(string host = "0.0.0.0", int port = -1, string apiKey = "") + { + CheckLlamaLib(); + if (string.IsNullOrEmpty(host)) + host = "0.0.0.0"; + + llamaLib.LLM_Start_Server(llm, host, port, apiKey ?? string.Empty); + } + + public void StopServer() + { + CheckLlamaLib(); + llamaLib.LLM_Stop_Server(llm); + } + + public void JoinService() + { + CheckLlamaLib(); + llamaLib.LLM_Join_Service(llm); + } + + public void JoinServer() + { + CheckLlamaLib(); + llamaLib.LLM_Join_Server(llm); + } + + public void SetSSL(string sslCert, string sslKey) + { + if (string.IsNullOrEmpty(sslCert)) + throw new ArgumentNullException(nameof(sslCert)); + if (string.IsNullOrEmpty(sslKey)) + throw new ArgumentNullException(nameof(sslKey)); + + CheckLlamaLib(); + llamaLib.LLM_Set_SSL(llm, sslCert, sslKey); + } + + public int EmbeddingSize() + { + CheckLlamaLib(); + return llamaLib.LLM_Embedding_Size(llm); + } + + public override void Dispose() + { + lock (_disposeLock) + { + if (!disposed) + { + if (llm != IntPtr.Zero && llamaLib != null) + { + try + { + llamaLib.LLM_Delete(llm); + } + catch (Exception) {} + } + llamaLib?.Dispose(); + llamaLib = null; + llm = IntPtr.Zero; + } + disposed = true; + } + } + } +} diff --git a/Runtime/LLMChatTemplates.cs.meta b/Runtime/LlamaLib/LLM.cs.meta similarity index 83% rename from Runtime/LLMChatTemplates.cs.meta rename to Runtime/LlamaLib/LLM.cs.meta index 5cfa7005..46805ab9 100644 --- a/Runtime/LLMChatTemplates.cs.meta +++ b/Runtime/LlamaLib/LLM.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 7aafa0738b61af4af85e44b81b5625ca +guid: 4c5139afa9c5c83568c0b622b0863a72 MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/Runtime/LlamaLib/LLMAgent.cs b/Runtime/LlamaLib/LLMAgent.cs new file mode 100644 index 00000000..11ee4a0e --- /dev/null +++ b/Runtime/LlamaLib/LLMAgent.cs @@ -0,0 +1,299 @@ +using System; +using System.IO; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Newtonsoft.Json.Linq; + +namespace UndreamAI.LlamaLib +{ + // Data structure for chat messages + public struct ChatMessage + { + public string Role { get; set; } + public string Content { get; set; } + + public ChatMessage(string role, string content) + { + Role = role; + Content = content; + } + + public JObject ToJson() + { + return new JObject + { + ["role"] = Role, + ["content"] = Content + }; + } + + public static ChatMessage FromJson(JObject json) + { + return new ChatMessage( + json["role"]?.ToString() ?? string.Empty, + json["content"]?.ToString() ?? string.Empty + ); + } + } + + // LLMAgent class + public class LLMAgent : LLMLocal + { + private LLMLocal llmBase; + + public LLMAgent(LLMLocal _llm, string _systemPrompt = "", string _userRole = "user", string _assistantRole = "assistant") + { + if (_llm == null) + throw new ArgumentNullException(nameof(_llm)); + if (_llm.disposed) + throw new ObjectDisposedException(nameof(_llm)); + + llmBase = _llm; + llamaLib = llmBase.llamaLib; + + llm = llamaLib.LLMAgent_Construct(llmBase.llm, _systemPrompt ?? "", _userRole ?? "user", _assistantRole ?? "assistant"); + if (llm == IntPtr.Zero) throw new InvalidOperationException("Failed to create LLMAgent"); + } + + // Properties + public int SlotId + { + get + { + CheckLlamaLib(); + return llamaLib.LLMAgent_Get_Slot(llm); + } + set + { + CheckLlamaLib(); + llamaLib.LLMAgent_Set_Slot(llm, value); + } + } + + public string UserRole + { + get + { + CheckLlamaLib(); + return Marshal.PtrToStringAnsi(llamaLib.LLMAgent_Get_User_Role(llm)) ?? ""; + } + set + { + CheckLlamaLib(); + llamaLib.LLMAgent_Set_User_Role(llm, value); + } + } + + public string AssistantRole + { + get + { + CheckLlamaLib(); + return Marshal.PtrToStringAnsi(llamaLib.LLMAgent_Get_Assistant_Role(llm)) ?? ""; + } + set + { + CheckLlamaLib(); + llamaLib.LLMAgent_Set_Assistant_Role(llm, value); + } + } + + public string SystemPrompt + { + get + { + CheckLlamaLib(); + return Marshal.PtrToStringAnsi(llamaLib.LLMAgent_Get_System_Prompt(llm)) ?? ""; + } + set + { + CheckLlamaLib(); + llamaLib.LLMAgent_Set_System_Prompt(llm, value); + } + } + + // History management + public JArray History + { + get + { + CheckLlamaLib(); + IntPtr result = llamaLib.LLMAgent_Get_History(llm); + string historyStr = Marshal.PtrToStringAnsi(result) ?? "[]"; + try + { + return JArray.Parse(historyStr); + } + catch + { + return new JArray(); + } + } + set + { + CheckLlamaLib(); + string historyJson = value?.ToString() ?? "[]"; + llamaLib.LLMAgent_Set_History(llm, historyJson); + } + } + + public List GetHistory() + { + var history = History; + var messages = new List(); + + try + { + foreach (var item in history) + { + if (item is JObject messageObj) + { + messages.Add(ChatMessage.FromJson(messageObj)); + } + } + } + catch {} + + return messages; + } + + public void SetHistory(List messages) + { + if (messages == null) + throw new ArgumentNullException(nameof(messages)); + + var historyArray = new JArray(); + foreach (var message in messages) + { + historyArray.Add(message.ToJson()); + } + History = historyArray; + } + + public void ClearHistory() + { + CheckLlamaLib(); + llamaLib.LLMAgent_Clear_History(llm); + } + + public void AddMessage(string role, string content) + { + CheckLlamaLib(); + llamaLib.LLMAgent_Add_Message(llm, role, content); + } + + public void AddUserMessage(string content) + { + CheckLlamaLib(); + llamaLib.LLMAgent_Add_Message(llm, UserRole, content); + } + + public void AddAssistantMessage(string content) + { + CheckLlamaLib(); + llamaLib.LLMAgent_Add_Message(llm, AssistantRole, content); + } + + public void AddMessage(ChatMessage message) + { + AddMessage(message.Role, message.Content); + } + + public void RemoveLastMessage() + { + CheckLlamaLib(); + llamaLib.LLMAgent_Remove_Last_Message(llm); + } + + public void SaveHistory(string filepath) + { + if (string.IsNullOrEmpty(filepath)) + throw new ArgumentNullException(nameof(filepath)); + + CheckLlamaLib(); + llamaLib.LLMAgent_Save_History(llm, filepath); + } + + public void LoadHistory(string filepath) + { + if (string.IsNullOrEmpty(filepath)) + throw new ArgumentNullException(nameof(filepath)); + + CheckLlamaLib(); + llamaLib.LLMAgent_Load_History(llm, filepath); + } + + public int GetHistorySize() + { + CheckLlamaLib(); + return llamaLib.LLMAgent_Get_History_Size(llm); + } + + // Chat functionality + public string Chat(string userPrompt, bool addToHistory = true, LlamaLib.CharArrayCallback callback = null, bool returnResponseJson = false) + { + CheckLlamaLib(); + IntPtr result = llamaLib.LLMAgent_Chat(llm, userPrompt, addToHistory, callback, returnResponseJson); + return Marshal.PtrToStringAnsi(result) ?? string.Empty; + } + + public async Task ChatAsync(string userPrompt, bool addToHistory = true, LlamaLib.CharArrayCallback callback = null, bool returnResponseJson = false) + { + return await Task.Run(() => Chat(userPrompt, addToHistory, callback, returnResponseJson)); + } + + // Override completion methods to use agent-specific implementations + public string Completion(string prompt, LlamaLib.CharArrayCallback callback = null) + { + return Completion(prompt, callback, SlotId); + } + + public async Task CompletionAsync(string prompt, LlamaLib.CharArrayCallback callback = null) + { + return await Task.Run(() => Completion(prompt, callback)); + } + + public string SaveSlot(string filepath) + { + if (string.IsNullOrEmpty(filepath)) + throw new ArgumentNullException(nameof(filepath)); + + CheckLlamaLib(); + IntPtr result = llamaLib.LLM_Save_Slot(llm, SlotId, filepath); + return Marshal.PtrToStringAnsi(result) ?? string.Empty; + } + + public string LoadSlot(string filepath) + { + if (string.IsNullOrEmpty(filepath) || !File.Exists(filepath)) + throw new ArgumentNullException(nameof(filepath)); + + CheckLlamaLib(); + IntPtr result = llamaLib.LLM_Load_Slot(llm, SlotId, filepath); + return Marshal.PtrToStringAnsi(result) ?? string.Empty; + } + + public void Cancel() + { + CheckLlamaLib(); + llamaLib.LLM_Cancel(llm, SlotId); + } + + // Override slot-based methods to hide them + private new string SaveSlot(int id_slot, string filepath) + { + return SaveSlot(filepath); + } + + private new string LoadSlot(int id_slot, string filepath) + { + return LoadSlot(filepath); + } + + private new void Cancel(int id_slot) + { + Cancel(); + } + } +} diff --git a/Runtime/LLMInterface.cs.meta b/Runtime/LlamaLib/LLMAgent.cs.meta similarity index 83% rename from Runtime/LLMInterface.cs.meta rename to Runtime/LlamaLib/LLMAgent.cs.meta index 66038d3a..6af83406 100644 --- a/Runtime/LLMInterface.cs.meta +++ b/Runtime/LlamaLib/LLMAgent.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 3c7da7ca2b36d79e99acecc300de2902 +guid: e046cf547a885a47987b2de08256d023 MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/Runtime/LlamaLib/LLMClient.cs b/Runtime/LlamaLib/LLMClient.cs new file mode 100644 index 00000000..edc0ab44 --- /dev/null +++ b/Runtime/LlamaLib/LLMClient.cs @@ -0,0 +1,49 @@ +using System; + +namespace UndreamAI.LlamaLib +{ + public class LLMClient : LLMLocal + { + public LLMClient(LLMProvider provider) + { + if (provider.disposed) + throw new ObjectDisposedException(nameof(provider)); + + llamaLib = provider.llamaLib; + llm = CreateClient(provider); + } + + public LLMClient(string url, int port, string apiKey = "") + { + if (string.IsNullOrEmpty(url)) + throw new ArgumentNullException(nameof(url)); + + try + { + llamaLib = new LlamaLib(false); + llm = CreateRemoteClient(url, port, apiKey); + } + catch + { + llamaLib?.Dispose(); + throw; + } + } + + private IntPtr CreateClient(LLMProvider provider) + { + var llm = llamaLib.LLMClient_Construct(provider.llm); + if (llm == IntPtr.Zero) + throw new InvalidOperationException("Failed to create LLMClient"); + return llm; + } + + private IntPtr CreateRemoteClient(string url, int port, string apiKey = "") + { + var llm = llamaLib.LLMClient_Construct_Remote(url, port, apiKey); + if (llm == IntPtr.Zero) + throw new InvalidOperationException($"Failed to create remote LLMClient for {url}:{port}"); + return llm; + } + } +} \ No newline at end of file diff --git a/Runtime/LlamaLib/LLMClient.cs.meta b/Runtime/LlamaLib/LLMClient.cs.meta new file mode 100644 index 00000000..c4aac8b5 --- /dev/null +++ b/Runtime/LlamaLib/LLMClient.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 37ba7749ae9efccc58368375fa8654e0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/LlamaLib/LLMService.cs b/Runtime/LlamaLib/LLMService.cs new file mode 100644 index 00000000..4347023f --- /dev/null +++ b/Runtime/LlamaLib/LLMService.cs @@ -0,0 +1,130 @@ +using System; +using System.IO; +using System.Runtime.InteropServices; + +namespace UndreamAI.LlamaLib +{ + public class LLMService : LLMProvider + { + public LLMService(string modelPath, int numSlots = 1, + int numThreads = -1, int numGpuLayers = 0, + bool flashAttention = false, int contextSize = 4096, + int batchSize = 2048, bool embeddingOnly = false, string[] loraPaths = null) + { + if (string.IsNullOrEmpty(modelPath)) + throw new ArgumentNullException(nameof(modelPath)); + if (!File.Exists(modelPath)) + throw new FileNotFoundException($"Model file not found: {modelPath}"); + + try + { + llamaLib = new LlamaLib(numGpuLayers > 0); + llm = CreateLLM( + llamaLib, + modelPath, numSlots, numThreads, numGpuLayers, + flashAttention, contextSize, batchSize, embeddingOnly, loraPaths); + } + catch + { + llamaLib?.Dispose(); + throw; + } + } + + public LLMService(LlamaLib llamaLibInstance, IntPtr llmInstance) + { + if (llamaLibInstance == null) throw new ArgumentNullException(nameof(llamaLibInstance)); + if (llmInstance == IntPtr.Zero) throw new ArgumentNullException(nameof(llmInstance)); + llamaLib = llamaLibInstance; + llm = llmInstance; + } + + public static LLMService FromCommand(string paramsString) + { + if (string.IsNullOrEmpty(paramsString)) + throw new ArgumentNullException(nameof(paramsString)); + + LlamaLib llamaLibInstance = null; + IntPtr llmInstance = IntPtr.Zero; + try + { + llamaLibInstance = new LlamaLib(LlamaLib.Has_GPU_Layers(paramsString)); + llmInstance = llamaLibInstance.LLMService_From_Command(paramsString); + } + catch + { + llamaLibInstance?.Dispose(); + throw; + } + return new LLMService(llamaLibInstance, llmInstance); + } + + public static IntPtr CreateLLM( + LlamaLib llamaLib, + string modelPath, int numSlots, int numThreads, + int numGpuLayers, bool flashAttention, int contextSize, int batchSize, + bool embeddingOnly, string[] loraPaths) + { + IntPtr loraPathsPtr = IntPtr.Zero; + int loraPathCount = 0; + + if (loraPaths != null && loraPaths.Length > 0) + { + loraPathCount = loraPaths.Length; + // Allocate array of string pointers + loraPathsPtr = Marshal.AllocHGlobal(IntPtr.Size * loraPathCount); + + try + { + for (int i = 0; i < loraPathCount; i++) + { + if (string.IsNullOrEmpty(loraPaths[i])) + throw new ArgumentException($"Lora path at index {i} is null or empty"); + + IntPtr stringPtr = Marshal.StringToHGlobalAnsi(loraPaths[i]); + Marshal.WriteIntPtr(loraPathsPtr, i * IntPtr.Size, stringPtr); + } + } + catch + { + // Clean up if allocation failed + for (int i = 0; i < loraPathCount; i++) + { + IntPtr stringPtr = Marshal.ReadIntPtr(loraPathsPtr, i * IntPtr.Size); + if (stringPtr != IntPtr.Zero) + Marshal.FreeHGlobal(stringPtr); + } + Marshal.FreeHGlobal(loraPathsPtr); + throw; + } + } + + try + { + var llm = llamaLib.LLMService_Construct( + modelPath, numSlots, numThreads, numGpuLayers, + flashAttention, contextSize, batchSize, embeddingOnly, + loraPathCount, loraPathsPtr); + + if (llm == IntPtr.Zero) + throw new InvalidOperationException("Failed to create LLMService"); + + return llm; + } + finally + { + // Clean up allocated strings + if (loraPathsPtr != IntPtr.Zero) + { + for (int i = 0; i < loraPathCount; i++) + { + IntPtr stringPtr = Marshal.ReadIntPtr(loraPathsPtr, i * IntPtr.Size); + if (stringPtr != IntPtr.Zero) + Marshal.FreeHGlobal(stringPtr); + } + Marshal.FreeHGlobal(loraPathsPtr); + } + } + } + } +} diff --git a/Runtime/LlamaLib/LLMService.cs.meta b/Runtime/LlamaLib/LLMService.cs.meta new file mode 100644 index 00000000..e45f2e73 --- /dev/null +++ b/Runtime/LlamaLib/LLMService.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 72670472c9b43749c9035c00bb020e52 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/LlamaLib/LibraryLoader.cs b/Runtime/LlamaLib/LibraryLoader.cs new file mode 100644 index 00000000..93b3a3f4 --- /dev/null +++ b/Runtime/LlamaLib/LibraryLoader.cs @@ -0,0 +1,221 @@ +/// @file +/// @brief File implementing the LlamaLib library loader +/// \cond HIDE +using System; +using System.Runtime.InteropServices; + +namespace UndreamAI.LlamaLib +{ + /// @ingroup utils + /// + /// Class implementing the LlamaLib library loader + /// Adapted from SkiaForUnity: + /// https://github.com/ammariqais/SkiaForUnity/blob/f43322218c736d1c41f3a3df9355b90db4259a07/SkiaUnity/Assets/SkiaSharp/SkiaSharp-Bindings/SkiaSharp.HarfBuzz.Shared/HarfBuzzSharp.Shared/LibraryLoader.cs + /// + static class LibraryLoader + { + /// + /// Allows to retrieve a function delegate for the library + /// + /// type to cast the function + /// library handle + /// function name + /// function delegate + public static T GetSymbolDelegate(IntPtr library, string name) where T : Delegate + { + var symbol = GetSymbol(library, name); + if (symbol == IntPtr.Zero) + throw new EntryPointNotFoundException($"Unable to load symbol '{name}'."); + + return Marshal.GetDelegateForFunctionPointer(symbol); + } + /// + /// Loads the provided library in a cross-platform manner + /// + /// library path + /// library handle + public static IntPtr LoadLibrary(string libraryPath) + { + if (string.IsNullOrEmpty(libraryPath)) + throw new ArgumentNullException(nameof(libraryPath)); + +#if ANDROID || IOS || VISIONOS + return Mobile.dlopen(libraryPath); +#else + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + return Win32.LoadLibrary(libraryPath); + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + return Linux.dlopen(libraryPath); + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + return Mac.dlopen(libraryPath); + else throw new PlatformNotSupportedException($"Current platform is unknown, unable to load library '{libraryPath}'."); +#endif + } + + /// + /// Retrieve a function delegate for the library in a cross-platform manner + /// + /// library handle + /// function name + /// function handle + public static IntPtr GetSymbol(IntPtr library, string symbolName) + { + if (string.IsNullOrEmpty(symbolName)) + throw new ArgumentNullException(nameof(symbolName)); + +#if ANDROID || IOS || VISIONOS + return Mobile.dlsym(library, symbolName); +#else + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + return Win32.GetProcAddress(library, symbolName); + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + return Linux.dlsym(library, symbolName); + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + return Mac.dlsym(library, symbolName); + else throw new PlatformNotSupportedException($"Current platform is unknown, unable to load symbol '{symbolName}' from library {library}."); +#endif + } + + /// + /// Frees up the library + /// + /// library handle + public static void FreeLibrary(IntPtr library) + { + if (library == IntPtr.Zero) + return; + +#if ANDROID || IOS || VISIONOS + Mobile.dlclose(library); +#else + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + Win32.FreeLibrary(library); + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + Linux.dlclose(library); + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + Mac.dlclose(library); + else throw new PlatformNotSupportedException($"Current platform is unknown, unable to close library '{library}'."); +#endif + } + + private static class Mac + { + private const string SystemLibrary = "/usr/lib/libSystem.dylib"; + + private const int RTLD_LAZY = 1; + private const int RTLD_NOW = 2; + + public static IntPtr dlopen(string path, bool lazy = true) => + dlopen(path, lazy ? RTLD_LAZY : RTLD_NOW); + + [DllImport(SystemLibrary)] + public static extern IntPtr dlopen(string path, int mode); + + [DllImport(SystemLibrary)] + public static extern IntPtr dlsym(IntPtr handle, string symbol); + + [DllImport(SystemLibrary)] + public static extern void dlclose(IntPtr handle); + } + + private static class Linux + { + private const string SystemLibrary = "libdl.so"; + private const string SystemLibrary2 = "libdl.so.2"; // newer Linux distros use this + + private const int RTLD_LAZY = 1; + private const int RTLD_NOW = 2; + + private static bool UseSystemLibrary2 = true; + + public static IntPtr dlopen(string path, bool lazy = true) + { + try + { + return dlopen2(path, lazy ? RTLD_LAZY : RTLD_NOW); + } + catch (DllNotFoundException) + { + UseSystemLibrary2 = false; + return dlopen1(path, lazy ? RTLD_LAZY : RTLD_NOW); + } + } + + public static IntPtr dlsym(IntPtr handle, string symbol) + { + return UseSystemLibrary2 ? dlsym2(handle, symbol) : dlsym1(handle, symbol); + } + + public static void dlclose(IntPtr handle) + { + if (UseSystemLibrary2) + dlclose2(handle); + else + dlclose1(handle); + } + + [DllImport(SystemLibrary, EntryPoint = "dlopen")] + private static extern IntPtr dlopen1(string path, int mode); + + [DllImport(SystemLibrary, EntryPoint = "dlsym")] + private static extern IntPtr dlsym1(IntPtr handle, string symbol); + + [DllImport(SystemLibrary, EntryPoint = "dlclose")] + private static extern void dlclose1(IntPtr handle); + + [DllImport(SystemLibrary2, EntryPoint = "dlopen")] + private static extern IntPtr dlopen2(string path, int mode); + + [DllImport(SystemLibrary2, EntryPoint = "dlsym")] + private static extern IntPtr dlsym2(IntPtr handle, string symbol); + + [DllImport(SystemLibrary2, EntryPoint = "dlclose")] + private static extern void dlclose2(IntPtr handle); + } + + private static class Win32 + { + private const string SystemLibrary = "Kernel32.dll"; + + [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)] + public static extern IntPtr LoadLibrary(string lpFileName); + + [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)] + public static extern IntPtr GetProcAddress(IntPtr hModule, string lpProcName); + + [DllImport(SystemLibrary, SetLastError = true)] + public static extern bool FreeLibrary(IntPtr hModule); + } + + private static class Mobile + { + public static IntPtr dlopen(string path) => dlopen(path, 1); + +#if ANDROID || IOS || VISIONOS + [DllImport("__Internal")] + public static extern IntPtr dlopen(string filename, int flags); + + [DllImport("__Internal")] + public static extern IntPtr dlsym(IntPtr handle, string symbol); + + [DllImport("__Internal")] + public static extern int dlclose(IntPtr handle); +#else + public static IntPtr dlopen(string filename, int flags) + { + return IntPtr.Zero; + } + + public static IntPtr dlsym(IntPtr handle, string symbol) + { + return IntPtr.Zero; + } + + public static int dlclose(IntPtr handle) + { + return 0; + } +#endif + } + } +} \ No newline at end of file diff --git a/Runtime/LlamaLib/LibraryLoader.cs.meta b/Runtime/LlamaLib/LibraryLoader.cs.meta new file mode 100644 index 00000000..7ea28b9c --- /dev/null +++ b/Runtime/LlamaLib/LibraryLoader.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 97d9c4e088cf8d3f1a29b468daff47b7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/LlamaLib/LlamaLib.cs b/Runtime/LlamaLib/LlamaLib.cs new file mode 100644 index 00000000..fbec7625 --- /dev/null +++ b/Runtime/LlamaLib/LlamaLib.cs @@ -0,0 +1,769 @@ +using System; +using System.IO; +using System.Runtime.InteropServices; +using System.Reflection; +using System.Collections.Generic; + +namespace UndreamAI.LlamaLib +{ + public class LlamaLib + { + public string architecture { get; private set; } + + // Function delegates + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void CharArrayCallback([MarshalAs(UnmanagedType.LPStr)] string charArray); + +#if ANDROID || IOS || VISIONOS + // Static P/Invoke declarations for mobile platforms +#if ANDROID_ARM64 + public const string DllName = "libllamalib_android-arm64"; +#elif ANDROID_X64 + public const string DllName = "libllamalib_android-x64"; +#else + public const string DllName = "__Internal"; +#endif + + public LlamaLib(bool gpu = false) + { +#if ANDROID_ARM64 + architecture = "android-arm64"; +#elif ANDROID_X64 + architecture = "android-x64"; +#elif IOS + architecture = "ios-arm64"; +#elif VISIONOS + architecture = "visionos-arm64"; +#endif + } + + // Base LLM functions + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Debug")] + public static extern void LLM_Debug_Static(int debugLevel); + public static void Debug(int debugLevel) => LLM_Debug_Static(debugLevel); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Logging_Callback")] + public static extern void LLM_Logging_Callback_Static(CharArrayCallback callback); + public static void LoggingCallback(CharArrayCallback callback) => LLM_Logging_Callback_Static(callback); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Logging_Stop")] + public static extern void LLM_Logging_Stop_Static(); + public static void LoggingStop() => LLM_Logging_Stop_Static(); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Get_Template")] + public static extern IntPtr LLM_Get_Template_Static(IntPtr llm); + public IntPtr LLM_Get_Template(IntPtr llm) => LLM_Get_Template_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Apply_Template")] + public static extern IntPtr LLM_Apply_Template_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string messages_as_json); + public IntPtr LLM_Apply_Template(IntPtr llm, string messages_as_json) => LLM_Apply_Template_Static(llm, messages_as_json); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Tokenize")] + public static extern IntPtr LLM_Tokenize_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string query); + public IntPtr LLM_Tokenize(IntPtr llm, string query) => LlamaLib.LLM_Tokenize_Static(llm, query); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Detokenize")] + public static extern IntPtr LLM_Detokenize_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string tokens_as_json); + public IntPtr LLM_Detokenize(IntPtr llm, string tokens_as_json) => LlamaLib.LLM_Detokenize_Static(llm, tokens_as_json); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Embeddings")] + public static extern IntPtr LLM_Embeddings_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string query); + public IntPtr LLM_Embeddings(IntPtr llm, string query) => LlamaLib.LLM_Embeddings_Static(llm, query); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Completion")] + public static extern IntPtr LLM_Completion_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string query, CharArrayCallback callback, int id_slot = -1, bool return_response_json = false); + public IntPtr LLM_Completion(IntPtr llm, string query, CharArrayCallback callback, int id_slot = -1, bool return_response_json = false) => LlamaLib.LLM_Completion_Static(llm, query, callback, id_slot, return_response_json); + + // LLMLocal functions + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Set_Template")] + public static extern IntPtr LLM_Set_Template_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string template); + public IntPtr LLM_Set_Template(IntPtr llm, string template) => LLM_Set_Template_Static(llm, template); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Save_Slot")] + public static extern IntPtr LLM_Save_Slot_Static(IntPtr llm, int id_slot, [MarshalAs(UnmanagedType.LPStr)] string filepath); + public IntPtr LLM_Save_Slot(IntPtr llm, int id_slot, string filepath) => LlamaLib.LLM_Save_Slot_Static(llm, id_slot, filepath); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Load_Slot")] + public static extern IntPtr LLM_Load_Slot_Static(IntPtr llm, int id_slot, [MarshalAs(UnmanagedType.LPStr)] string filepath); + public IntPtr LLM_Load_Slot(IntPtr llm, int id_slot, string filepath) => LlamaLib.LLM_Load_Slot_Static(llm, id_slot, filepath); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Cancel")] + public static extern void LLM_Cancel_Static(IntPtr llm, int idSlot); + public void LLM_Cancel(IntPtr llm, int idSlot) => LlamaLib.LLM_Cancel_Static(llm, idSlot); + + // LLMProvider functions + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Lora_Weight")] + public static extern bool LLM_Lora_Weight_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string loras_as_json); + public bool LLM_Lora_Weight(IntPtr llm, string loras_as_json) => LlamaLib.LLM_Lora_Weight_Static(llm, loras_as_json); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Lora_List")] + public static extern IntPtr LLM_Lora_List_Static(IntPtr llm); + public IntPtr LLM_Lora_List(IntPtr llm) => LlamaLib.LLM_Lora_List_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Delete")] + public static extern void LLM_Delete_Static(IntPtr llm); + public void LLM_Delete(IntPtr llm) => LlamaLib.LLM_Delete_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Start")] + public static extern void LLM_Start_Static(IntPtr llm); + public void LLM_Start(IntPtr llm) => LlamaLib.LLM_Start_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Started")] + [return : MarshalAs(UnmanagedType.I1)] + public static extern bool LLM_Started_Static(IntPtr llm); + public bool LLM_Started(IntPtr llm) => LlamaLib.LLM_Started_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Stop")] + public static extern void LLM_Stop_Static(IntPtr llm); + public void LLM_Stop(IntPtr llm) => LlamaLib.LLM_Stop_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Start_Server")] + public static extern void LLM_Start_Server_Static(IntPtr llm, + [MarshalAs(UnmanagedType.LPStr)] string host = "0.0.0.0", + int port = -1, + [MarshalAs(UnmanagedType.LPStr)] string apiKey = ""); + public void LLM_Start_Server(IntPtr llm, string host = "0.0.0.0", int port = -1, string apiKey = "") => LlamaLib.LLM_Start_Server_Static(llm, host, port, apiKey); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Stop_Server")] + public static extern void LLM_Stop_Server_Static(IntPtr llm); + public void LLM_Stop_Server(IntPtr llm) => LlamaLib.LLM_Stop_Server_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Join_Service")] + public static extern void LLM_Join_Service_Static(IntPtr llm); + public void LLM_Join_Service(IntPtr llm) => LlamaLib.LLM_Join_Service_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Join_Server")] + public static extern void LLM_Join_Server_Static(IntPtr llm); + public void LLM_Join_Server(IntPtr llm) => LlamaLib.LLM_Join_Server_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Set_SSL")] + public static extern void LLM_Set_SSL_Static(IntPtr llm, + [MarshalAs(UnmanagedType.LPStr)] string sslCert, + [MarshalAs(UnmanagedType.LPStr)] string sslKey); + public void LLM_Set_SSL(IntPtr llm, string sslCert, string sslKey) => LlamaLib.LLM_Set_SSL_Static(llm, sslCert, sslKey); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Status_Code")] + public static extern int LLM_Status_Code_Static(); + public int LLM_Status_Code() => LlamaLib.LLM_Status_Code_Static(); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Status_Message")] + public static extern IntPtr LLM_Status_Message_Static(); + public IntPtr LLM_Status_Message() => LlamaLib.LLM_Status_Message_Static(); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Embedding_Size")] + public static extern int LLM_Embedding_Size_Static(IntPtr llm); + public int LLM_Embedding_Size(IntPtr llm) => LlamaLib.LLM_Embedding_Size_Static(llm); + + // LLMService functions + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMService_Construct")] + public static extern IntPtr LLMService_Construct_Static( + [MarshalAs(UnmanagedType.LPStr)] string modelPath, + int numSlots = 1, + int numThreads = -1, + int numGpuLayers = 0, + [MarshalAs(UnmanagedType.I1)] bool flashAttention = false, + int contextSize = 4096, + int batchSize = 2048, + [MarshalAs(UnmanagedType.I1)] bool embeddingOnly = false, + int loraCount = 0, + IntPtr loraPaths = default); + public IntPtr LLMService_Construct( + int numSlots = 1, + string modelPath, + int numThreads = -1, + int numGpuLayers = 0, + bool flashAttention = false, + int contextSize = 4096, + int batchSize = 2048, + bool embeddingOnly = false, + int loraCount = 0, + IntPtr loraPaths = default) + => LlamaLib.LLMService_Construct_Static(modelPath, numSlots, numThreads, numGpuLayers, flashAttention, + contextSize, batchSize, embeddingOnly, loraCount, loraPaths); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMService_From_Command")] + public static extern IntPtr LLMService_From_Command_Static([MarshalAs(UnmanagedType.LPStr)] string paramsString); + public IntPtr LLMService_From_Command(string paramsString) => LlamaLib.LLMService_From_Command_Static(paramsString); + + // LLMClient functions + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMClient_Construct")] + public static extern IntPtr LLMClient_Construct_Static(IntPtr llm); + public IntPtr LLMClient_Construct(IntPtr llm) => LlamaLib.LLMClient_Construct_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMClient_Construct_Remote")] + public static extern IntPtr LLMClient_Construct_Remote_Static( + [MarshalAs(UnmanagedType.LPStr)] string url, + int port, + [MarshalAs(UnmanagedType.LPStr)] string apiKey = ""); + public IntPtr LLMClient_Construct_Remote(string url, int port, string apiKey = "") => LlamaLib.LLMClient_Construct_Remote_Static(url, port, apiKey); + + // LLMAgent functions + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Construct")] + public static extern IntPtr LLMAgent_Construct_Static(IntPtr llm, + [MarshalAs(UnmanagedType.LPStr)] string systemPrompt = "", + [MarshalAs(UnmanagedType.LPStr)] string userRole = "user", + [MarshalAs(UnmanagedType.LPStr)] string assistantRole = "assistant"); + public IntPtr LLMAgent_Construct(IntPtr llm, string systemPrompt = "", string userRole = "user", string assistantRole = "assistant") + => LLMAgent_Construct_Static(llm, systemPrompt, userRole, assistantRole); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Set_User_Role")] + public static extern void LLMAgent_Set_User_Role_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string userRole); + public void LLMAgent_Set_User_Role(IntPtr llm, string userRole) => LLMAgent_Set_User_Role_Static(llm, userRole); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Get_User_Role")] + public static extern IntPtr LLMAgent_Get_User_Role_Static(IntPtr llm); + public IntPtr LLMAgent_Get_User_Role(IntPtr llm) => LLMAgent_Get_User_Role_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Set_Assistant_Role")] + public static extern void LLMAgent_Set_Assistant_Role_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string assistantRole); + public void LLMAgent_Set_Assistant_Role(IntPtr llm, string assistantRole) => LLMAgent_Set_Assistant_Role_Static(llm, assistantRole); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Get_Assistant_Role")] + public static extern IntPtr LLMAgent_Get_Assistant_Role_Static(IntPtr llm); + public IntPtr LLMAgent_Get_Assistant_Role(IntPtr llm) => LLMAgent_Get_Assistant_Role_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Set_System_Prompt")] + public static extern void LLMAgent_Set_System_Prompt_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string systemPrompt); + public void LLMAgent_Set_System_Prompt(IntPtr llm, string systemPrompt) => LLMAgent_Set_System_Prompt_Static(llm, systemPrompt); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Get_System_Prompt")] + public static extern IntPtr LLMAgent_Get_System_Prompt_Static(IntPtr llm); + public IntPtr LLMAgent_Get_System_Prompt(IntPtr llm) => LLMAgent_Get_System_Prompt_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Set_Completion_Parameters")] + public static extern void LLM_Set_Completion_Parameters_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string parameters); + public void LLM_Set_Completion_Parameters(IntPtr llm, string parameters) => LLM_Set_Completion_Parameters_Static(llm, parameters); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Get_Completion_Parameters")] + public static extern IntPtr LLM_Get_Completion_Parameters_Static(IntPtr llm); + public IntPtr LLM_Get_Completion_Parameters(IntPtr llm) => LLM_Get_Completion_Parameters_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Set_Grammar")] + public static extern void LLM_Set_Grammar_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string grammar); + public void LLM_Set_Grammar(IntPtr llm, string grammar) => LLM_Set_Grammar_Static(llm, grammar); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLM_Get_Grammar")] + public static extern IntPtr LLM_Get_Grammar_Static(IntPtr llm); + public IntPtr LLM_Get_Grammar(IntPtr llm) => LLM_Get_Grammar_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Set_Slot")] + public static extern void LLMAgent_Set_Slot_Static(IntPtr llm, int slotId); + public void LLMAgent_Set_Slot(IntPtr llm, int slotId) => LLMAgent_Set_Slot_Static(llm, slotId); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Get_Slot")] + public static extern IntPtr LLMAgent_Get_Slot_Static(IntPtr llm); + public IntPtr LLMAgent_Get_Slot(IntPtr llm) => LLMAgent_Get_Slot_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Chat")] + public static extern IntPtr LLMAgent_Chat_Static(IntPtr llm, + [MarshalAs(UnmanagedType.LPStr)] string userPrompt, + [MarshalAs(UnmanagedType.I1)] bool addToHistory = true, + CharArrayCallback callback = null, + [MarshalAs(UnmanagedType.I1)] bool returnResponseJson = false); + public IntPtr LLMAgent_Chat(IntPtr llm, string userPrompt, bool addToHistory = true, CharArrayCallback callback = null, bool returnResponseJson = false) + => LLMAgent_Chat_Static(llm, userPrompt, addToHistory, callback, returnResponseJson); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Clear_History")] + public static extern void LLMAgent_Clear_History_Static(IntPtr llm); + public void LLMAgent_Clear_History(IntPtr llm) => LLMAgent_Clear_History_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Get_History")] + public static extern IntPtr LLMAgent_Get_History_Static(IntPtr llm); + public IntPtr LLMAgent_Get_History(IntPtr llm) => LLMAgent_Get_History_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Set_History")] + public static extern void LLMAgent_Set_History_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string historyJson); + public void LLMAgent_Set_History(IntPtr llm, string historyJson) => LLMAgent_Set_History_Static(llm, historyJson); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Add_Message")] + public static extern void LLMAgent_Add_Message_Static(IntPtr llm, + [MarshalAs(UnmanagedType.LPStr)] string role, + [MarshalAs(UnmanagedType.LPStr)] string content); + public void LLMAgent_Add_Message(IntPtr llm, string role, string content) => LLMAgent_Add_Message_Static(llm, role, content); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Remove_Last_Message")] + public static extern void LLMAgent_Remove_Last_Message_Static(IntPtr llm); + public void LLMAgent_Remove_Last_Message(IntPtr llm) => LLMAgent_Remove_Last_Message_Static(llm); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Save_History")] + public static extern void LLMAgent_Save_History_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string filepath); + public void LLMAgent_Save_History(IntPtr llm, string filepath) => LLMAgent_Save_History_Static(llm, filepath); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Load_History")] + public static extern void LLMAgent_Load_History_Static(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string filepath); + public void LLMAgent_Load_History(IntPtr llm, string filepath) => LLMAgent_Load_History_Static(llm, filepath); + + [DllImport(DllName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LLMAgent_Get_History_Size")] + public static extern int LLMAgent_Get_History_Size_Static(IntPtr llm); + public int LLMAgent_Get_History_Size(IntPtr llm) => LLMAgent_Get_History_Size_Static(llm); + +#else + // Desktop platform implementation with dynamic loading + private static List instances = new List(); + private static readonly object runtimeLock = new object(); + private static IntPtr runtimeLibraryHandle = IntPtr.Zero; + private IntPtr libraryHandle = IntPtr.Zero; + private static int debugLevelGlobal = 0; + private static CharArrayCallback loggingCallbackGlobal = null; + + private void LoadRuntimeLibrary() + { + lock (runtimeLock) + { + if (runtimeLibraryHandle == IntPtr.Zero) + { + runtimeLibraryHandle = LibraryLoader.LoadLibrary(GetRuntimeLibraryPath()); + Has_GPU_Layers = LibraryLoader.GetSymbolDelegate(runtimeLibraryHandle, "Has_GPU_Layers"); + Available_Architectures = LibraryLoader.GetSymbolDelegate(runtimeLibraryHandle, "Available_Architectures"); + } + } + } + + public LlamaLib(bool gpu = false) + { + LoadRuntimeLibrary(); + LoadLibraries(gpu); + lock (runtimeLock) + { + instances.Add(this); + LLM_Debug(debugLevelGlobal); + if (loggingCallbackGlobal != null) LLM_Logging_Callback(loggingCallbackGlobal); + } + } + + public static string GetPlatform() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + return "linux-x64"; + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + if (RuntimeInformation.ProcessArchitecture == Architecture.X64) + return "osx-x64"; + else + return "osx-arm64"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + return "win-x64"; + else throw new ArgumentException("Unknown platform " + RuntimeInformation.OSDescription); + } + + public virtual string FindLibrary(string libraryName) + { + string baseDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); + + List lookupDirs = new List(); + lookupDirs.Add(Path.Combine(baseDir, "runtimes", GetPlatform(), "native")); + lookupDirs.Add(baseDir); + + foreach (string lookupDir in lookupDirs) + { + string libraryPath = Path.Combine(lookupDir, libraryName); + if (File.Exists(libraryPath)) return libraryPath; + } + + throw new InvalidOperationException($"Library {libraryName} not found!"); + } + + private string GetRuntimeLibraryPath() + { + string platform = GetPlatform(); + string libName; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + libName = "libllamalib_" + platform + "_runtime.so"; + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + libName = "libllamalib_" + platform + "_runtime.dylib"; + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + libName = "llamalib_" + platform + "_runtime.dll"; + else + throw new ArgumentException("Unknown platform " + RuntimeInformation.OSDescription); + return FindLibrary(libName); + } + + private void LoadLibraries(bool gpu) + { + string architecturesString = Marshal.PtrToStringAnsi(Available_Architectures(gpu)); + if (string.IsNullOrEmpty(architecturesString)) + { + throw new InvalidOperationException("No architectures available for the specified GPU setting."); + } + + string[] libraries = architecturesString.Split(','); + Exception lastException = null; + + foreach (string library in libraries) + { + try + { + string libraryPath = FindLibrary(library.Trim()); + if (debugLevelGlobal > 0) Console.WriteLine("Trying " + libraryPath); + libraryHandle = LibraryLoader.LoadLibrary(libraryPath); + LoadFunctionPointers(); + architecture = library.Trim(); + if (debugLevelGlobal > 0) Console.WriteLine("Successfully loaded: " + libraryPath); + return; + } + catch (Exception ex) + { + if (debugLevelGlobal > 0) Console.WriteLine($"Failed to load library {library}: {ex.Message}."); + lastException = ex; + continue; + } + } + + // If we get here, no library was successfully loaded + throw new InvalidOperationException($"Failed to load any library. Available libraries: {string.Join(", ", libraries)}", lastException); + } + + private void LoadFunctionPointers() + { + LLM_Debug = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Debug"); + LLM_Logging_Callback = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Logging_Callback"); + LLM_Logging_Stop = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Logging_Stop"); + LLM_Get_Template = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Get_Template"); + LLM_Set_Template = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Set_Template"); + LLM_Apply_Template = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Apply_Template"); + LLM_Tokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Tokenize"); + LLM_Detokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Detokenize"); + LLM_Embeddings = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Embeddings"); + LLM_Completion = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Completion"); + LLM_Save_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Save_Slot"); + LLM_Load_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Load_Slot"); + LLM_Cancel = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Cancel"); + LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_Weight"); + LLM_Lora_List = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_List"); + LLM_Delete = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Delete"); + LLM_Start = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Start"); + LLM_Started = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Started"); + LLM_Stop = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Stop"); + LLM_Start_Server = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Start_Server"); + LLM_Stop_Server = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Stop_Server"); + LLM_Join_Service = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Join_Service"); + LLM_Join_Server = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Join_Server"); + LLM_Set_SSL = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Set_SSL"); + LLM_Status_Code = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Status_Code"); + LLM_Status_Message = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Status_Message"); + LLM_Embedding_Size = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Embedding_Size"); + LLMService_Construct = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMService_Construct"); + LLMService_From_Command = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMService_From_Command"); + LLMClient_Construct = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMClient_Construct"); + LLMClient_Construct_Remote = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMClient_Construct_Remote"); + LLMAgent_Construct = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Construct"); + LLMAgent_Set_User_Role = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Set_User_Role"); + LLMAgent_Get_User_Role = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Get_User_Role"); + LLMAgent_Set_Assistant_Role = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Set_Assistant_Role"); + LLMAgent_Get_Assistant_Role = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Get_Assistant_Role"); + LLMAgent_Set_System_Prompt = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Set_System_Prompt"); + LLMAgent_Get_System_Prompt = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Get_System_Prompt"); + LLMAgent_Set_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Set_Slot"); + LLMAgent_Get_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Get_Slot"); + LLM_Set_Completion_Parameters = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Set_Completion_Parameters"); + LLM_Get_Completion_Parameters = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Get_Completion_Parameters"); + LLM_Set_Grammar = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Set_Grammar"); + LLM_Get_Grammar = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Get_Grammar"); + LLMAgent_Chat = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Chat"); + LLMAgent_Clear_History = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Clear_History"); + LLMAgent_Get_History = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Get_History"); + LLMAgent_Set_History = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Set_History"); + LLMAgent_Add_Message = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Add_Message"); + LLMAgent_Remove_Last_Message = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Remove_Last_Message"); + LLMAgent_Save_History = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Save_History"); + LLMAgent_Load_History = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Load_History"); + LLMAgent_Get_History_Size = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLMAgent_Get_History_Size"); + } + + // Delegate definitions for desktop platforms + // Runtime lib + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr Available_Architectures_Delegate([MarshalAs(UnmanagedType.I1)] bool gpu); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate bool Has_GPU_Layers_Delegate([MarshalAs(UnmanagedType.LPStr)] string command); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Debug_Delegate(int debugLevel); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Logging_Callback_Delegate(CharArrayCallback callback); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Logging_Stop_Delegate(); + + // Main lib + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Get_Template_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Set_Template_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string template); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Apply_Template_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string messages_as_json); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Tokenize_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string query); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Detokenize_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string tokens_as_json); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Embeddings_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string query); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Completion_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string query, CharArrayCallback callback, int id_slot = -1, bool return_response_json = false); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Save_Slot_Delegate(IntPtr llm, int id_slot, [MarshalAs(UnmanagedType.LPStr)] string filepath); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Load_Slot_Delegate(IntPtr llm, int id_slot, [MarshalAs(UnmanagedType.LPStr)] string filepath); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Cancel_Delegate(IntPtr llm, int idSlot); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate bool LLM_Lora_Weight_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string loras_as_json); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Lora_List_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Delete_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Start_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate bool LLM_Started_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Stop_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Start_Server_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string host = "0.0.0.0", int port = -1, [MarshalAs(UnmanagedType.LPStr)] string apiKey = ""); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Stop_Server_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Join_Service_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Join_Server_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Set_SSL_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string sslCert, [MarshalAs(UnmanagedType.LPStr)] string sslKey); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate int LLM_Status_Code_Delegate(); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Status_Message_Delegate(); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate int LLM_Embedding_Size_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLMService_Construct_Delegate( + [MarshalAs(UnmanagedType.LPStr)] string modelPath, + int numSlots = 1, + int numThreads = -1, + int numGpuLayers = 0, + [MarshalAs(UnmanagedType.I1)] bool flashAttention = false, + int contextSize = 4096, + int batchSize = 2048, + [MarshalAs(UnmanagedType.I1)] bool embeddingOnly = false, + int loraCount = 0, + IntPtr loraPaths = default); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLMService_From_Command_Delegate([MarshalAs(UnmanagedType.LPStr)] string paramsString); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLMClient_Construct_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLMClient_Construct_Remote_Delegate([MarshalAs(UnmanagedType.LPStr)] string url, int port, [MarshalAs(UnmanagedType.LPStr)] string apiKey = ""); + + // LLMAgent functions + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLMAgent_Construct_Delegate(IntPtr llm, + [MarshalAs(UnmanagedType.LPStr)] string systemPrompt = "", + [MarshalAs(UnmanagedType.LPStr)] string userRole = "user", + [MarshalAs(UnmanagedType.LPStr)] string assistantRole = "assistant"); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLMAgent_Set_User_Role_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string userRole); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLMAgent_Get_User_Role_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLMAgent_Set_Assistant_Role_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string assistantRole); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLMAgent_Get_Assistant_Role_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLMAgent_Set_System_Prompt_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string systemPrompt); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLMAgent_Get_System_Prompt_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Set_Grammar_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string grammar); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Get_Grammar_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLM_Set_Completion_Parameters_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string parameters); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLM_Get_Completion_Parameters_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLMAgent_Set_Slot_Delegate(IntPtr llm, int slotId); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate int LLMAgent_Get_Slot_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLMAgent_Chat_Delegate(IntPtr llm, + [MarshalAs(UnmanagedType.LPStr)] string userPrompt, + [MarshalAs(UnmanagedType.I1)] bool addToHistory = true, + CharArrayCallback callback = null, + [MarshalAs(UnmanagedType.I1)] bool returnResponseJson = false); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLMAgent_Clear_History_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate IntPtr LLMAgent_Get_History_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLMAgent_Set_History_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string historyJson); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLMAgent_Add_Message_Delegate(IntPtr llm, + [MarshalAs(UnmanagedType.LPStr)] string role, + [MarshalAs(UnmanagedType.LPStr)] string content); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLMAgent_Remove_Last_Message_Delegate(IntPtr llm); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLMAgent_Save_History_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string filepath); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void LLMAgent_Load_History_Delegate(IntPtr llm, [MarshalAs(UnmanagedType.LPStr)] string filepath); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate int LLMAgent_Get_History_Size_Delegate(IntPtr llm); + + + // Function pointers for desktop platforms + // Runtime lib + public static Available_Architectures_Delegate Available_Architectures; + public static Has_GPU_Layers_Delegate Has_GPU_Layers; + + // Main lib + public LLM_Debug_Delegate LLM_Debug; + public LLM_Logging_Callback_Delegate LLM_Logging_Callback; + public LLM_Logging_Stop_Delegate LLM_Logging_Stop; + public LLM_Get_Template_Delegate LLM_Get_Template; + public LLM_Set_Template_Delegate LLM_Set_Template; + public LLM_Apply_Template_Delegate LLM_Apply_Template; + public LLM_Tokenize_Delegate LLM_Tokenize; + public LLM_Detokenize_Delegate LLM_Detokenize; + public LLM_Embeddings_Delegate LLM_Embeddings; + public LLM_Completion_Delegate LLM_Completion; + public LLM_Save_Slot_Delegate LLM_Save_Slot; + public LLM_Load_Slot_Delegate LLM_Load_Slot; + public LLM_Cancel_Delegate LLM_Cancel; + public LLM_Lora_Weight_Delegate LLM_Lora_Weight; + public LLM_Lora_List_Delegate LLM_Lora_List; + public LLM_Delete_Delegate LLM_Delete; + public LLM_Start_Delegate LLM_Start; + public LLM_Started_Delegate LLM_Started; + public LLM_Stop_Delegate LLM_Stop; + public LLM_Start_Server_Delegate LLM_Start_Server; + public LLM_Stop_Server_Delegate LLM_Stop_Server; + public LLM_Join_Service_Delegate LLM_Join_Service; + public LLM_Join_Server_Delegate LLM_Join_Server; + public LLM_Set_SSL_Delegate LLM_Set_SSL; + public LLM_Status_Code_Delegate LLM_Status_Code; + public LLM_Status_Message_Delegate LLM_Status_Message; + public LLM_Embedding_Size_Delegate LLM_Embedding_Size; + public LLMService_Construct_Delegate LLMService_Construct; + public LLMService_From_Command_Delegate LLMService_From_Command; + public LLMClient_Construct_Delegate LLMClient_Construct; + public LLMClient_Construct_Remote_Delegate LLMClient_Construct_Remote; + public LLMAgent_Construct_Delegate LLMAgent_Construct; + public LLMAgent_Set_User_Role_Delegate LLMAgent_Set_User_Role; + public LLMAgent_Get_User_Role_Delegate LLMAgent_Get_User_Role; + public LLMAgent_Set_Assistant_Role_Delegate LLMAgent_Set_Assistant_Role; + public LLMAgent_Get_Assistant_Role_Delegate LLMAgent_Get_Assistant_Role; + public LLMAgent_Set_System_Prompt_Delegate LLMAgent_Set_System_Prompt; + public LLMAgent_Get_System_Prompt_Delegate LLMAgent_Get_System_Prompt; + public LLMAgent_Set_Slot_Delegate LLMAgent_Set_Slot; + public LLMAgent_Get_Slot_Delegate LLMAgent_Get_Slot; + public LLM_Set_Completion_Parameters_Delegate LLM_Set_Completion_Parameters; + public LLM_Get_Completion_Parameters_Delegate LLM_Get_Completion_Parameters; + public LLM_Set_Grammar_Delegate LLM_Set_Grammar; + public LLM_Get_Grammar_Delegate LLM_Get_Grammar; + public LLMAgent_Chat_Delegate LLMAgent_Chat; + public LLMAgent_Clear_History_Delegate LLMAgent_Clear_History; + public LLMAgent_Get_History_Delegate LLMAgent_Get_History; + public LLMAgent_Set_History_Delegate LLMAgent_Set_History; + public LLMAgent_Add_Message_Delegate LLMAgent_Add_Message; + public LLMAgent_Remove_Last_Message_Delegate LLMAgent_Remove_Last_Message; + public LLMAgent_Save_History_Delegate LLMAgent_Save_History; + public LLMAgent_Load_History_Delegate LLMAgent_Load_History; + public LLMAgent_Get_History_Size_Delegate LLMAgent_Get_History_Size; + + public static void Debug(int debugLevel) + { + debugLevelGlobal = debugLevel; + foreach (LlamaLib instance in instances) + { + instance.LLM_Debug(debugLevel); + } + } + + public static void LoggingCallback(CharArrayCallback callback) + { + loggingCallbackGlobal = callback; + foreach (LlamaLib instance in instances) + { + instance.LLM_Logging_Callback(callback); + } + } + + public static void LoggingStop() + { + LoggingCallback(null); + } + + public void Dispose() + { + LibraryLoader.FreeLibrary(libraryHandle); + libraryHandle = IntPtr.Zero; + + lock (runtimeLock) + { + instances.Remove(this); + if (instances.Count == 0) + { + LibraryLoader.FreeLibrary(runtimeLibraryHandle); + runtimeLibraryHandle = IntPtr.Zero; + } + } + } + + ~LlamaLib() + { + Dispose(); + } + +#endif + } +} diff --git a/Runtime/LlamaLib/LlamaLib.cs.meta b/Runtime/LlamaLib/LlamaLib.cs.meta new file mode 100644 index 00000000..a6e2a0af --- /dev/null +++ b/Runtime/LlamaLib/LlamaLib.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: fe209b11ce081b521a4c1bffd25c4e32 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/RAG/Chunking.cs b/Runtime/RAG/Chunking.cs index 4040e663..094b4b2b 100644 --- a/Runtime/RAG/Chunking.cs +++ b/Runtime/RAG/Chunking.cs @@ -37,7 +37,7 @@ public void ReturnChunks(bool returnChunks) /// /// phrase /// List of start/end indices of the split chunks - public abstract Task> Split(string input); + public abstract List<(int, int)> Split(string input); /// /// Retrieves the phrase with the specific id @@ -60,15 +60,15 @@ public override string Get(int key) /// input phrase /// data group to add it to /// phrase id - public override async Task Add(string inputString, string group = "") + public override int Add(string inputString, string group = "") { int key = nextKey++; // sentence -> phrase List sentenceIds = new List(); - foreach ((int startIndex, int endIndex) in await Split(inputString)) + foreach ((int startIndex, int endIndex) in Split(inputString)) { string sentenceText = inputString.Substring(startIndex, endIndex - startIndex + 1); - int sentenceId = await search.Add(sentenceText, group); + int sentenceId = search.Add(sentenceText, group); sentenceIds.Add(sentenceId); sentenceToPhrase[sentenceId] = key; @@ -166,9 +166,9 @@ public override int Count(string group) /// search query /// data group to search in /// incremental search key - public override async Task IncrementalSearch(string queryString, string group = "") + public override int IncrementalSearch(string queryString, string group = "") { - return await search.IncrementalSearch(queryString, group); + return search.IncrementalSearch(queryString, group); } /// diff --git a/Runtime/RAG/RAG.cs b/Runtime/RAG/RAG.cs index a1b26183..d83e57ca 100644 --- a/Runtime/RAG/RAG.cs +++ b/Runtime/RAG/RAG.cs @@ -117,13 +117,13 @@ public virtual void OnValidate() #endif public override string Get(int key) { return GetSearcher().Get(key); } - public override async Task Add(string inputString, string group = "") { return await GetSearcher().Add(inputString, group); } + public override int Add(string inputString, string group = "") { return GetSearcher().Add(inputString, group); } public override int Remove(string inputString, string group = "") { return GetSearcher().Remove(inputString, group); } public override void Remove(int key) { GetSearcher().Remove(key); } public override int Count() { return GetSearcher().Count(); } public override int Count(string group) { return GetSearcher().Count(group); } public override void Clear() { GetSearcher().Clear(); } - public override async Task IncrementalSearch(string queryString, string group = "") { return await GetSearcher().IncrementalSearch(queryString, group);} + public override int IncrementalSearch(string queryString, string group = "") { return GetSearcher().IncrementalSearch(queryString, group);} public override (string[], float[], bool) IncrementalFetch(int fetchKey, int k) { return GetSearcher().IncrementalFetch(fetchKey, k);} public override (int[], float[], bool) IncrementalFetchKeys(int fetchKey, int k) { return GetSearcher().IncrementalFetchKeys(fetchKey, k);} public override void IncrementalSearchComplete(int fetchKey) { GetSearcher().IncrementalSearchComplete(fetchKey);} diff --git a/Runtime/RAG/Search.cs b/Runtime/RAG/Search.cs index 249beafb..84e4938f 100644 --- a/Runtime/RAG/Search.cs +++ b/Runtime/RAG/Search.cs @@ -33,7 +33,7 @@ public abstract class Searchable : MonoBehaviour /// input phrase /// data group to add it to /// phrase id - public abstract Task Add(string inputString, string group = ""); + public abstract int Add(string inputString, string group = ""); /// /// Removes a phrase from the search. @@ -73,7 +73,7 @@ public abstract class Searchable : MonoBehaviour /// search query /// data group to search in /// incremental search key - public abstract Task IncrementalSearch(string queryString, string group = ""); + public abstract int IncrementalSearch(string queryString, string group = ""); /// /// Retrieves the most similar search results in batches (incremental search). @@ -112,9 +112,9 @@ public abstract class Searchable : MonoBehaviour /// `bool` indicating if the search is exhausted. /// /// - public async Task<(string[], float[])> Search(string queryString, int k, string group = "") + public (string[], float[]) Search(string queryString, int k, string group = "") { - int fetchKey = await IncrementalSearch(queryString, group); + int fetchKey = IncrementalSearch(queryString, group); (string[] phrases, float[] distances, bool completed) = IncrementalFetch(fetchKey, k); if (!completed) IncrementalSearchComplete(fetchKey); return (phrases, distances); @@ -290,11 +290,11 @@ public void SetLLM(LLM llm) /// Array of distances for each result (`float[]`). /// /// - public async Task<(string[], float[])> SearchFromList(string query, string[] searchList) + public (string[], float[]) SearchFromList(string query, string[] searchList) { - float[] embedding = await Encode(query); + float[] embedding = Encode(query); float[][] embeddingsList = new float[searchList.Length][]; - for (int i = 0; i < searchList.Length; i++) embeddingsList[i] = await Encode(searchList[i]); + for (int i = 0; i < searchList.Length; i++) embeddingsList[i] = Encode(searchList[i]); float[] unsortedDistances = InverseDotProduct(embedding, embeddingsList); List<(string, float)> sortedLists = searchList.Zip(unsortedDistances, (first, second) => (first, second)) @@ -339,19 +339,19 @@ public static float[] InverseDotProduct(float[] vector1, float[][] vector2) return results; } - public virtual async Task Encode(string inputString) + public virtual float[] Encode(string inputString) { - return (await llmEmbedder.Embeddings(inputString)).ToArray(); + return llmEmbedder.Embeddings(inputString).ToArray(); } - public virtual async Task> Tokenize(string query, Callback> callback = null) + public virtual List Tokenize(string query, Callback> callback = null) { - return await llmEmbedder.Tokenize(query, callback); + return llmEmbedder.Tokenize(query, callback); } - public async Task Detokenize(List tokens, Callback callback = null) + public virtual string Detokenize(List tokens, Callback callback = null) { - return await llmEmbedder.Detokenize(tokens, callback); + return llmEmbedder.Detokenize(tokens, callback); } public override string Get(int key) @@ -360,10 +360,10 @@ public override string Get(int key) return null; } - public override async Task Add(string inputString, string group = "") + public override int Add(string inputString, string group = "") { int key = nextKey++; - AddInternal(key, await Encode(inputString)); + AddInternal(key, Encode(inputString)); data[key] = inputString; if (!dataSplits.ContainsKey(group)) dataSplits[group] = new List(){key}; @@ -421,9 +421,9 @@ public override int Count(string group) return dataSplit.Count; } - public override async Task IncrementalSearch(string queryString, string group = "") + public override int IncrementalSearch(string queryString, string group = "") { - return IncrementalSearch(await Encode(queryString), group); + return IncrementalSearch(Encode(queryString), group); } public override void Save(ZipArchive archive) diff --git a/Runtime/RAG/SentenceSplitter.cs b/Runtime/RAG/SentenceSplitter.cs index 809300f0..3d46f938 100644 --- a/Runtime/RAG/SentenceSplitter.cs +++ b/Runtime/RAG/SentenceSplitter.cs @@ -25,30 +25,28 @@ public class SentenceSplitter : Chunking /// /// phrase /// List of start/end indices of the split chunks - public override async Task> Split(string input) + public override List<(int, int)> Split(string input) { List<(int, int)> indices = new List<(int, int)>(); - await Task.Run(() => { - int startIndex = 0; - bool seenChar = false; - for (int i = 0; i < input.Length; i++) + int startIndex = 0; + bool seenChar = false; + for (int i = 0; i < input.Length; i++) + { + bool isDelimiter = delimiters.Contains(input[i]); + if (isDelimiter) { - bool isDelimiter = delimiters.Contains(input[i]); - if (isDelimiter) - { - while ((i < input.Length - 1) && (delimiters.Contains(input[i + 1]) || char.IsWhiteSpace(input[i + 1]))) i++; - } - else - { - if (!seenChar) seenChar = !char.IsWhiteSpace(input[i]); - } - if ((i == input.Length - 1) || (isDelimiter && seenChar)) - { - indices.Add((startIndex, i)); - startIndex = i + 1; - } + while ((i < input.Length - 1) && (delimiters.Contains(input[i + 1]) || char.IsWhiteSpace(input[i + 1]))) i++; } - }); + else + { + if (!seenChar) seenChar = !char.IsWhiteSpace(input[i]); + } + if ((i == input.Length - 1) || (isDelimiter && seenChar)) + { + indices.Add((startIndex, i)); + startIndex = i + 1; + } + } return indices; } } diff --git a/Runtime/RAG/TokenSplitter.cs b/Runtime/RAG/TokenSplitter.cs index cb57313b..7abf563b 100644 --- a/Runtime/RAG/TokenSplitter.cs +++ b/Runtime/RAG/TokenSplitter.cs @@ -47,17 +47,17 @@ protected int DetermineEndIndex(string input, string detokenised, int startIndex /// /// phrase /// List of start/end indices of the split chunks - public override async Task> Split(string input) + public override List<(int, int)> Split(string input) { List<(int, int)> indices = new List<(int, int)>(); - List tokens = await search.Tokenize(input); + List tokens = search.Tokenize(input); if (tokens.Count == 0) return indices; int startIndex = 0; for (int i = 0; i < tokens.Count; i += numTokens) { int batchTokens = Math.Min(tokens.Count, i + numTokens) - i; - string detokenised = await search.Detokenize(tokens.GetRange(i, batchTokens)); + string detokenised = search.Detokenize(tokens.GetRange(i, batchTokens)); int endIndex = DetermineEndIndex(input, detokenised, startIndex); indices.Add((startIndex, endIndex)); startIndex = endIndex + 1; diff --git a/Runtime/RAG/WordSplitter.cs b/Runtime/RAG/WordSplitter.cs index 251bdff5..69fbb6ca 100644 --- a/Runtime/RAG/WordSplitter.cs +++ b/Runtime/RAG/WordSplitter.cs @@ -23,7 +23,7 @@ public class WordSplitter : Chunking /// /// phrase /// List of start/end indices of the split chunks - public override async Task> Split(string input) + public override List<(int, int)> Split(string input) { bool IsBoundary(char c) { @@ -31,27 +31,25 @@ bool IsBoundary(char c) } List<(int, int)> indices = new List<(int, int)>(); - await Task.Run(() => { - List<(int, int)> wordIndices = new List<(int, int)>(); - int startIndex = 0; - int endIndex; - for (int i = 0; i < input.Length; i++) + List<(int, int)> wordIndices = new List<(int, int)>(); + int startIndex = 0; + int endIndex; + for (int i = 0; i < input.Length; i++) + { + if (i == input.Length - 1 || IsBoundary(input[i])) { - if (i == input.Length - 1 || IsBoundary(input[i])) - { - while (i < input.Length - 1 && IsBoundary(input[i + 1])) i++; - endIndex = i; - wordIndices.Add((startIndex, endIndex)); - startIndex = i + 1; - } + while (i < input.Length - 1 && IsBoundary(input[i + 1])) i++; + endIndex = i; + wordIndices.Add((startIndex, endIndex)); + startIndex = i + 1; } + } - for (int i = 0; i < wordIndices.Count; i += numWords) - { - int iTo = Math.Min(wordIndices.Count - 1, i + numWords - 1); - indices.Add((wordIndices[i].Item1, wordIndices[iTo].Item2)); - } - }); + for (int i = 0; i < wordIndices.Count; i += numWords) + { + int iTo = Math.Min(wordIndices.Count - 1, i + numWords - 1); + indices.Add((wordIndices[i].Item1, wordIndices[iTo].Item2)); + } return indices; } } diff --git a/Samples~/ChatBot/ChatBot.cs b/Samples~/ChatBot/ChatBot.cs index ac2f893d..4897018f 100644 --- a/Samples~/ChatBot/ChatBot.cs +++ b/Samples~/ChatBot/ChatBot.cs @@ -15,7 +15,7 @@ public class ChatBot : MonoBehaviour public Font font; public int fontSize = 16; public int bubbleWidth = 600; - public LLMCharacter llmCharacter; + public LLMAgent llmAgent; public float textPadding = 10f; public float bubbleSpacing = 10f; public Sprite sprite; @@ -55,12 +55,12 @@ void Start() inputBubble.setInteractable(false); stopButton.gameObject.SetActive(true); ShowLoadedMessages(); - _ = llmCharacter.Warmup(WarmUpCallback); + _ = llmAgent.Warmup(WarmUpCallback); } Bubble AddBubble(string message, bool isPlayerMessage) { - Bubble bubble = new Bubble(chatContainer, isPlayerMessage? playerUI: aiUI, isPlayerMessage? "PlayerBubble": "AIBubble", message); + Bubble bubble = new Bubble(chatContainer, isPlayerMessage ? playerUI : aiUI, isPlayerMessage ? "PlayerBubble" : "AIBubble", message); chatBubbles.Add(bubble); bubble.OnResize(UpdateBubblePositions); return bubble; @@ -68,7 +68,7 @@ Bubble AddBubble(string message, bool isPlayerMessage) void ShowLoadedMessages() { - for (int i=1; i LoadQuestionAnswers(string questionAnswersText) @@ -82,7 +82,7 @@ public async Task CreateEmbeddings() PlayerText.text += $"Creating Embeddings for {botName} (only once)...\n"; List questions = botQuestionAnswers.Keys.ToList(); stopwatch.Start(); - foreach (string question in questions) await rag.Add(question, botName); + foreach (string question in questions) rag.Add(question, botName); stopwatch.Stop(); Debug.Log($"embedded {rag.Count()} phrases in {stopwatch.Elapsed.TotalMilliseconds / 1000f} secs"); } @@ -95,20 +95,20 @@ public async Task CreateEmbeddings() } } - public async Task> Retrieval(string question) + public List Retrieval(string question) { // find similar questions for the current bot using the RAG - (string[] similarQuestions, _) = await rag.Search(question, numRAGResults, currentBotName); + (string[] similarQuestions, _) = rag.Search(question, numRAGResults, currentBotName); // get the answers of the similar questions List similarAnswers = new List(); foreach (string similarQuestion in similarQuestions) similarAnswers.Add(botQuestionAnswers[currentBotName][similarQuestion]); return similarAnswers; } - public async Task ConstructPrompt(string question) + public string ConstructPrompt(string question) { // get similar answers from the RAG - List similarAnswers = await Retrieval(question); + List similarAnswers = Retrieval(question); // create the prompt using the user question and the similar answers string answers = ""; foreach (string similarAnswer in similarAnswers) answers += $"\n- {similarAnswer}"; @@ -118,12 +118,12 @@ public async Task ConstructPrompt(string question) return prompt; } - protected async override void OnInputFieldSubmit(string question) + protected override void OnInputFieldSubmit(string question) { PlayerText.interactable = false; SetAIText("..."); - string prompt = await ConstructPrompt(question); - _ = llmCharacter.Chat(prompt, SetAIText, AIReplyComplete); + string prompt = ConstructPrompt(question); + _ = llmAgent.ChatAsync(prompt, SetAIText, AIReplyComplete); } protected override void DropdownChange(int selection) @@ -134,8 +134,8 @@ protected override void DropdownChange(int selection) botImages[currentBotName].gameObject.SetActive(true); Debug.Log($"{currentBotName}: {rag.Count(currentBotName)} phrases available"); - // set the LLMCharacter name - llmCharacter.AIName = currentBotName; + // set the LLMAgent name + llmAgent.assistantRole = currentBotName; } void SetAIText(string text) @@ -152,7 +152,7 @@ void AIReplyComplete() public void CancelRequests() { - llmCharacter.CancelRequests(); + llmAgent.CancelRequests(); AIReplyComplete(); } @@ -162,11 +162,11 @@ public void ExitGame() Application.Quit(); } - void CheckLLM(LLMCaller llmCaller, bool debug) + void CheckLLM(LLMClient llmClient, bool debug) { - if (!llmCaller.remote && llmCaller.llm != null && llmCaller.llm.model == "") + if (!llmClient.remote && llmClient.llm != null && llmClient.llm.model == "") { - string error = $"Please select a llm model in the {llmCaller.llm.gameObject.name} GameObject!"; + string error = $"Please select a llm model in the {llmClient.llm.gameObject.name} GameObject!"; if (debug) Debug.LogWarning(error); else throw new System.Exception(error); } @@ -175,7 +175,7 @@ void CheckLLM(LLMCaller llmCaller, bool debug) void CheckLLMs(bool debug) { CheckLLM(rag.search.llmEmbedder, debug); - CheckLLM(llmCharacter, debug); + CheckLLM(llmAgent, debug); } bool onValidateWarning = true; diff --git a/Samples~/MobileDemo/MobileDemo.cs b/Samples~/MobileDemo/MobileDemo.cs index b7aa0e36..ccf74529 100644 --- a/Samples~/MobileDemo/MobileDemo.cs +++ b/Samples~/MobileDemo/MobileDemo.cs @@ -7,7 +7,7 @@ namespace LLMUnitySamples { public class MobileDemo : MonoBehaviour { - public LLMCharacter llmCharacter; + public LLMAgent llmAgent; public GameObject ChatPanel; public InputField playerText; @@ -45,7 +45,7 @@ async Task DownloadThenWarmup() async Task WarmUp() { AIText.text += $"Warming up the model..."; - await llmCharacter.Warmup(); + await llmAgent.Warmup(); AIText.text = ""; AIReplyComplete(); } @@ -60,7 +60,7 @@ void onInputFieldSubmit(string message) { playerText.interactable = false; AIText.text = "..."; - _ = llmCharacter.Chat(message, SetAIText, AIReplyComplete); + _ = llmAgent.ChatAsync(message, SetAIText, AIReplyComplete); } public void SetAIText(string text) @@ -77,7 +77,7 @@ public void AIReplyComplete() public void CancelRequests() { - llmCharacter.CancelRequests(); + llmAgent.CancelRequests(); AIReplyComplete(); } @@ -91,14 +91,14 @@ public void ExitGame() bool onValidateInfo = true; void OnValidate() { - if (onValidateWarning && !llmCharacter.remote && llmCharacter.llm != null && llmCharacter.llm.model == "") + if (onValidateWarning && !llmAgent.remote && llmAgent.llm != null && llmAgent.llm.model == "") { - Debug.LogWarning($"Please select a model in the {llmCharacter.llm.gameObject.name} GameObject!"); + Debug.LogWarning($"Please select a model in the {llmAgent.llm.gameObject.name} GameObject!"); onValidateWarning = false; } if (onValidateInfo) { - Debug.Log($"Select 'Download On Start' in the {llmCharacter.llm.gameObject.name} GameObject to download the models when the app starts."); + Debug.Log($"Select 'Download On Start' in the {llmAgent.llm.gameObject.name} GameObject to download the models when the app starts."); onValidateInfo = false; } } diff --git a/Samples~/MultipleCharacters/MultipleCharacters.cs b/Samples~/MultipleCharacters/MultipleCharacters.cs index 82067e2c..39a426ad 100644 --- a/Samples~/MultipleCharacters/MultipleCharacters.cs +++ b/Samples~/MultipleCharacters/MultipleCharacters.cs @@ -9,13 +9,13 @@ public class MultipleCharactersInteraction { InputField playerText; Text AIText; - LLMCharacter llmCharacter; + LLMAgent llmAgent; - public MultipleCharactersInteraction(InputField playerText, Text AIText, LLMCharacter llmCharacter) + public MultipleCharactersInteraction(InputField playerText, Text AIText, LLMAgent llmAgent) { this.playerText = playerText; this.AIText = AIText; - this.llmCharacter = llmCharacter; + this.llmAgent = llmAgent; } public void Start() @@ -28,7 +28,7 @@ public void onInputFieldSubmit(string message) { playerText.interactable = false; AIText.text = "..."; - _ = llmCharacter.Chat(message, SetAIText, AIReplyComplete); + _ = llmAgent.ChatAsync(message, SetAIText, AIReplyComplete); } public void SetAIText(string text) @@ -46,12 +46,12 @@ public void AIReplyComplete() public class MultipleCharacters : MonoBehaviour { - public LLMCharacter llmCharacter1; + public LLMAgent llmCharacter1; public InputField playerText1; public Text AIText1; MultipleCharactersInteraction interaction1; - public LLMCharacter llmCharacter2; + public LLMAgent llmCharacter2; public InputField playerText2; public Text AIText2; MultipleCharactersInteraction interaction2; diff --git a/Samples~/RAG/RAGAndLLM_Sample.cs b/Samples~/RAG/RAGAndLLM_Sample.cs index 818324cd..4265d3cf 100644 --- a/Samples~/RAG/RAGAndLLM_Sample.cs +++ b/Samples~/RAG/RAGAndLLM_Sample.cs @@ -5,14 +5,14 @@ namespace LLMUnitySamples { public class RAGAndLLMSample : RAGSample { - public LLMCharacter llmCharacter; + public LLMAgent llmAgent; public Toggle ParaphraseWithLLM; - protected async override void onInputFieldSubmit(string message) + protected override void onInputFieldSubmit(string message) { playerText.interactable = false; AIText.text = "..."; - (string[] similarPhrases, float[] distances) = await rag.Search(message, 1); + (string[] similarPhrases, float[] distances) = rag.Search(message, 1); string similarPhrase = similarPhrases[0]; if (!ParaphraseWithLLM.isOn) { @@ -21,20 +21,20 @@ protected async override void onInputFieldSubmit(string message) } else { - _ = llmCharacter.Chat("Paraphrase the following phrase: " + similarPhrase, SetAIText, AIReplyComplete); + _ = llmAgent.ChatAsync("Paraphrase the following phrase: " + similarPhrase, SetAIText, AIReplyComplete); } } public void CancelRequests() { - llmCharacter.CancelRequests(); + llmAgent.CancelRequests(); AIReplyComplete(); } protected override void CheckLLMs(bool debug) { base.CheckLLMs(debug); - CheckLLM(llmCharacter, debug); + CheckLLM(llmAgent, debug); } } } diff --git a/Samples~/RAG/RAG_Sample.cs b/Samples~/RAG/RAG_Sample.cs index 0c1455c7..3b668e9d 100644 --- a/Samples~/RAG/RAG_Sample.cs +++ b/Samples~/RAG/RAG_Sample.cs @@ -44,7 +44,7 @@ public async Task CreateEmbeddings() playerText.text += $"Creating Embeddings (only once)...\n"; Stopwatch stopwatch = new Stopwatch(); stopwatch.Start(); - foreach (string phrase in phrases) await rag.Add(phrase); + foreach (string phrase in phrases) rag.Add(phrase); stopwatch.Stop(); Debug.Log($"embedded {rag.Count()} phrases in {stopwatch.Elapsed.TotalMilliseconds / 1000f} secs"); // store the embeddings @@ -56,11 +56,11 @@ public async Task CreateEmbeddings() } } - protected async virtual void onInputFieldSubmit(string message) + protected virtual void onInputFieldSubmit(string message) { playerText.interactable = false; AIText.text = "..."; - (string[] similarPhrases, float[] distances) = await rag.Search(message, 1); + (string[] similarPhrases, float[] distances) = rag.Search(message, 1); AIText.text = similarPhrases[0]; } @@ -82,11 +82,11 @@ public void ExitGame() Application.Quit(); } - protected void CheckLLM(LLMCaller llmCaller, bool debug) + protected void CheckLLM(LLMClient llmClient, bool debug) { - if (!llmCaller.remote && llmCaller.llm != null && llmCaller.llm.model == "") + if (!llmClient.remote && llmClient.llm != null && llmClient.llm.model == "") { - string error = $"Please select a llm model in the {llmCaller.llm.gameObject.name} GameObject!"; + string error = $"Please select a llm model in the {llmClient.llm.gameObject.name} GameObject!"; if (debug) Debug.LogWarning(error); else throw new System.Exception(error); } diff --git a/Samples~/SimpleInteraction/SimpleInteraction.cs b/Samples~/SimpleInteraction/SimpleInteraction.cs index 9c088c3b..060e9c41 100644 --- a/Samples~/SimpleInteraction/SimpleInteraction.cs +++ b/Samples~/SimpleInteraction/SimpleInteraction.cs @@ -6,7 +6,7 @@ namespace LLMUnitySamples { public class SimpleInteraction : MonoBehaviour { - public LLMCharacter llmCharacter; + public LLMAgent llmAgent; public InputField playerText; public Text AIText; @@ -20,7 +20,7 @@ void onInputFieldSubmit(string message) { playerText.interactable = false; AIText.text = "..."; - _ = llmCharacter.Chat(message, SetAIText, AIReplyComplete); + _ = llmAgent.ChatAsync(message, SetAIText, AIReplyComplete); } public void SetAIText(string text) @@ -37,7 +37,7 @@ public void AIReplyComplete() public void CancelRequests() { - llmCharacter.CancelRequests(); + llmAgent.CancelRequests(); AIReplyComplete(); } @@ -50,9 +50,9 @@ public void ExitGame() bool onValidateWarning = true; void OnValidate() { - if (onValidateWarning && !llmCharacter.remote && llmCharacter.llm != null && llmCharacter.llm.model == "") + if (onValidateWarning && !llmAgent.remote && llmAgent.llm != null && llmAgent.llm.model == "") { - Debug.LogWarning($"Please select a model in the {llmCharacter.llm.gameObject.name} GameObject!"); + Debug.LogWarning($"Please select a model in the {llmAgent.llm.gameObject.name} GameObject!"); onValidateWarning = false; } } diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index 4f3210ef..90b70423 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -1,605 +1,573 @@ -using NUnit.Framework; -using LLMUnity; -using UnityEngine; -using System.Threading.Tasks; -using System.Collections.Generic; -using System; -using System.Collections; -using System.IO; -using System.Linq; -using System.Threading; -using UnityEngine.TestTools; -using UnityEditor; -using UnityEditor.TestTools.TestRunner.Api; - -namespace LLMUnityTests -{ - [InitializeOnLoad] - public static class TestRunListener - { - static TestRunListener() - { - var api = ScriptableObject.CreateInstance(); - api.RegisterCallbacks(new TestRunCallbacks()); - } - } - - public class TestRunCallbacks : ICallbacks - { - public void RunStarted(ITestAdaptor testsToRun) {} - - public void RunFinished(ITestResultAdaptor result) - { - LLMUnitySetup.FullLlamaLib = false; - } - - public void TestStarted(ITestAdaptor test) - { - LLMUnitySetup.FullLlamaLib = test.FullName.Contains("CUDA_full"); - } - - public void TestFinished(ITestResultAdaptor result) - { - LLMUnitySetup.FullLlamaLib = false; - } - } - - public class TestLLMLoraAssignment - { - [Test] - public void TestLoras() - { - GameObject gameObject = new GameObject(); - gameObject.SetActive(false); - LLM llm = gameObject.AddComponent(); - - string lora1 = LLMUnitySetup.GetFullPath("lala"); - string lora2Rel = "test/lala"; - string lora2 = LLMUnitySetup.GetAssetPath(lora2Rel); - LLMUnitySetup.CreateEmptyFile(lora1); - Directory.CreateDirectory(Path.GetDirectoryName(lora2)); - LLMUnitySetup.CreateEmptyFile(lora2); - - llm.AddLora(lora1); - llm.AddLora(lora2); - Assert.AreEqual(llm.lora, lora1 + "," + lora2); - Assert.AreEqual(llm.loraWeights, "1,1"); - - llm.RemoveLoras(); - Assert.AreEqual(llm.lora, ""); - Assert.AreEqual(llm.loraWeights, ""); - - llm.AddLora(lora1, 0.8f); - llm.AddLora(lora2Rel, 0.9f); - Assert.AreEqual(llm.lora, lora1 + "," + lora2); - Assert.AreEqual(llm.loraWeights, "0.8,0.9"); - - llm.SetLoraWeight(lora2Rel, 0.7f); - Assert.AreEqual(llm.lora, lora1 + "," + lora2); - Assert.AreEqual(llm.loraWeights, "0.8,0.7"); - - llm.RemoveLora(lora2Rel); - Assert.AreEqual(llm.lora, lora1); - Assert.AreEqual(llm.loraWeights, "0.8"); - - llm.AddLora(lora2Rel); - llm.SetLoraWeight(lora2Rel, 0.5f); - Assert.AreEqual(llm.lora, lora1 + "," + lora2); - Assert.AreEqual(llm.loraWeights, "0.8,0.5"); - - llm.SetLoraWeight(lora2, 0.1f); - Assert.AreEqual(llm.lora, lora1 + "," + lora2); - Assert.AreEqual(llm.loraWeights, "0.8,0.1"); - - Dictionary loraToWeight = new Dictionary(); - loraToWeight[lora1] = 0; - loraToWeight[lora2] = 0.2f; - llm.SetLoraWeights(loraToWeight); - Assert.AreEqual(llm.lora, lora1 + "," + lora2); - Assert.AreEqual(llm.loraWeights, "0,0.2"); - - File.Delete(lora1); - File.Delete(lora2); - } - } - - public class TestLLM - { - protected string modelNameLLManager; - - protected GameObject gameObject; - protected LLM llm; - protected LLMCharacter llmCharacter; - protected Exception error = null; - protected string prompt; - protected string query; - protected string reply1; - protected string reply2; - protected int tokens1; - protected int tokens2; - protected int port; - - static readonly object _lock = new object(); - - public TestLLM() - { - Task task = Init(); - task.Wait(); - } - - public virtual async Task Init() - { - Monitor.Enter(_lock); - port = new System.Random().Next(10000, 20000); - SetParameters(); - await DownloadModels(); - gameObject = new GameObject(); - gameObject.SetActive(false); - llm = CreateLLM(); - llmCharacter = CreateLLMCharacter(); - gameObject.SetActive(true); - } - - public virtual void SetParameters() - { - prompt = "You are a scientific assistant and provide short and concise info on the user questions"; - query = "Can you tell me some fun fact about ants in one sentence?"; - reply1 = "Sure! Here's a fun fact: Ants work together to build complex structures like nests, even though they don't have human-like intelligence."; - reply2 = "Sure! Here's a fun fact: Ants work together to build complex structures like nests, which is a fascinating example of teamwork."; - tokens1 = 20; - tokens2 = 9; - } - - protected virtual string GetModelUrl() - { - return "https://huggingface.co/unsloth/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q4_K_M.gguf"; - } - - public virtual async Task DownloadModels() - { - modelNameLLManager = await LLMManager.DownloadModel(GetModelUrl()); - } - - [Test] - public void TestGetLLMManagerAssetRuntime() - { - string path = ""; - string managerPath = LLM.GetLLMManagerAssetRuntime(path); - Assert.AreEqual(managerPath, path); - - string filename = "lala"; - path = LLMUnitySetup.GetFullPath(filename); - LLMUnitySetup.CreateEmptyFile(path); - managerPath = LLM.GetLLMManagerAssetRuntime(path); - Assert.AreEqual(managerPath, path); - File.Delete(path); - - path = modelNameLLManager; - managerPath = LLM.GetLLMManagerAssetRuntime(path); - Assert.AreEqual(managerPath, LLMManager.GetAssetPath(path)); - - path = LLMUnitySetup.GetAssetPath("lala"); - LLMUnitySetup.CreateEmptyFile(path); - managerPath = LLM.GetLLMManagerAssetRuntime(path); - Assert.AreEqual(managerPath, path); - File.Delete(path); - } - - [Test] - public void TestGetLLMManagerAssetEditor() - { - string path = ""; - string managerPath = LLM.GetLLMManagerAssetEditor(path); - Assert.AreEqual(managerPath, path); - - path = modelNameLLManager; - managerPath = LLM.GetLLMManagerAssetEditor(path); - Assert.AreEqual(managerPath, modelNameLLManager); - - path = LLMManager.Get(modelNameLLManager).path; - managerPath = LLM.GetLLMManagerAssetEditor(path); - Assert.AreEqual(managerPath, modelNameLLManager); - - string filename = "lala"; - path = LLMUnitySetup.GetAssetPath(filename); - LLMUnitySetup.CreateEmptyFile(path); - managerPath = LLM.GetLLMManagerAssetEditor(filename); - Assert.AreEqual(managerPath, filename); - managerPath = LLM.GetLLMManagerAssetEditor(path); - Assert.AreEqual(managerPath, filename); - - path = LLMUnitySetup.GetFullPath(filename); - LLMUnitySetup.CreateEmptyFile(path); - managerPath = LLM.GetLLMManagerAssetEditor(path); - Assert.AreEqual(managerPath, path); - File.Delete(path); - } - - public virtual LLM CreateLLM() - { - LLM llm = gameObject.AddComponent(); - llm.SetModel(modelNameLLManager); - llm.parallelPrompts = 1; - llm.port = port; - return llm; - } - - public virtual LLMCharacter CreateLLMCharacter() - { - LLMCharacter llmCharacter = gameObject.AddComponent(); - llmCharacter.llm = llm; - llmCharacter.playerName = "User"; - llmCharacter.AIName = "Assistant"; - llmCharacter.prompt = prompt; - llmCharacter.temperature = 0; - llmCharacter.seed = 0; - llmCharacter.stream = false; - llmCharacter.numPredict = 50; - llmCharacter.port = port; - return llmCharacter; - } - - [UnityTest] - public IEnumerator RunTests() - { - Task task = RunTestsTask(); - while (!task.IsCompleted) yield return null; - if (error != null) - { - Debug.LogError(error.ToString()); - throw (error); - } - OnDestroy(); - } - - public async Task RunTestsTask() - { - error = null; - try - { - await Tests(); - llm.OnDestroy(); - } - catch (Exception e) - { - error = e; - } - } - - public virtual async Task Tests() - { - await llmCharacter.Tokenize("I", TestTokens); - await llmCharacter.Warmup(); - TestArchitecture(); - TestInitParameters(tokens1, 1); - TestWarmup(); - await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply1)); - TestPostChat(3); - llmCharacter.SetPrompt(llmCharacter.prompt); - llmCharacter.AIName = "False response"; - await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply2)); - TestPostChat(3); - await llmCharacter.Chat("bye!"); - TestPostChat(5); - prompt = "How are you?"; - llmCharacter.SetPrompt(prompt); - await llmCharacter.Chat("hi"); - TestInitParameters(tokens2, 3); - List embeddings = await llmCharacter.Embeddings("hi how are you?"); - TestEmbeddings(embeddings); - } - - public virtual void TestArchitecture() - { - Assert.That(llm.architecture.Contains("avx")); - } - - public void TestInitParameters(int nkeep, int chats) - { - Assert.AreEqual(llmCharacter.nKeep, nkeep); - Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerName, llmCharacter.AIName).Length > 0); - Assert.AreEqual(llmCharacter.chat.Count, chats); - } - - public void TestTokens(List tokens) - { - Assert.AreEqual(tokens, new List { 40 }); - } - - public void TestWarmup() - { - Assert.That(llmCharacter.chat.Count == 1); - } - - public void TestChat(string reply, string replyGT) - { - Debug.Log(reply.Trim()); - var words1 = reply.Trim().Split(new[] { ' ', ',', '.', '!', '?' }, StringSplitOptions.RemoveEmptyEntries); - var words2 = replyGT.Trim().Split(new[] { ' ', ',', '.', '!', '?' }, StringSplitOptions.RemoveEmptyEntries); - var commonWords = words1.Intersect(words2).Count(); - var totalWords = Math.Max(words1.Length, words2.Length); - - Assert.That((double)commonWords / totalWords >= 0.7); - } - - public void TestPostChat(int num) - { - Assert.That(llmCharacter.chat.Count == num); - } - - public void TestEmbeddings(List embeddings) - { - Assert.That(embeddings.Count == 1024); - } - - public virtual void OnDestroy() - { - if (Monitor.IsEntered(_lock)) - { - Monitor.Exit(_lock); - } - } - } - - public class TestLLM_LLMManager_Load : TestLLM - { - public override LLM CreateLLM() - { - LLM llm = gameObject.AddComponent(); - string filename = Path.GetFileName(GetModelUrl()).Split("?")[0]; - string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); - filename = LLMManager.LoadModel(sourcePath); - llm.SetModel(filename); - llm.parallelPrompts = 1; - return llm; - } - } - - public class TestLLM_StreamingAssets_Load : TestLLM - { - string loadPath; - - public override LLM CreateLLM() - { - LLM llm = gameObject.AddComponent(); - string filename = Path.GetFileName(GetModelUrl()).Split("?")[0]; - string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); - loadPath = LLMUnitySetup.GetAssetPath(filename); - if (!File.Exists(loadPath)) File.Copy(sourcePath, loadPath); - llm.SetModel(loadPath); - llm.parallelPrompts = 1; - return llm; - } - - public override void OnDestroy() - { - base.OnDestroy(); - if (!File.Exists(loadPath)) File.Delete(loadPath); - } - } - - public class TestLLM_SetModel_Warning : TestLLM - { - public override LLM CreateLLM() - { - LLM llm = gameObject.AddComponent(); - string filename = Path.GetFileName(GetModelUrl()).Split("?")[0]; - string loadPath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); - llm.SetModel(loadPath); - llm.parallelPrompts = 1; - return llm; - } - } - - public class TestLLM_Lora : TestLLM - { - protected string loraUrl = "https://huggingface.co/phh/Qwen3-0.6B-TLDR-Lora/resolve/main/Qwen3-0.6B-tldr-lora-f16.gguf"; - protected string loraNameLLManager; - protected float loraWeight; - - public override async Task DownloadModels() - { - await base.DownloadModels(); - loraNameLLManager = await LLMManager.DownloadLora(loraUrl); - } - - public override LLM CreateLLM() - { - LLM llm = base.CreateLLM(); - llm.AddLora(loraNameLLManager, loraWeight); - return llm; - } - - public override void SetParameters() - { - prompt = ""; - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer) - { - reply1 = "I am sorry, but I cannot assist with this request. Please try again or ask a different question."; - } - else - { - reply1 = "I am sorry, but I cannot respond to your message as it is empty.Could you please provide a meaningful query or content ?"; - } - reply2 = "False response."; - tokens1 = 5; - tokens2 = 9; - loraWeight = 0.9f; - } - - public override async Task Tests() - { - await base.Tests(); - TestModelPaths(); - await TestLoraWeight(); - loraWeight = 0.6f; - llm.SetLoraWeight(loraNameLLManager, loraWeight); - await TestLoraWeight(); - } - - public void TestModelPaths() - { - Assert.AreEqual(llm.model, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(GetModelUrl()).Split("?")[0]).Replace('\\', '/')); - Assert.AreEqual(llm.lora, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(loraUrl).Split("?")[0]).Replace('\\', '/')); - } - - public async Task TestLoraWeight() - { - List loras = await llm.ListLoras(); - Assert.AreEqual(loras[0].scale, loraWeight); - } - } - - public class TestLLM_Remote : TestLLM - { - public override LLM CreateLLM() - { - LLM llm = base.CreateLLM(); - llm.remote = true; - return llm; - } - - public override LLMCharacter CreateLLMCharacter() - { - LLMCharacter llmCharacter = base.CreateLLMCharacter(); - llmCharacter.remote = true; - return llmCharacter; - } - } - - public class TestLLM_Lora_Remote : TestLLM_Lora - { - public override LLM CreateLLM() - { - LLM llm = base.CreateLLM(); - llm.remote = true; - return llm; - } - - public override LLMCharacter CreateLLMCharacter() - { - LLMCharacter llmCharacter = base.CreateLLMCharacter(); - llmCharacter.remote = true; - return llmCharacter; - } - } - - public class TestLLM_Double : TestLLM - { - LLM llm1; - LLMCharacter llmCharacter1; - - public override async Task Init() - { - SetParameters(); - await DownloadModels(); - gameObject = new GameObject(); - gameObject.SetActive(false); - llm = CreateLLM(); - llmCharacter = CreateLLMCharacter(); - llm1 = CreateLLM(); - llmCharacter1 = CreateLLMCharacter(); - gameObject.SetActive(true); - } - } - - public class TestLLMCharacter_Save : TestLLM - { - string saveName = "TestLLMCharacter_Save"; - - public override LLMCharacter CreateLLMCharacter() - { - LLMCharacter llmCharacter = base.CreateLLMCharacter(); - llmCharacter.save = saveName; - llmCharacter.saveCache = true; - foreach (string filename in new string[] - { - llmCharacter.GetJsonSavePath(saveName), - llmCharacter.GetCacheSavePath(saveName) - }) if (File.Exists(filename)) File.Delete(filename); - return llmCharacter; - } - - public override async Task Tests() - { - await base.Tests(); - TestSave(); - } - - public void TestSave() - { - string jsonPath = llmCharacter.GetJsonSavePath(saveName); - string cachePath = llmCharacter.GetCacheSavePath(saveName); - Assert.That(File.Exists(jsonPath)); - Assert.That(File.Exists(cachePath)); - string json = File.ReadAllText(jsonPath); - File.Delete(jsonPath); - File.Delete(cachePath); - - List chatHistory = JsonUtility.FromJson(json).chat; - Assert.AreEqual(chatHistory.Count, 2); - Assert.AreEqual(chatHistory[0].role, llmCharacter.playerName); - Assert.AreEqual(chatHistory[0].content, "hi"); - Assert.AreEqual(chatHistory[1].role, llmCharacter.AIName); - - Assert.AreEqual(llmCharacter.chat.Count, chatHistory.Count + 1); - for (int i = 0; i < chatHistory.Count; i++) - { - Assert.AreEqual(chatHistory[i].role, llmCharacter.chat[i + 1].role); - Assert.AreEqual(chatHistory[i].content, llmCharacter.chat[i + 1].content); - } - } - } - - public class TestLLM_CUDA : TestLLM - { - public override LLM CreateLLM() - { - LLM llm = base.CreateLLM(); - llm.numGPULayers = 10; - return llm; - } - - public override void SetParameters() - { - base.SetParameters(); - reply1 = "Sure! Here's a fun fact: Ants work together to build complex structures like nests, even though they don't have a brain."; - if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer) - { - reply2 = "Sure! Here's a fun fact: Ants work together to build complex structures like nests, which is a fascinating example of teamwork."; - } - } - - public override void TestArchitecture() - { - Assert.That(llm.architecture.Contains("cuda")); - } - } - - public class TestLLM_CUDA_full : TestLLM_CUDA - { - public override void TestArchitecture() - { - Assert.That(llm.architecture.Contains("cuda") && llm.architecture.Contains("full")); - } - } - - public class TestLLM_CUDA_full_attention : TestLLM_CUDA_full - { - public override LLM CreateLLM() - { - LLM llm = base.CreateLLM(); - llm.flashAttention = true; - return llm; - } - - public override void SetParameters() - { - base.SetParameters(); - if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer) - { - reply2 = "Sure! Here's a fun fact: Ants work together to build complex structures like nests, even though they don't have human-like intelligence."; - } - } - } -} +// using NUnit.Framework; +// using LLMUnity; +// using UnityEngine; +// using System.Threading.Tasks; +// using System.Collections.Generic; +// using System; +// using System.Collections; +// using System.IO; +// using System.Linq; +// using System.Threading; +// using UnityEngine.TestTools; + +// namespace LLMUnityTests +// { +// public class TestLLMLoraAssignment +// { +// [Test] +// public void TestLoras() +// { +// GameObject gameObject = new GameObject(); +// gameObject.SetActive(false); +// LLM llm = gameObject.AddComponent(); + +// string lora1 = LLMUnitySetup.GetFullPath("lala"); +// string lora2Rel = "test/lala"; +// string lora2 = LLMUnitySetup.GetAssetPath(lora2Rel); +// LLMUnitySetup.CreateEmptyFile(lora1); +// Directory.CreateDirectory(Path.GetDirectoryName(lora2)); +// LLMUnitySetup.CreateEmptyFile(lora2); + +// llm.AddLora(lora1); +// llm.AddLora(lora2); +// Assert.AreEqual(llm.lora, lora1 + "," + lora2); +// Assert.AreEqual(llm.loraWeights, "1,1"); + +// llm.RemoveLoras(); +// Assert.AreEqual(llm.lora, ""); +// Assert.AreEqual(llm.loraWeights, ""); + +// llm.AddLora(lora1, 0.8f); +// llm.AddLora(lora2Rel, 0.9f); +// Assert.AreEqual(llm.lora, lora1 + "," + lora2); +// Assert.AreEqual(llm.loraWeights, "0.8,0.9"); + +// llm.SetLoraWeight(lora2Rel, 0.7f); +// Assert.AreEqual(llm.lora, lora1 + "," + lora2); +// Assert.AreEqual(llm.loraWeights, "0.8,0.7"); + +// llm.RemoveLora(lora2Rel); +// Assert.AreEqual(llm.lora, lora1); +// Assert.AreEqual(llm.loraWeights, "0.8"); + +// llm.AddLora(lora2Rel); +// llm.SetLoraWeight(lora2Rel, 0.5f); +// Assert.AreEqual(llm.lora, lora1 + "," + lora2); +// Assert.AreEqual(llm.loraWeights, "0.8,0.5"); + +// llm.SetLoraWeight(lora2, 0.1f); +// Assert.AreEqual(llm.lora, lora1 + "," + lora2); +// Assert.AreEqual(llm.loraWeights, "0.8,0.1"); + +// Dictionary loraToWeight = new Dictionary(); +// loraToWeight[lora1] = 0; +// loraToWeight[lora2] = 0.2f; +// llm.SetLoraWeights(loraToWeight); +// Assert.AreEqual(llm.lora, lora1 + "," + lora2); +// Assert.AreEqual(llm.loraWeights, "0,0.2"); + +// File.Delete(lora1); +// File.Delete(lora2); +// } +// } + +// public class TestLLM +// { +// protected string modelNameLLManager; + +// protected GameObject gameObject; +// protected LLM llm; +// protected LLMAgent llmAgent; +// protected Exception error = null; +// protected string prompt; +// protected string query; +// protected string reply1; +// protected string reply2; +// protected int tokens1; +// protected int tokens2; +// protected int port; + +// static readonly object _lock = new object(); + +// public TestLLM() +// { +// Task task = Init(); +// task.Wait(); +// } + +// public virtual async Task Init() +// { +// Monitor.Enter(_lock); +// port = new System.Random().Next(10000, 20000); +// SetParameters(); +// await DownloadModels(); +// gameObject = new GameObject(); +// gameObject.SetActive(false); +// llm = CreateLLM(); +// llmAgent = CreateLLMCharacter(); +// gameObject.SetActive(true); +// } + +// public virtual void SetParameters() +// { +// prompt = "You are a scientific assistant and provide short and concise info on the user questions"; +// query = "Can you tell me some fun fact about ants in one sentence?"; +// reply1 = "Sure! Here's a fun fact: Ants work together to build complex structures like nests, even though they don't have human-like intelligence."; +// reply2 = "Sure! Here's a fun fact: Ants work together to build complex structures like nests, which is a fascinating example of teamwork."; +// tokens1 = 20; +// tokens2 = 9; +// } + +// protected virtual string GetModelUrl() +// { +// return "https://huggingface.co/unsloth/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q4_K_M.gguf"; +// } + +// public virtual async Task DownloadModels() +// { +// modelNameLLManager = await LLMManager.DownloadModel(GetModelUrl()); +// } + +// [Test] +// public void TestGetLLMManagerAssetRuntime() +// { +// string path = ""; +// string managerPath = LLM.GetLLMManagerAssetRuntime(path); +// Assert.AreEqual(managerPath, path); + +// string filename = "lala"; +// path = LLMUnitySetup.GetFullPath(filename); +// LLMUnitySetup.CreateEmptyFile(path); +// managerPath = LLM.GetLLMManagerAssetRuntime(path); +// Assert.AreEqual(managerPath, path); +// File.Delete(path); + +// path = modelNameLLManager; +// managerPath = LLM.GetLLMManagerAssetRuntime(path); +// Assert.AreEqual(managerPath, LLMManager.GetAssetPath(path)); + +// path = LLMUnitySetup.GetAssetPath("lala"); +// LLMUnitySetup.CreateEmptyFile(path); +// managerPath = LLM.GetLLMManagerAssetRuntime(path); +// Assert.AreEqual(managerPath, path); +// File.Delete(path); +// } + +// [Test] +// public void TestGetLLMManagerAssetEditor() +// { +// string path = ""; +// string managerPath = LLM.GetLLMManagerAssetEditor(path); +// Assert.AreEqual(managerPath, path); + +// path = modelNameLLManager; +// managerPath = LLM.GetLLMManagerAssetEditor(path); +// Assert.AreEqual(managerPath, modelNameLLManager); + +// path = LLMManager.Get(modelNameLLManager).path; +// managerPath = LLM.GetLLMManagerAssetEditor(path); +// Assert.AreEqual(managerPath, modelNameLLManager); + +// string filename = "lala"; +// path = LLMUnitySetup.GetAssetPath(filename); +// LLMUnitySetup.CreateEmptyFile(path); +// managerPath = LLM.GetLLMManagerAssetEditor(filename); +// Assert.AreEqual(managerPath, filename); +// managerPath = LLM.GetLLMManagerAssetEditor(path); +// Assert.AreEqual(managerPath, filename); + +// path = LLMUnitySetup.GetFullPath(filename); +// LLMUnitySetup.CreateEmptyFile(path); +// managerPath = LLM.GetLLMManagerAssetEditor(path); +// Assert.AreEqual(managerPath, path); +// File.Delete(path); +// } + +// public virtual LLM CreateLLM() +// { +// LLM llm = gameObject.AddComponent(); +// llm.SetModel(modelNameLLManager); +// llm.parallelPrompts = 1; +// llm.port = port; +// return llm; +// } + +// public virtual LLMAgent CreateLLMCharacter() +// { +// LLMAgent llmAgent = gameObject.AddComponent(); +// llmAgent.llm = llm; +// llmAgent.userRole = "User"; +// llmAgent.assistantRole = "Assistant"; +// llmAgent.prompt = prompt; +// llmAgent.temperature = 0; +// llmAgent.seed = 0; +// llmAgent.stream = false; +// llmAgent.numPredict = 50; +// llmAgent.port = port; +// return llmAgent; +// } + +// [UnityTest] +// public IEnumerator RunTests() +// { +// Task task = RunTestsTask(); +// while (!task.IsCompleted) yield return null; +// if (error != null) +// { +// Debug.LogError(error.ToString()); +// throw (error); +// } +// OnDestroy(); +// } + +// public async Task RunTestsTask() +// { +// error = null; +// try +// { +// await Tests(); +// llm.OnDestroy(); +// } +// catch (Exception e) +// { +// error = e; +// } +// } + +// public virtual async Task Tests() +// { +// await llmAgent.Tokenize("I", TestTokens); +// await llmAgent.Warmup(); +// TestArchitecture(); +// TestInitParameters(tokens1, 1); +// TestWarmup(); +// await llmAgent.Chat(query, (string reply) => TestChat(reply, reply1)); +// TestPostChat(3); +// llmAgent.SetPrompt(llmAgent.prompt); +// llmAgent.assistantRole = "False response"; +// await llmAgent.Chat(query, (string reply) => TestChat(reply, reply2)); +// TestPostChat(3); +// await llmAgent.Chat("bye!"); +// TestPostChat(5); +// prompt = "How are you?"; +// llmAgent.SetPrompt(prompt); +// await llmAgent.Chat("hi"); +// TestInitParameters(tokens2, 3); +// List embeddings = await llmAgent.Embeddings("hi how are you?"); +// TestEmbeddings(embeddings); +// } + +// public virtual void TestArchitecture() +// { +// Assert.That(llm.architecture.Contains("avx")); +// } + +// public void TestInitParameters(int nkeep, int chats) +// { +// Assert.AreEqual(llmAgent.nKeep, nkeep); +// Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmAgent.userRole, llmAgent.assistantRole).Length > 0); +// Assert.AreEqual(llmAgent.chat.Count, chats); +// } + +// public void TestTokens(List tokens) +// { +// Assert.AreEqual(tokens, new List { 40 }); +// } + +// public void TestWarmup() +// { +// Assert.That(llmAgent.chat.Count == 1); +// } + +// public void TestChat(string reply, string replyGT) +// { +// Debug.Log(reply.Trim()); +// var words1 = reply.Trim().Split(new[] { ' ', ',', '.', '!', '?' }, StringSplitOptions.RemoveEmptyEntries); +// var words2 = replyGT.Trim().Split(new[] { ' ', ',', '.', '!', '?' }, StringSplitOptions.RemoveEmptyEntries); +// var commonWords = words1.Intersect(words2).Count(); +// var totalWords = Math.Max(words1.Length, words2.Length); + +// Assert.That((double)commonWords / totalWords >= 0.7); +// } + +// public void TestPostChat(int num) +// { +// Assert.That(llmAgent.chat.Count == num); +// } + +// public void TestEmbeddings(List embeddings) +// { +// Assert.That(embeddings.Count == 1024); +// } + +// public virtual void OnDestroy() +// { +// if (Monitor.IsEntered(_lock)) +// { +// Monitor.Exit(_lock); +// } +// } +// } + +// public class TestLLM_LLMManager_Load : TestLLM +// { +// public override LLM CreateLLM() +// { +// LLM llm = gameObject.AddComponent(); +// string filename = Path.GetFileName(GetModelUrl()).Split("?")[0]; +// string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); +// filename = LLMManager.LoadModel(sourcePath); +// llm.SetModel(filename); +// llm.parallelPrompts = 1; +// return llm; +// } +// } + +// public class TestLLM_StreamingAssets_Load : TestLLM +// { +// string loadPath; + +// public override LLM CreateLLM() +// { +// LLM llm = gameObject.AddComponent(); +// string filename = Path.GetFileName(GetModelUrl()).Split("?")[0]; +// string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); +// loadPath = LLMUnitySetup.GetAssetPath(filename); +// if (!File.Exists(loadPath)) File.Copy(sourcePath, loadPath); +// llm.SetModel(loadPath); +// llm.parallelPrompts = 1; +// return llm; +// } + +// public override void OnDestroy() +// { +// base.OnDestroy(); +// if (!File.Exists(loadPath)) File.Delete(loadPath); +// } +// } + +// public class TestLLM_SetModel_Warning : TestLLM +// { +// public override LLM CreateLLM() +// { +// LLM llm = gameObject.AddComponent(); +// string filename = Path.GetFileName(GetModelUrl()).Split("?")[0]; +// string loadPath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); +// llm.SetModel(loadPath); +// llm.parallelPrompts = 1; +// return llm; +// } +// } + +// public class TestLLM_Lora : TestLLM +// { +// protected string loraUrl = "https://huggingface.co/phh/Qwen3-0.6B-TLDR-Lora/resolve/main/Qwen3-0.6B-tldr-lora-f16.gguf"; +// protected string loraNameLLManager; +// protected float loraWeight; + +// public override async Task DownloadModels() +// { +// await base.DownloadModels(); +// loraNameLLManager = await LLMManager.DownloadLora(loraUrl); +// } + +// public override LLM CreateLLM() +// { +// LLM llm = base.CreateLLM(); +// llm.AddLora(loraNameLLManager, loraWeight); +// return llm; +// } + +// public override void SetParameters() +// { +// prompt = ""; +// if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer) +// { +// reply1 = "I am sorry, but I cannot assist with this request. Please try again or ask a different question."; +// } +// else +// { +// reply1 = "I am sorry, but I cannot respond to your message as it is empty.Could you please provide a meaningful query or content ?"; +// } +// reply2 = "False response."; +// tokens1 = 5; +// tokens2 = 9; +// loraWeight = 0.9f; +// } + +// public override async Task Tests() +// { +// await base.Tests(); +// TestModelPaths(); +// TestLoraWeight(); +// loraWeight = 0.6f; +// llm.SetLoraWeight(loraNameLLManager, loraWeight); +// TestLoraWeight(); +// } + +// public void TestModelPaths() +// { +// Assert.AreEqual(llm.model, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(GetModelUrl()).Split("?")[0]).Replace('\\', '/')); +// Assert.AreEqual(llm.lora, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(loraUrl).Split("?")[0]).Replace('\\', '/')); +// } + +// public void TestLoraWeight() +// { +// List loras = llm.ListLoras(); +// Assert.AreEqual(loras[0].Scale, loraWeight); +// } +// } + +// public class TestLLM_Remote : TestLLM +// { +// public override LLM CreateLLM() +// { +// LLM llm = base.CreateLLM(); +// llm.remote = true; +// return llm; +// } + +// public override LLMAgent CreateLLMCharacter() +// { +// LLMAgent llmAgent = base.CreateLLMCharacter(); +// llmAgent.remote = true; +// return llmAgent; +// } +// } + +// public class TestLLM_Lora_Remote : TestLLM_Lora +// { +// public override LLM CreateLLM() +// { +// LLM llm = base.CreateLLM(); +// llm.remote = true; +// return llm; +// } + +// public override LLMAgent CreateLLMCharacter() +// { +// LLMAgent llmAgent = base.CreateLLMCharacter(); +// llmAgent.remote = true; +// return llmAgent; +// } +// } + +// public class TestLLM_Double : TestLLM +// { +// LLM llm1; +// LLMAgent llmCharacter1; + +// public override async Task Init() +// { +// SetParameters(); +// await DownloadModels(); +// gameObject = new GameObject(); +// gameObject.SetActive(false); +// llm = CreateLLM(); +// llmAgent = CreateLLMCharacter(); +// llm1 = CreateLLM(); +// llmCharacter1 = CreateLLMCharacter(); +// gameObject.SetActive(true); +// } +// } + +// public class TestLLMCharacter_Save : TestLLM +// { +// string saveName = "TestLLMCharacter_Save"; + +// public override LLMAgent CreateLLMCharacter() +// { +// LLMAgent llmAgent = base.CreateLLMCharacter(); +// llmAgent.save = saveName; +// llmAgent.saveCache = true; +// foreach (string filename in new string[] +// { +// llmAgent.GetJsonSavePath(saveName), +// llmAgent.GetCacheSavePath(saveName) +// }) if (File.Exists(filename)) File.Delete(filename); +// return llmAgent; +// } + +// public override async Task Tests() +// { +// await base.Tests(); +// TestSave(); +// } + +// public void TestSave() +// { +// string jsonPath = llmAgent.GetJsonSavePath(saveName); +// string cachePath = llmAgent.GetCacheSavePath(saveName); +// Assert.That(File.Exists(jsonPath)); +// Assert.That(File.Exists(cachePath)); +// string json = File.ReadAllText(jsonPath); +// File.Delete(jsonPath); +// File.Delete(cachePath); + +// List chatHistory = JsonUtility.FromJson(json).chat; +// Assert.AreEqual(chatHistory.Count, 2); +// Assert.AreEqual(chatHistory[0].role, llmAgent.userRole); +// Assert.AreEqual(chatHistory[0].content, "hi"); +// Assert.AreEqual(chatHistory[1].role, llmAgent.assistantRole); + +// Assert.AreEqual(llmAgent.chat.Count, chatHistory.Count + 1); +// for (int i = 0; i < chatHistory.Count; i++) +// { +// Assert.AreEqual(chatHistory[i].role, llmAgent.chat[i + 1].role); +// Assert.AreEqual(chatHistory[i].content, llmAgent.chat[i + 1].content); +// } +// } +// } + +// public class TestLLM_CUDA : TestLLM +// { +// public override LLM CreateLLM() +// { +// LLM llm = base.CreateLLM(); +// llm.numGPULayers = 10; +// return llm; +// } + +// public override void SetParameters() +// { +// base.SetParameters(); +// reply1 = "Sure! Here's a fun fact: Ants work together to build complex structures like nests, even though they don't have a brain."; +// if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer) +// { +// reply2 = "Sure! Here's a fun fact: Ants work together to build complex structures like nests, which is a fascinating example of teamwork."; +// } +// } + +// public override void TestArchitecture() +// { +// Assert.That(llm.architecture.Contains("cuda")); +// } +// } + +// public class TestLLM_CUDA_full : TestLLM_CUDA +// { +// public override void TestArchitecture() +// { +// Assert.That(llm.architecture.Contains("cuda") && llm.architecture.Contains("full")); +// } +// } + +// public class TestLLM_CUDA_full_attention : TestLLM_CUDA_full +// { +// public override LLM CreateLLM() +// { +// LLM llm = base.CreateLLM(); +// llm.flashAttention = true; +// return llm; +// } + +// public override void SetParameters() +// { +// base.SetParameters(); +// if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer) +// { +// reply2 = "Sure! Here's a fun fact: Ants work together to build complex structures like nests, even though they don't have human-like intelligence."; +// } +// } +// } +// } diff --git a/Tests/Runtime/TestLLMChatTemplates.cs b/Tests/Runtime/TestLLMChatTemplates.cs index bafa6cb2..10969196 100644 --- a/Tests/Runtime/TestLLMChatTemplates.cs +++ b/Tests/Runtime/TestLLMChatTemplates.cs @@ -1,200 +1,200 @@ -using LLMUnity; -using System.Collections.Generic; -using NUnit.Framework.Internal; -using NUnit.Framework; - -namespace LLMUnityTests -{ - public class TestChatTemplate - { - List messages = new List() - { - new ChatMessage {role = "system", content = "you are a bot"}, - new ChatMessage {role = "user", content = "Hello, how are you?"}, - new ChatMessage {role = "assistant", content = "I'm doing great. How can I help you today?"}, - new ChatMessage {role = "user", content = "I'd like to show off how chat templating works!"}, - new ChatMessage {role = "assistant", content = "chat template is awesome"}, - new ChatMessage {role = "user", content = "do you think so?"}, - }; - - [Test] - public void TestChatML() - { - Assert.AreEqual( - new ChatMLTemplate().ComputePrompt(messages, "user", "assistant"), - "<|im_start|>system\nyou are a bot<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\nchat template is awesome<|im_end|>\n<|im_start|>user\ndo you think so?<|im_end|>\n<|im_start|>assistant\n" - ); - } - - [Test] - public void TestGemma() - { - Assert.AreEqual( - new GemmaTemplate().ComputePrompt(messages, "user", "assistant"), - "user\nyou are a bot\n\nHello, how are you?\nmodel\nI'm doing great. How can I help you today?\nuser\nI'd like to show off how chat templating works!\nmodel\nchat template is awesome\nuser\ndo you think so?\nmodel\n" - ); - } - - [Test] - public void TestMistralInstruct() - { - Assert.AreEqual( - new MistralInstructTemplate().ComputePrompt(messages, "user", "assistant"), - "[INST] you are a bot\n\nHello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]chat template is awesome[INST] do you think so? [/INST]" - ); - } - - [Test] - public void TestMistralChat() - { - Assert.AreEqual( - new MistralChatTemplate().ComputePrompt(messages, "user", "assistant"), - "[INST] you are a bot\n\n### user: Hello, how are you? [/INST]### assistant: I'm doing great. How can I help you today?[INST] ### user: I'd like to show off how chat templating works! [/INST]### assistant: chat template is awesome[INST] ### user: do you think so? [/INST]### assistant:" - ); - } - - [Test] - public void TestLLama2() - { - Assert.AreEqual( - new LLama2Template().ComputePrompt(messages, "user", "assistant"), - "[INST] <>\nyou are a bot\n<> Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]chat template is awesome [INST] do you think so? [/INST]" - ); - } - - [Test] - public void TestLLama2Chat() - { - Assert.AreEqual( - new LLama2ChatTemplate().ComputePrompt(messages, "user", "assistant"), - "[INST] <>\nyou are a bot\n<> ### user: Hello, how are you? [/INST]### assistant: I'm doing great. How can I help you today? [INST] ### user: I'd like to show off how chat templating works! [/INST]### assistant: chat template is awesome [INST] ### user: do you think so? [/INST]### assistant:" - ); - } - - [Test] - public void TestLLama3Chat() - { - Assert.AreEqual( - new LLama3ChatTemplate().ComputePrompt(messages, "user", "assistant"), - "<|start_header_id|>system<|end_header_id|>\n\nyou are a bot<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nchat template is awesome<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ndo you think so?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" - ); - } - - [Test] - public void TestAlpaca() - { - Assert.AreEqual( - new AlpacaTemplate().ComputePrompt(messages, "user", "assistant"), - "you are a bot\n\n### user: Hello, how are you?\n### assistant: I'm doing great. How can I help you today?\n### user: I'd like to show off how chat templating works!\n### assistant: chat template is awesome\n### user: do you think so?\n### assistant:" - ); - } - - [Test] - public void TestVicuna() - { - Assert.AreEqual( - new VicunaTemplate().ComputePrompt(messages, "user", "assistant"), - "you are a bot\n\nuser: Hello, how are you?\nassistant: I'm doing great. How can I help you today?\nuser: I'd like to show off how chat templating works!\nassistant: chat template is awesome\nuser: do you think so?\nassistant:" - ); - } - - [Test] - public void TestPhi2() - { - Assert.AreEqual( - new Phi2Template().ComputePrompt(messages, "user", "assistant"), - "you are a bot\n\nuser: Hello, how are you?\nassistant: I'm doing great. How can I help you today?\nuser: I'd like to show off how chat templating works!\nassistant: chat template is awesome\nuser: do you think so?\nassistant:" - ); - } - - [Test] - public void TestPhi3() - { - Assert.AreEqual( - new Phi3Template().ComputePrompt(messages, "user", "assistant"), - "<|user|>\nyou are a bot\n\nHello, how are you?<|end|>\n<|assistant|>\nI'm doing great. How can I help you today?<|end|>\n<|user|>\nI'd like to show off how chat templating works!<|end|>\n<|assistant|>\nchat template is awesome<|end|>\n<|user|>\ndo you think so?<|end|>\n<|assistant|>\n" - ); - } - - [Test] - public void TestPhi3_5() - { - Assert.AreEqual( - new Phi3_5Template().ComputePrompt(messages, "user", "assistant"), - "<|system|>\nyou are a bot<|end|>\n<|user|>\nHello, how are you?<|end|>\n<|assistant|>\nI'm doing great. How can I help you today?<|end|>\n<|user|>\nI'd like to show off how chat templating works!<|end|>\n<|assistant|>\nchat template is awesome<|end|>\n<|user|>\ndo you think so?<|end|>\n<|assistant|>\n" - ); - } - - [Test] - public void TestPhi4Mini() - { - Assert.AreEqual( - new Phi4MiniTemplate().ComputePrompt(messages, "user", "assistant"), - "<|system|>you are a bot<|end|><|user|>Hello, how are you?<|end|><|assistant|>I'm doing great. How can I help you today?<|end|><|user|>I'd like to show off how chat templating works!<|end|><|assistant|>chat template is awesome<|end|><|user|>do you think so?<|end|><|assistant|>" - ); - } - - [Test] - public void TestPhi4() - { - Assert.AreEqual( - new Phi4Template().ComputePrompt(messages, "user", "assistant"), - "<|im_start|>system<|im_sep|>you are a bot<|im_end|><|im_start|>user<|im_sep|>Hello, how are you?<|im_end|><|im_start|>assistant<|im_sep|>I'm doing great. How can I help you today?<|im_end|><|im_start|>user<|im_sep|>I'd like to show off how chat templating works!<|im_end|><|im_start|>assistant<|im_sep|>chat template is awesome<|im_end|><|im_start|>user<|im_sep|>do you think so?<|im_end|><|im_start|>assistant<|im_sep|>" - ); - } - - [Test] - public void TestZephyr() - { - Assert.AreEqual( - new ZephyrTemplate().ComputePrompt(messages, "user", "assistant"), - "<|system|>\nyou are a bot\n<|user|>\nHello, how are you?\n<|assistant|>\nI'm doing great. How can I help you today?\n<|user|>\nI'd like to show off how chat templating works!\n<|assistant|>\nchat template is awesome\n<|user|>\ndo you think so?\n<|assistant|>\n" - ); - } - - [Test] - public void TestDeepSeekV2() - { - Assert.AreEqual( - new DeepSeekV2Template().ComputePrompt(messages, "user", "assistant"), - "<|begin▁of▁sentence|>you are a bot\n\nUser: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end▁of▁sentence|>User: I'd like to show off how chat templating works!\n\nAssistant: chat template is awesome<|end▁of▁sentence|>User: do you think so?\n\nAssistant:" - ); - } - - [Test] - public void TestDeepSeekV3() - { - Assert.AreEqual( - new DeepSeekV3Template().ComputePrompt(messages, "user", "assistant"), - "<|begin▁of▁sentence|>you are a bot\n\n<|User|>Hello, how are you?<|Assistant|>I'm doing great. How can I help you today?<|end▁of▁sentence|><|User|>I'd like to show off how chat templating works!<|Assistant|>chat template is awesome<|end▁of▁sentence|><|User|>do you think so?<|Assistant|>" - ); - } - - [Test] - public void TestDeepSeekR1() - { - Assert.AreEqual( - new DeepSeekR1Template().ComputePrompt(messages, "user", "assistant"), - "<|begin▁of▁sentence|>you are a bot\n\n<|User|>Hello, how are you?<|Assistant|>I'm doing great. How can I help you today?<|end▁of▁sentence|><|User|>I'd like to show off how chat templating works!<|Assistant|>chat template is awesome<|end▁of▁sentence|><|User|>do you think so?<|Assistant|>\n\n\n\n" - ); - } - - [Test] - public void TestQwen3() - { - Assert.AreEqual( - new Qwen3Template().ComputePrompt(messages, "user", "assistant"), - "<|im_start|>system\nyou are a bot<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\nchat template is awesome<|im_end|>\n<|im_start|>user\ndo you think so?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" - ); - } - - [Test] - public void TestBitNet() - { - Assert.AreEqual( - new BitNetTemplate().ComputePrompt(messages, "user", "assistant"), - "System: you are a bot<|eot_id|>User: Hello, how are you?<|eot_id|>Assistant: I'm doing great. How can I help you today?<|eot_id|>User: I'd like to show off how chat templating works!<|eot_id|>Assistant: chat template is awesome<|eot_id|>User: do you think so?<|eot_id|>Assistant: " - ); - } - } -} +// using LLMUnity; +// using System.Collections.Generic; +// using NUnit.Framework.Internal; +// using NUnit.Framework; + +// namespace LLMUnityTests +// { +// public class TestChatTemplate +// { +// List messages = new List() +// { +// new ChatMessage {role = "system", content = "you are a bot"}, +// new ChatMessage {role = "user", content = "Hello, how are you?"}, +// new ChatMessage {role = "assistant", content = "I'm doing great. How can I help you today?"}, +// new ChatMessage {role = "user", content = "I'd like to show off how chat templating works!"}, +// new ChatMessage {role = "assistant", content = "chat template is awesome"}, +// new ChatMessage {role = "user", content = "do you think so?"}, +// }; + +// [Test] +// public void TestChatML() +// { +// Assert.AreEqual( +// new ChatMLTemplate().ComputePrompt(messages, "user", "assistant"), +// "<|im_start|>system\nyou are a bot<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\nchat template is awesome<|im_end|>\n<|im_start|>user\ndo you think so?<|im_end|>\n<|im_start|>assistant\n" +// ); +// } + +// [Test] +// public void TestGemma() +// { +// Assert.AreEqual( +// new GemmaTemplate().ComputePrompt(messages, "user", "assistant"), +// "user\nyou are a bot\n\nHello, how are you?\nmodel\nI'm doing great. How can I help you today?\nuser\nI'd like to show off how chat templating works!\nmodel\nchat template is awesome\nuser\ndo you think so?\nmodel\n" +// ); +// } + +// [Test] +// public void TestMistralInstruct() +// { +// Assert.AreEqual( +// new MistralInstructTemplate().ComputePrompt(messages, "user", "assistant"), +// "[INST] you are a bot\n\nHello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]chat template is awesome[INST] do you think so? [/INST]" +// ); +// } + +// [Test] +// public void TestMistralChat() +// { +// Assert.AreEqual( +// new MistralChatTemplate().ComputePrompt(messages, "user", "assistant"), +// "[INST] you are a bot\n\n### user: Hello, how are you? [/INST]### assistant: I'm doing great. How can I help you today?[INST] ### user: I'd like to show off how chat templating works! [/INST]### assistant: chat template is awesome[INST] ### user: do you think so? [/INST]### assistant:" +// ); +// } + +// [Test] +// public void TestLLama2() +// { +// Assert.AreEqual( +// new LLama2Template().ComputePrompt(messages, "user", "assistant"), +// "[INST] <>\nyou are a bot\n<> Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]chat template is awesome [INST] do you think so? [/INST]" +// ); +// } + +// [Test] +// public void TestLLama2Chat() +// { +// Assert.AreEqual( +// new LLama2ChatTemplate().ComputePrompt(messages, "user", "assistant"), +// "[INST] <>\nyou are a bot\n<> ### user: Hello, how are you? [/INST]### assistant: I'm doing great. How can I help you today? [INST] ### user: I'd like to show off how chat templating works! [/INST]### assistant: chat template is awesome [INST] ### user: do you think so? [/INST]### assistant:" +// ); +// } + +// [Test] +// public void TestLLama3Chat() +// { +// Assert.AreEqual( +// new LLama3ChatTemplate().ComputePrompt(messages, "user", "assistant"), +// "<|start_header_id|>system<|end_header_id|>\n\nyou are a bot<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nchat template is awesome<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ndo you think so?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" +// ); +// } + +// [Test] +// public void TestAlpaca() +// { +// Assert.AreEqual( +// new AlpacaTemplate().ComputePrompt(messages, "user", "assistant"), +// "you are a bot\n\n### user: Hello, how are you?\n### assistant: I'm doing great. How can I help you today?\n### user: I'd like to show off how chat templating works!\n### assistant: chat template is awesome\n### user: do you think so?\n### assistant:" +// ); +// } + +// [Test] +// public void TestVicuna() +// { +// Assert.AreEqual( +// new VicunaTemplate().ComputePrompt(messages, "user", "assistant"), +// "you are a bot\n\nuser: Hello, how are you?\nassistant: I'm doing great. How can I help you today?\nuser: I'd like to show off how chat templating works!\nassistant: chat template is awesome\nuser: do you think so?\nassistant:" +// ); +// } + +// [Test] +// public void TestPhi2() +// { +// Assert.AreEqual( +// new Phi2Template().ComputePrompt(messages, "user", "assistant"), +// "you are a bot\n\nuser: Hello, how are you?\nassistant: I'm doing great. How can I help you today?\nuser: I'd like to show off how chat templating works!\nassistant: chat template is awesome\nuser: do you think so?\nassistant:" +// ); +// } + +// [Test] +// public void TestPhi3() +// { +// Assert.AreEqual( +// new Phi3Template().ComputePrompt(messages, "user", "assistant"), +// "<|user|>\nyou are a bot\n\nHello, how are you?<|end|>\n<|assistant|>\nI'm doing great. How can I help you today?<|end|>\n<|user|>\nI'd like to show off how chat templating works!<|end|>\n<|assistant|>\nchat template is awesome<|end|>\n<|user|>\ndo you think so?<|end|>\n<|assistant|>\n" +// ); +// } + +// [Test] +// public void TestPhi3_5() +// { +// Assert.AreEqual( +// new Phi3_5Template().ComputePrompt(messages, "user", "assistant"), +// "<|system|>\nyou are a bot<|end|>\n<|user|>\nHello, how are you?<|end|>\n<|assistant|>\nI'm doing great. How can I help you today?<|end|>\n<|user|>\nI'd like to show off how chat templating works!<|end|>\n<|assistant|>\nchat template is awesome<|end|>\n<|user|>\ndo you think so?<|end|>\n<|assistant|>\n" +// ); +// } + +// [Test] +// public void TestPhi4Mini() +// { +// Assert.AreEqual( +// new Phi4MiniTemplate().ComputePrompt(messages, "user", "assistant"), +// "<|system|>you are a bot<|end|><|user|>Hello, how are you?<|end|><|assistant|>I'm doing great. How can I help you today?<|end|><|user|>I'd like to show off how chat templating works!<|end|><|assistant|>chat template is awesome<|end|><|user|>do you think so?<|end|><|assistant|>" +// ); +// } + +// [Test] +// public void TestPhi4() +// { +// Assert.AreEqual( +// new Phi4Template().ComputePrompt(messages, "user", "assistant"), +// "<|im_start|>system<|im_sep|>you are a bot<|im_end|><|im_start|>user<|im_sep|>Hello, how are you?<|im_end|><|im_start|>assistant<|im_sep|>I'm doing great. How can I help you today?<|im_end|><|im_start|>user<|im_sep|>I'd like to show off how chat templating works!<|im_end|><|im_start|>assistant<|im_sep|>chat template is awesome<|im_end|><|im_start|>user<|im_sep|>do you think so?<|im_end|><|im_start|>assistant<|im_sep|>" +// ); +// } + +// [Test] +// public void TestZephyr() +// { +// Assert.AreEqual( +// new ZephyrTemplate().ComputePrompt(messages, "user", "assistant"), +// "<|system|>\nyou are a bot\n<|user|>\nHello, how are you?\n<|assistant|>\nI'm doing great. How can I help you today?\n<|user|>\nI'd like to show off how chat templating works!\n<|assistant|>\nchat template is awesome\n<|user|>\ndo you think so?\n<|assistant|>\n" +// ); +// } + +// [Test] +// public void TestDeepSeekV2() +// { +// Assert.AreEqual( +// new DeepSeekV2Template().ComputePrompt(messages, "user", "assistant"), +// "<|begin▁of▁sentence|>you are a bot\n\nUser: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end▁of▁sentence|>User: I'd like to show off how chat templating works!\n\nAssistant: chat template is awesome<|end▁of▁sentence|>User: do you think so?\n\nAssistant:" +// ); +// } + +// [Test] +// public void TestDeepSeekV3() +// { +// Assert.AreEqual( +// new DeepSeekV3Template().ComputePrompt(messages, "user", "assistant"), +// "<|begin▁of▁sentence|>you are a bot\n\n<|User|>Hello, how are you?<|Assistant|>I'm doing great. How can I help you today?<|end▁of▁sentence|><|User|>I'd like to show off how chat templating works!<|Assistant|>chat template is awesome<|end▁of▁sentence|><|User|>do you think so?<|Assistant|>" +// ); +// } + +// [Test] +// public void TestDeepSeekR1() +// { +// Assert.AreEqual( +// new DeepSeekR1Template().ComputePrompt(messages, "user", "assistant"), +// "<|begin▁of▁sentence|>you are a bot\n\n<|User|>Hello, how are you?<|Assistant|>I'm doing great. How can I help you today?<|end▁of▁sentence|><|User|>I'd like to show off how chat templating works!<|Assistant|>chat template is awesome<|end▁of▁sentence|><|User|>do you think so?<|Assistant|>\n\n\n\n" +// ); +// } + +// [Test] +// public void TestQwen3() +// { +// Assert.AreEqual( +// new Qwen3Template().ComputePrompt(messages, "user", "assistant"), +// "<|im_start|>system\nyou are a bot<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\nchat template is awesome<|im_end|>\n<|im_start|>user\ndo you think so?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" +// ); +// } + +// [Test] +// public void TestBitNet() +// { +// Assert.AreEqual( +// new BitNetTemplate().ComputePrompt(messages, "user", "assistant"), +// "System: you are a bot<|eot_id|>User: Hello, how are you?<|eot_id|>Assistant: I'm doing great. How can I help you today?<|eot_id|>User: I'd like to show off how chat templating works!<|eot_id|>Assistant: chat template is awesome<|eot_id|>User: do you think so?<|eot_id|>Assistant: " +// ); +// } +// } +// } diff --git a/Tests/Runtime/TestSearch.cs b/Tests/Runtime/TestSearch.cs index 5a0872b3..beb4acd7 100644 --- a/Tests/Runtime/TestSearch.cs +++ b/Tests/Runtime/TestSearch.cs @@ -67,40 +67,30 @@ public static bool ApproxEqual(float x1, float x2) } [UnityTest] - public virtual IEnumerator RunTests() - { - Task task = RunTestsTask(); - while (!task.IsCompleted) yield return null; - if (error != null) - { - Debug.LogError(error.ToString()); - throw (error); - } - } - - public virtual async Task RunTestsTask() + public virtual void RunTests() { error = null; try { - await Tests(); + Tests(); llm.OnDestroy(); } - catch (Exception e) + catch (Exception e) { - error = e; + Debug.LogError(e.ToString()); + throw (e); } } - public virtual async Task Tests() + public virtual void Tests() { - await TestAdd(); - await TestSearch(); - await TestIncrementalSearch(); - await TestSaveLoad(); + TestAdd(); + TestSearch(); + TestIncrementalSearch(); + TestSaveLoad(); } - public virtual async Task TestAdd() + public virtual void TestAdd() { void CheckCount(int[] nums) { @@ -114,32 +104,32 @@ void CheckCount(int[] nums) } int key, num; - key = await search.Add(weather); + key = search.Add(weather); Assert.That(key == 0); Assert.That(search.Get(key) == weather); Assert.That(search.Count() == 1); search.Remove(key); Assert.That(search.Count() == 0); - key = await search.Add(weather); + key = search.Add(weather); Assert.That(key == 1); - key = await search.Add(raining); + key = search.Add(raining); Assert.That(key == 2); - key = await search.Add(sometext); + key = search.Add(sometext); Assert.That(key == 3); Assert.That(search.Count() == 3); search.Clear(); Assert.That(search.Count() == 0); - key = await search.Add(weather, "0"); + key = search.Add(weather, "0"); Assert.That(key == 0); - key = await search.Add(raining, "0"); + key = search.Add(raining, "0"); Assert.That(key == 1); - key = await search.Add(weather, "1"); + key = search.Add(weather, "1"); Assert.That(key == 2); - key = await search.Add(sometext, "1"); + key = search.Add(sometext, "1"); Assert.That(key == 3); - key = await search.Add(sometext, "2"); + key = search.Add(sometext, "2"); Assert.That(key == 4); CheckCount(new int[] {2, 2, 1}); num = search.Remove(weather, "0"); @@ -165,20 +155,20 @@ void CheckCount(int[] nums) Assert.That(search.Count() == 0); } - public virtual async Task TestSearch() + public virtual void TestSearch() { string[] results; float[] distances; - (results, distances) = await search.Search(weather, 1); + (results, distances) = search.Search(weather, 1); Assert.That(results.Length == 0); Assert.That(distances.Length == 0); - await search.Add(weather); - await search.Add(raining); - await search.Add(sometext); + search.Add(weather); + search.Add(raining); + search.Add(sometext); - (results, distances) = await search.Search(weather, 2); + (results, distances) = search.Search(weather, 2); Assert.AreEqual(results.Length, 2); Assert.AreEqual(distances.Length, 2); Assert.AreEqual(results[0], weather); @@ -186,7 +176,7 @@ public virtual async Task TestSearch() Assert.That(ApproxEqual(distances[0], 0)); Assert.That(ApproxEqual(distances[1], weatherRainingDiff)); - (results, distances) = await search.Search(raining, 2); + (results, distances) = search.Search(raining, 2); Assert.AreEqual(results.Length, 2); Assert.AreEqual(distances.Length, 2); Assert.AreEqual(results[0], raining); @@ -196,12 +186,12 @@ public virtual async Task TestSearch() search.Clear(); - await search.Add(weather, "0"); - await search.Add(raining, "1"); - await search.Add(sometext, "0"); - await search.Add(sometext, "1"); + search.Add(weather, "0"); + search.Add(raining, "1"); + search.Add(sometext, "0"); + search.Add(sometext, "1"); - (results, distances) = await search.Search(weather, 2, "0"); + (results, distances) = search.Search(weather, 2, "0"); Assert.AreEqual(results.Length, 2); Assert.AreEqual(distances.Length, 2); Assert.AreEqual(results[0], weather); @@ -209,7 +199,7 @@ public virtual async Task TestSearch() Assert.That(ApproxEqual(distances[0], 0)); Assert.That(ApproxEqual(distances[1], weatherSometextDiff)); - (results, distances) = await search.Search(weather, 2, "0"); + (results, distances) = search.Search(weather, 2, "0"); Assert.AreEqual(results.Length, 2); Assert.AreEqual(distances.Length, 2); Assert.AreEqual(results[0], weather); @@ -217,14 +207,14 @@ public virtual async Task TestSearch() Assert.That(ApproxEqual(distances[0], 0)); Assert.That(ApproxEqual(distances[1], weatherSometextDiff)); - (results, distances) = await search.Search(weather, 2, "1"); + (results, distances) = search.Search(weather, 2, "1"); Assert.AreEqual(results.Length, 2); Assert.AreEqual(distances.Length, 2); Assert.AreEqual(results[0], raining); Assert.AreEqual(results[1], sometext); Assert.That(ApproxEqual(distances[1], weatherSometextDiff)); - (results, distances) = await search.Search(weather, 3, "1"); + (results, distances) = search.Search(weather, 3, "1"); Assert.AreEqual(results.Length, 2); Assert.AreEqual(distances.Length, 2); Assert.AreEqual(results[0], raining); @@ -233,13 +223,13 @@ public virtual async Task TestSearch() search.Clear(); } - public async Task TestIncrementalSearch() + public void TestIncrementalSearch() { string[] results; float[] distances; bool completed; - int searchKey = await search.IncrementalSearch(weather); + int searchKey = search.IncrementalSearch(weather); (results, distances, completed) = search.IncrementalFetch(searchKey, 1); Assert.That(searchKey == 0); Assert.That(results.Length == 0); @@ -247,11 +237,11 @@ public async Task TestIncrementalSearch() Assert.That(completed); search.Clear(); - await search.Add(weather); - await search.Add(raining); - await search.Add(sometext); + search.Add(weather); + search.Add(raining); + search.Add(sometext); - searchKey = await search.IncrementalSearch(weather); + searchKey = search.IncrementalSearch(weather); (results, distances, completed) = search.IncrementalFetch(searchKey, 1); Assert.That(searchKey == 0); Assert.That(results.Length == 1); @@ -269,7 +259,7 @@ public async Task TestIncrementalSearch() Assert.That(ApproxEqual(distances[1], weatherSometextDiff)); Assert.That(completed); - searchKey = await search.IncrementalSearch(weather); + searchKey = search.IncrementalSearch(weather); (results, distances, completed) = search.IncrementalFetch(searchKey, 2); Assert.That(searchKey == 1); Assert.That(results.Length == 2); @@ -283,12 +273,12 @@ public async Task TestIncrementalSearch() search.IncrementalSearchComplete(searchKey); search.Clear(); - await search.Add(weather, "0"); - await search.Add(raining, "1"); - await search.Add(sometext, "0"); - await search.Add(sometext, "1"); + search.Add(weather, "0"); + search.Add(raining, "1"); + search.Add(sometext, "0"); + search.Add(sometext, "1"); - searchKey = await search.IncrementalSearch(weather, "0"); + searchKey = search.IncrementalSearch(weather, "0"); (results, distances, completed) = search.IncrementalFetch(searchKey, 2); Assert.That(searchKey == 0); Assert.AreEqual(results.Length, 2); @@ -299,7 +289,7 @@ public async Task TestIncrementalSearch() Assert.That(ApproxEqual(distances[1], weatherSometextDiff)); Assert.That(completed); - searchKey = await search.IncrementalSearch(weather, "0"); + searchKey = search.IncrementalSearch(weather, "0"); (results, distances, completed) = search.IncrementalFetch(searchKey, 2); Assert.That(searchKey == 1); Assert.AreEqual(results.Length, 2); @@ -310,7 +300,7 @@ public async Task TestIncrementalSearch() Assert.That(ApproxEqual(distances[1], weatherSometextDiff)); Assert.That(completed); - searchKey = await search.IncrementalSearch(weather, "1"); + searchKey = search.IncrementalSearch(weather, "1"); (results, distances, completed) = search.IncrementalFetch(searchKey, 1); Assert.That(searchKey == 2); Assert.AreEqual(results.Length, 1); @@ -325,7 +315,7 @@ public async Task TestIncrementalSearch() Assert.That(ApproxEqual(distances[0], weatherSometextDiff)); Assert.That(completed); - searchKey = await search.IncrementalSearch(weather, "1"); + searchKey = search.IncrementalSearch(weather, "1"); (results, distances, completed) = search.IncrementalFetch(searchKey, 3); Assert.That(searchKey == 3); Assert.AreEqual(results.Length, 2); @@ -337,19 +327,19 @@ public async Task TestIncrementalSearch() search.Clear(); } - public virtual async Task TestSaveLoad() + public virtual void TestSaveLoad() { string path = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); string[] results; float[] distances; - await search.Add(weather); - await search.Add(raining); - await search.Add(sometext); + search.Add(weather); + search.Add(raining); + search.Add(sometext); search.Save(path); search.Clear(); - await search.Load(path); + search.Load(path); File.Delete(path); Assert.That(search.Count() == 3); @@ -357,7 +347,7 @@ public virtual async Task TestSaveLoad() Assert.That(search.Get(1) == raining); Assert.That(search.Get(2) == sometext); - (results, distances) = await search.Search(raining, 2); + (results, distances) = search.Search(raining, 2); Assert.AreEqual(results[0], raining); Assert.AreEqual(results[1], weather); Assert.That(ApproxEqual(distances[0], 0)); @@ -365,14 +355,14 @@ public virtual async Task TestSaveLoad() search.Clear(); - await search.Add(weather, "0"); - await search.Add(raining, "1"); - await search.Add(sometext, "0"); - await search.Add(sometext, "1"); + search.Add(weather, "0"); + search.Add(raining, "1"); + search.Add(sometext, "0"); + search.Add(sometext, "1"); search.Save(path); search.Clear(); - await search.Load(path); + search.Load(path); File.Delete(path); Assert.That(search.Count() == 4); @@ -383,7 +373,7 @@ public virtual async Task TestSaveLoad() Assert.That(search.Get(2) == sometext); Assert.That(search.Get(3) == sometext); - (results, distances) = await search.Search(raining, 2, "0"); + (results, distances) = search.Search(raining, 2, "0"); Assert.AreEqual(results[0], weather); Assert.AreEqual(results[1], sometext); Assert.That(ApproxEqual(distances[0], weatherRainingDiff)); @@ -401,34 +391,34 @@ public override T CreateSearch() return search; } - public override async Task Tests() + public override void Tests() { - await base.Tests(); - await TestEncode(); - await TestSimilarity(); - await TestSearchFromList(); + base.Tests(); + TestEncode(); + TestSimilarity(); + TestSearchFromList(); } - public async Task TestEncode() + public void TestEncode() { - float[] encoding = await search.Encode(weather); + float[] encoding = search.Encode(weather); Assert.That(ApproxEqual(encoding[0], -0.02910374f)); Assert.That(ApproxEqual(encoding[383], 0.01764517f)); } - public async Task TestSimilarity() + public void TestSimilarity() { - float[] sentence1 = await search.Encode(weather); - float[] sentence2 = await search.Encode(raining); + float[] sentence1 = search.Encode(weather); + float[] sentence2 = search.Encode(raining); float similarity = SimpleSearch.DotProduct(sentence1, sentence2); float distance = SimpleSearch.InverseDotProduct(sentence1, sentence2); Assert.That(ApproxEqual(similarity, 1 - weatherRainingDiff)); Assert.That(ApproxEqual(distance, weatherRainingDiff)); } - public async Task TestSearchFromList() + public void TestSearchFromList() { - (string[] results, float[] distances) = await search.SearchFromList(weather, new string[] {sometext, raining}); + (string[] results, float[] distances) = search.SearchFromList(weather, new string[] {sometext, raining}); Assert.AreEqual(results.Length, 2); Assert.AreEqual(distances.Length, 2); Assert.AreEqual(results[0], raining); @@ -491,18 +481,18 @@ public static (string, List<(int, int)>) GenerateText(int length) return (new string(generatedText), indices); } - public override async Task Tests() + public override void Tests() { - await base.Tests(); - await TestProperSplit(); + base.Tests(); + TestProperSplit(); } - public async Task TestProperSplit() + public void TestProperSplit() { for (int length = 50; length <= 500; length += 50) { (string randomText, _) = GenerateText(length); - List<(int, int)> indices = await search.Split(randomText); + List<(int, int)> indices = search.Split(randomText); int currIndex = 0; foreach ((int startIndex, int endIndex) in indices) { @@ -510,7 +500,7 @@ public async Task TestProperSplit() currIndex = endIndex + 1; } Assert.AreEqual(currIndex, length); - int key = await search.Add(randomText); + int key = search.Add(randomText); Assert.AreEqual(search.Get(key), randomText); } } @@ -520,13 +510,13 @@ public class TestTokenSplitter : TestSplitter {} public class TestWordSplitter : TestSplitter { - public override async Task Tests() + public override void Tests() { - await base.Tests(); - await TestSplit(); + base.Tests(); + TestSplit(); } - public async Task TestSplit() + public void TestSplit() { System.Random random = new System.Random(); char[] characters = "abcdefghijklmnopqrstuvwxyz".ToCharArray(); @@ -549,7 +539,7 @@ public async Task TestSplit() } string text = String.Join(" ", splits); - List<(int, int)> indices = await search.Split(text); + List<(int, int)> indices = search.Split(text); for (int i = 0; i < indices.Count; i++) { (int startIndex, int endIndex) = indices[i]; @@ -563,17 +553,17 @@ public async Task TestSplit() public class TestSentenceSplitter : TestSplitter { - public override async Task Tests() + public override void Tests() { - await base.Tests(); - await TestSplit(); + base.Tests(); + TestSplit(); } - public async Task TestSplit() + public void TestSplit() { - async Task SplitSentences(string text) + string[] SplitSentences(string text) { - List<(int, int)> indices = await search.Split(text); + List<(int, int)> indices = search.Split(text); List sentences = new List(); foreach ((int startIndex, int endIndex) in indices) sentences.Add(text.Substring(startIndex, endIndex - startIndex + 1)); return sentences.ToArray(); @@ -592,9 +582,9 @@ async Task SplitSentences(string text) sentencesGT = (string[])sentences.Clone(); text = String.Join("", sentencesGT); - sentencesBack = await SplitSentences(text); + sentencesBack = SplitSentences(text); Assert.AreEqual(sentencesBack, sentencesGT); - key = await search.Add(text); + key = search.Add(text); Assert.AreEqual(search.Get(key), text); sentencesGT = (string[])sentences.Clone(); @@ -603,18 +593,18 @@ async Task SplitSentences(string text) sentencesGT[2] += ".... "; sentencesGT[3] += " ?"; text = String.Join("", sentencesGT); - sentencesBack = await SplitSentences(text); + sentencesBack = SplitSentences(text); Assert.AreEqual(sentencesBack, sentencesGT); - key = await search.Add(text); + key = search.Add(text); Assert.AreEqual(search.Get(key), text); for (int length = 10; length <= 100; length += 10) { (string randomText, List<(int, int)> indicesGT) = GenerateText(length); - List<(int, int)> indices = await search.Split(randomText); + List<(int, int)> indices = search.Split(randomText); Assert.AreEqual(indices.Count, indicesGT.Count); Assert.AreEqual(indices, indicesGT); - key = await search.Add(randomText); + key = search.Add(randomText); Assert.AreEqual(search.Get(key), randomText); } @@ -661,24 +651,24 @@ public class TestRAG_DBSearch_TokenSplitter : TestRAG public abstract class TestRAG_Chunking : TestRAG { - public override async Task TestSearch() + public override void TestSearch() { - await base.TestSearch(); + base.TestSearch(); string[] results; float[] distances; - await search.Add(weather + raining); - await search.Add(sometext); + search.Add(weather + raining); + search.Add(sometext); search.ReturnChunks(false); - (results, distances) = await search.Search(weather, 1); + (results, distances) = search.Search(weather, 1); Assert.That(results.Length == 1); Assert.That(distances.Length == 1); Assert.AreEqual(results[0], weather + raining); search.ReturnChunks(true); - (results, distances) = await search.Search(weather, 1); + (results, distances) = search.Search(weather, 1); Assert.That(results.Length == 1); Assert.That(distances.Length == 1); Assert.AreEqual(results[0], weather); diff --git a/VERSION b/VERSION index f6dcb643..ad55eb85 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v2.5.2 +v3.0.0 diff --git a/package.json b/package.json index 7095e04a..0dea1d41 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "ai.undream.llm", - "version": "2.5.2", + "version": "3.0.0", "displayName": "LLM for Unity", "description": "LLM for Unity allows to run and distribute Large Language Models (LLMs) in the Unity engine.", "unity": "2022.3", @@ -61,6 +61,9 @@ "path": "Samples~/KnowledgeBaseGame" } ], + "dependencies": { + "com.unity.nuget.newtonsoft-json": "3.0.2" + }, "author": { "name": "Undream AI", "email": "hello@undream.ai",