Skip to content

Commit

Permalink
Merge pull request #28 from undreamai/feature/awake_instead_of_onenable
Browse files Browse the repository at this point in the history
Start server on Awake instead of OnEnable
  • Loading branch information
amakropoulos authored Jan 16, 2024
2 parents bc72492 + 9919343 commit 3c3072e
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 12 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ Create a GameObject for the LLM :chess_pawn::

In your script you can then use it as follows :unicorn::
``` c#
using LLMUnity;

public class MyScript {
public LLM llm;

Expand Down Expand Up @@ -125,6 +127,36 @@ You can also:
}
```

</details>
<details>
<summary>Add a LLM / LLMClient component dynamically</summary>

``` c#
using UnityEngine;
using LLMUnity;

public class MyScript : MonoBehaviour
{
LLM llm;
LLMClient llmclient;

async void Start()
{
// Add and setup a LLM object
gameObject.SetActive(false);
llm = gameObject.AddComponent<LLM>();
await llm.SetModel("mistral-7b-instruct-v0.1.Q4_K_M.gguf");
llm.prompt = "A chat between a curious human and an artificial intelligence assistant.";
gameObject.SetActive(true);
// or a LLMClient object
gameObject.SetActive(false);
llmclient = gameObject.AddComponent<LLMClient>();
llmclient.prompt = "A chat between a curious human and an artificial intelligence assistant.";
gameObject.SetActive(true);
}
}
```

</details>


Expand Down
10 changes: 5 additions & 5 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private static string GetAssetPath(string relPath = "")

#if UNITY_EDITOR
[InitializeOnLoadMethod]
private static async void InitializeOnLoad()
private static async Task InitializeOnLoad()
{
// Perform download when the build is finished
await DownloadBinaries();
Expand Down Expand Up @@ -86,7 +86,7 @@ public void SetModelProgress(float progress)
modelProgress = progress;
}

public async void SetModel(string path)
public async Task SetModel(string path)
{
// set the model and enable the model editor properties
modelCopyProgress = 0;
Expand All @@ -95,7 +95,7 @@ public async void SetModel(string path)
modelCopyProgress = 1;
}

public async void SetLora(string path)
public async Task SetLora(string path)
{
// set the lora and enable the model editor properties
modelCopyProgress = 0;
Expand All @@ -106,11 +106,11 @@ public async void SetLora(string path)

#endif

new public void OnEnable()
new public void Awake()
{
// start the llm server and run the OnEnable of the client
StartLLMServer();
base.OnEnable();
base.Awake();
}

private string SelectApeBinary()
Expand Down
2 changes: 1 addition & 1 deletion Runtime/LLMClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public LLMClient()
chat.Add(new ChatMessage { role = "system", content = prompt });
}

