Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start server on Awake instead of OnEnable #28

Merged
merged 3 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ Create a GameObject for the LLM :chess_pawn::

In your script you can then use it as follows :unicorn::
``` c#
using LLMUnity;

public class MyScript {
public LLM llm;

Expand Down Expand Up @@ -125,6 +127,36 @@ You can also:
}
```

</details>
<details>
<summary>Add a LLM / LLMClient component dynamically</summary>

``` c#
using UnityEngine;
using LLMUnity;

public class MyScript : MonoBehaviour
{
LLM llm;
LLMClient llmclient;

async void Start()
{
// Add and setup a LLM object
gameObject.SetActive(false);
llm = gameObject.AddComponent<LLM>();
await llm.SetModel("mistral-7b-instruct-v0.1.Q4_K_M.gguf");
llm.prompt = "A chat between a curious human and an artificial intelligence assistant.";
gameObject.SetActive(true);
// or a LLMClient object
gameObject.SetActive(false);
llmclient = gameObject.AddComponent<LLMClient>();
llmclient.prompt = "A chat between a curious human and an artificial intelligence assistant.";
gameObject.SetActive(true);
}
}
```

</details>


Expand Down
10 changes: 5 additions & 5 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private static string GetAssetPath(string relPath = "")

#if UNITY_EDITOR
[InitializeOnLoadMethod]
private static async void InitializeOnLoad()
private static async Task InitializeOnLoad()
{
// Perform download when the build is finished
await DownloadBinaries();
Expand Down Expand Up @@ -86,7 +86,7 @@ public void SetModelProgress(float progress)
modelProgress = progress;
}

public async void SetModel(string path)
public async Task SetModel(string path)
{
// set the model and enable the model editor properties
modelCopyProgress = 0;
Expand All @@ -95,7 +95,7 @@ public async void SetModel(string path)
modelCopyProgress = 1;
}

public async void SetLora(string path)
public async Task SetLora(string path)
{
// set the lora and enable the model editor properties
modelCopyProgress = 0;
Expand All @@ -106,11 +106,11 @@ public async void SetLora(string path)

#endif

new public void OnEnable()
new public void Awake()
{
// start the llm server and run the OnEnable of the client
StartLLMServer();
base.OnEnable();
base.Awake();
}

private string SelectApeBinary()
Expand Down
2 changes: 1 addition & 1 deletion Runtime/LLMClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public LLMClient()
chat.Add(new ChatMessage {role = "system", content = prompt});
}