public async void OnEnable()
public async void Awake()
{
// initialise the prompt and set the keep tokens based on its length
currentPrompt = prompt;
Expand Down
9 changes: 8 additions & 1 deletion Runtime/LLMUnitySetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
using Debug = UnityEngine.Debug;
using System.Threading.Tasks;
using System.Collections.Generic;
using System;

namespace LLMUnity
{
public delegate void EmptyCallback();
public delegate void Callback<T>(T message);
public delegate Task TaskCallback<T>(T message);
public delegate T2 ContentCallback<T, T2>(T message);

public class LLMUnitySetup : MonoBehaviour
Expand Down Expand Up @@ -60,7 +62,7 @@ public static string RunProcess(string command, string commandArgs = "", Callbac
#if UNITY_EDITOR
public static async Task DownloadFile(
string fileUrl, string savePath, bool executable = false,
Callback<string> callback = null, Callback<float> progresscallback = null,
TaskCallback<string> callback = null, Callback<float> progresscallback = null,
int chunkSize = 1024 * 1024)
{
// download a file to the specified path
Expand Down Expand Up @@ -129,6 +131,11 @@ public static async Task DownloadFile(

public static async Task<string> AddAsset(string assetPath, string basePath)
{
if (!File.Exists(assetPath))
{
Debug.LogError($"{assetPath} does not exist!");
return null;
}
// add an asset to the basePath directory if it is not already there and return the relative path
Directory.CreateDirectory(basePath);
string fullPath = Path.GetFullPath(assetPath);
Expand Down
12 changes: 7 additions & 5 deletions Runtime/undream.llmunity.Runtime.api
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ namespace LLMUnity
[ServerAdvanced] public int parallelPrompts;
public System.Threading.ManualResetEvent serverStarted;
public LLM() {}
public void Awake();
public void DownloadModel();
public void OnDestroy();
public void OnEnable();
public static void SetBinariesProgress(float progress);
public void SetLora(string path);
public void SetModel(string path);
public System.Threading.Tasks.Task SetLora(string path);
public System.Threading.Tasks.Task SetModel(string path);
public void SetModelProgress(float progress);
public void StopProcess();
}
Expand All @@ -137,14 +137,14 @@ namespace LLMUnity
[ModelAdvanced] public int topK = 40;
[ModelAdvanced] public float topP = 0.9f;
public LLMClient() {}
public void Awake();
public System.Threading.Tasks.Task Chat(string question, LLMUnity.Callback<string> callback = default(LLMUnity.Callback<string>), EmptyCallback completionCallback = default(EmptyCallback), bool addToHistory = true);
public string ChatContent(ChatResult result);
public string ChatOpenAIContent(ChatOpenAIResult result);
public Ret ConvertContent<Res, Ret>(string response, LLMUnity.ContentCallback<Res, Ret> getContent = default(LLMUnity.ContentCallback<Res, Ret>));
public ChatRequest GenerateRequest(string message, bool openAIFormat = false);
public string MultiChatContent(MultiChatResult result);
public string[] MultiResponse(string response);
public void OnEnable();
public System.Threading.Tasks.Task<Ret> PostRequest<Res, Ret>(string json, string endpoint, LLMUnity.ContentCallback<Res, Ret> getContent, LLMUnity.Callback<Ret> callback = default(LLMUnity.Callback<Ret>));
public System.Threading.Tasks.Task Tokenize(string question, LLMUnity.Callback<System.Collections.Generic.List<int>> callback = default(LLMUnity.Callback<System.Collections.Generic.List<int>>));
public System.Collections.Generic.List<int> TokenizeContent(TokenizeResult result);
Expand All @@ -156,7 +156,7 @@ namespace LLMUnity
public LLMUnitySetup() {}
public static System.Threading.Tasks.Task<string> AddAsset(string assetPath, string basePath);
public static System.Diagnostics.Process CreateProcess(string command, string commandArgs = @"", LLMUnity.Callback<string> outputCallback = default(LLMUnity.Callback<string>), LLMUnity.Callback<string> errorCallback = default(LLMUnity.Callback<string>), System.Collections.Generic.List<System.ValueTuple<string, string>> environment = default(System.Collections.Generic.List<System.ValueTuple<string, string>>), bool redirectOutput = false, bool redirectError = false);
public static System.Threading.Tasks.Task DownloadFile(string fileUrl, string savePath, bool executable = false, LLMUnity.Callback<string> callback = default(LLMUnity.Callback<string>), LLMUnity.Callback<float> progresscallback = default(LLMUnity.Callback<float>), int chunkSize = 1048576);
public static System.Threading.Tasks.Task DownloadFile(string fileUrl, string savePath, bool executable = false, LLMUnity.TaskCallback<string> callback = default(LLMUnity.TaskCallback<string>), LLMUnity.Callback<float> progresscallback = default(LLMUnity.Callback<float>), int chunkSize = 1048576);
public static string RunProcess(string command, string commandArgs = @"", LLMUnity.Callback<string> outputCallback = default(LLMUnity.Callback<string>), LLMUnity.Callback<string> errorCallback = default(LLMUnity.Callback<string>));
}

Expand Down Expand Up @@ -196,6 +196,8 @@ namespace LLMUnity
public System.DateTime timestamp;
}

public delegate System.Threading.Tasks.Task TaskCallback<T>(T message);

public struct TokenizeRequest
{
public string content;
Expand Down

0 comments on commit 3c3072e

Please sign in to comment.