public async void OnEnable()
public async void Awake()
{
// initialise the prompt and set the keep tokens based on its length
currentPrompt = prompt;
Expand Down
9 changes: 8 additions & 1 deletion Runtime/LLMUnitySetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
using Debug = UnityEngine.Debug;
using System.Threading.Tasks;
using System.Collections.Generic;
using System;

namespace LLMUnity
{
public delegate void EmptyCallback();
public delegate void Callback<T>(T message);
public delegate Task TaskCallback<T>(T message);
public delegate T2 ContentCallback<T, T2>(T message);

public class LLMUnitySetup : MonoBehaviour
Expand Down Expand Up @@ -60,7 +62,7 @@ public static string RunProcess(string command, string commandArgs = "", Callbac
#if UNITY_EDITOR
public static async Task DownloadFile(
string fileUrl, string savePath, bool executable = false,
Callback<string> callback = null, Callback<float> progresscallback = null,
TaskCallback<string> callback = null, Callback<float> progresscallback = null,
int chunkSize = 1024 * 1024)
{
// download a file to the specified path
Expand Down Expand Up @@ -129,6 +131,11 @@ public static async Task DownloadFile(

public static async Task<string> AddAsset(string assetPath, string basePath)
{
if (!File.Exists(assetPath))
{
Debug.LogError($"{assetPath} does not exist!");
return null;
}
// add an asset to the basePath directory if it is not already there and return the relative path
Directory.CreateDirectory(basePath);
string fullPath = Path.GetFullPath(assetPath);
Expand Down
12 changes: 7 additions & 5 deletions Runtime/undream.llmunity.Runtime.api
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ namespace LLMUnity
[ServerAdvanced] public int parallelPrompts;
public System.Threading.ManualResetEvent serverStarted;
public LLM() {}
public void Awake();
public void DownloadModel();
public void OnDestroy();
public void OnEnable();
public static void SetBinariesProgress(float progress);
public void SetLora(string path);
public void SetModel(string path);
public System.Threading.Tasks.Task SetLora(string path);
public System.Threading.Tasks.Task SetModel(string path);
public void SetModelProgress(float progress);
public void StopProcess();
}
Expand All @@ -137,14 +137,14 @@ namespace LLMUnity
[ModelAdvanced] public int topK = 40;
[ModelAdvanced] public float topP = 0.9f;
public LLMClient() {}
public void Awake();
public System.Threading.Tasks.Task Chat(string question, LLMUnity.Callback<string> callback = default(LLMUnity.Callback<string>), EmptyCallback completionCallback = default(EmptyCallback), bool addToHistory = true);
public string ChatContent(ChatResult result);
public string ChatOpenAIContent(ChatOpenAIResult result);
public Ret ConvertContent<Res, Ret>(string response, LLMUnity.ContentCallback<Res, Ret> getContent = default(LLMUnity.ContentCallback<Res, Ret>));
public ChatRequest GenerateRequest(string message, bool openAIFormat = false);
public string MultiChatContent(MultiChatResult result);
public string[] MultiResponse(string response);
public void OnEnable();
public System.Threading.Tasks.Task<Ret> PostRequest<Res, Ret>(string json, string endpoint, LLMUnity.ContentCallback<Res, Ret> getContent, LLMUnity.Callback<Ret> callback = default(LLMUnity.Callback<Ret>));
public System.Threading.Tasks.Task Tokenize(string question, LLMUnity.Callback<System.Collections.Generic.List<int>> callback = default(LLMUnity.Callback<System.Collections.Generic.List<int>>));
public System.Collections.Generic.List<int> TokenizeContent(TokenizeResult result);
Expand All @@ -156,7 +156,7 @@ namespace LLMUnity
public LLMUnitySetup() {}
public static System.Threading.Tasks.Task<string> AddAsset(string assetPath, string basePath);
public static System.Diagnostics.Process CreateProcess(string command, string commandArgs = @"", LLMUnity.Callback<string> outputCallback = default(LLMUnity.Callback<string>), LLMUnity.Callback<string> errorCallback = default(LLMUnity.Callback<string>), System.Collections.Generic.List<System.ValueTuple<string, string>> environment = default(System.Collections.Generic.List<System.ValueTuple<string, string>>), bool redirectOutput = false, bool redirectError = false);
public static System.Threading.Tasks.Task DownloadFile(string fileUrl, string savePath, bool executable = false, LLMUnity.Callback<string> callback = default(LLMUnity.Callback<string>), LLMUnity.Callback<float> progresscallback = default(LLMUnity.Callback<float>), int chunkSize = 1048576);
public static System.Threading.Tasks.Task DownloadFile(string fileUrl, string savePath, bool executable = false, LLMUnity.TaskCallback<string> callback = default(LLMUnity.TaskCallback<string>), LLMUnity.Callback<float> progresscallback = default(LLMUnity.Callback<float>), int chunkSize = 1048576);
public static string RunProcess(string command, string commandArgs = @"", LLMUnity.Callback<string> outputCallback = default(LLMUnity.Callback<string>), LLMUnity.Callback<string> errorCallback = default(LLMUnity.Callback<string>));
}

Expand Down Expand Up @@ -196,6 +196,8 @@ namespace LLMUnity
public System.DateTime timestamp;
}

public delegate System.Threading.Tasks.Task TaskCallback<T>(T message);

public struct TokenizeRequest
{
public string content;
Expand Down