From 86c4eb41d1da74eaf96716902fa55ed929fb7426 Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Sat, 18 Mar 2023 09:09:40 -0400 Subject: [PATCH 1/9] WIP HTTP server for streaming inferences --- Cargo.lock | 555 +++++++++++++++++++++++++++++++++++++ Cargo.toml | 5 +- llama-http/Cargo.toml | 19 ++ llama-http/src/cli_args.rs | 79 ++++++ llama-http/src/main.rs | 129 +++++++++ 5 files changed, 785 insertions(+), 2 deletions(-) create mode 100644 llama-http/Cargo.toml create mode 100644 llama-http/src/cli_args.rs create mode 100644 llama-http/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 1c70a007..72256f17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + [[package]] name = "bincode" version = "1.3.3" @@ -26,12 +32,24 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bumpalo" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" + [[package]] name = "bytemuck" version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" +[[package]] +name = "bytes" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" + [[package]] name = "cc" version = "1.0.79" @@ -115,6 +133,114 @@ dependencies = [ "libc", ] +[[package]] +name = "flume" +version = "0.10.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "pin-project", + "spin", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "futures" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" + +[[package]] +name = "futures-executor" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91" + +[[package]] +name = "futures-macro" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2" + +[[package]] +name = "futures-task" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879" + +[[package]] +name = "futures-util" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "getrandom" version = "0.2.8" @@ -122,8 +248,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -133,6 +261,31 @@ dependencies = [ "cc", ] +[[package]] +name = "h2" +version = "0.3.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "heck" version = "0.4.1" @@ -154,12 +307,80 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +[[package]] +name = "http" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +dependencies = [ + "bytes", + "http", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "httpdate" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" + [[package]] name = "humantime" version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "hyper" +version = "0.14.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "indexmap" +version = "1.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399" +dependencies = [ + "autocfg", + "hashbrown", +] + [[package]] name = "io-lifetimes" version = "1.0.6" @@ -182,6 +403,21 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "itoa" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" + +[[package]] +name = "js-sys" +version = "0.3.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "libc" version = "0.2.140" @@ -207,6 +443,23 @@ dependencies = [ "rand", ] +[[package]] +name = "llama-http" +version = "0.1.0" +dependencies = [ + "clap", + "flume", + "futures", + "hyper", + "llama-rs", + "num_cpus", + "once_cell", + "rand", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "llama-rs" version = "0.1.0" @@ -220,6 +473,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "lock_api" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.17" @@ -235,6 +498,27 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "mio" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys", +] + +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "num_cpus" version = "1.15.0" @@ -257,12 +541,67 @@ version = "6.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + [[package]] name = "partial_sort" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7924d1d0ad836f665c9065e26d016c673ece3993f30d340068b16f282afc1156" +[[package]] +name = "pin-project" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -341,6 +680,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.7.1" @@ -372,6 +720,18 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + [[package]] name = "serde" version = "1.0.156" @@ -392,6 +752,60 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + +[[package]] +name = "slab" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smallvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" + +[[package]] +name = "socket2" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "spin" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5d6e0250b93c8427a177b849d144a96d5acc57006149479403d7861ab721e34" +dependencies = [ + "lock_api", +] + [[package]] name = "strsim" version = "0.10.0" @@ -438,6 +852,83 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio" +version = "1.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64" +dependencies = [ + "autocfg", + "bytes", + "libc", + "memchr", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-util" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + +[[package]] +name = "tracing" +version = "0.1.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +dependencies = [ + "cfg-if", + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" + [[package]] name = "unicode-ident" version = "1.0.8" @@ -450,12 +941,76 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "want" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ce8a968cb1cd110d136ff8b819a556d6fb6d919363c61534f6860c7eb172ba0" +dependencies = [ + "log", + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 43938e39..79657579 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,10 +2,11 @@ members = [ "ggml-raw", "llama-rs", - "llama-cli" + "llama-cli", + "llama-http" ] resolver = "2" [workspace.dependencies] -rand = "0.8.5" \ No newline at end of file +rand = "0.8.5" diff --git a/llama-http/Cargo.toml b/llama-http/Cargo.toml new file mode 100644 index 00000000..62fcfdda --- /dev/null +++ b/llama-http/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "llama-http" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +hyper = {version="0.14.25", features=["full"]} +futures = "*" +tokio = {version = "*", features=["full"]} +clap = { version = "4.1.8", features = ["derive"] } +once_cell = "1.17.1" +num_cpus = "1.15.0" +flume = "0.10.14" +llama-rs = { path = "../llama-rs" } +rand = { workspace = true } +serde = "*" +serde_json = "*" diff --git a/llama-http/src/cli_args.rs b/llama-http/src/cli_args.rs new file mode 100644 index 00000000..24fb6750 --- /dev/null +++ b/llama-http/src/cli_args.rs @@ -0,0 +1,79 @@ +use clap::Parser; +use once_cell::sync::Lazy; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// The port to listen on + #[arg(long, short = 'P', default_value_t = 8080)] + pub port: u16, + + /// Where to load the model path from + #[arg(long, short = 'm')] + pub model_path: String, + + /// The prompt to feed the generator. Prefixes all prompts served. + #[arg(long, short = 'p', default_value = None)] + pub prompt: Option, + + /// A file to read the prompt from. Takes precedence over `prompt` if set. Prefixes all prompts + /// served. + #[arg(long, short = 'f', default_value = None)] + pub prompt_file: Option, + + /// Sets the number of threads to use + #[arg(long, short = 't', default_value_t = num_cpus::get_physical())] + pub num_threads: usize, + + /// Sets how many tokens to predict + #[arg(long, short = 'n')] + pub num_predict: Option, + + /// Sets the size of the context (in tokens). Allows feeding longer prompts. + /// Note that this affects memory. TODO: Unsure how large the limit is. + #[arg(long, default_value_t = 512)] + pub num_ctx_tokens: usize, + + /// How many tokens from the prompt at a time to feed the network. Does not + /// affect generation. + #[arg(long, default_value_t = 8)] + pub batch_size: usize, + + /// Size of the 'last N' buffer that is used for the `repeat_penalty` + /// option. In tokens. + #[arg(long, default_value_t = 64)] + pub repeat_last_n: usize, + + /// The penalty for repeating tokens. Higher values make the generation less + /// likely to get into a loop, but may harm results when repetitive outputs + /// are desired. + #[arg(long, default_value_t = 1.30)] + pub repeat_penalty: f32, + + /// Temperature + #[arg(long, default_value_t = 0.80)] + pub temp: f32, + + /// Top-K: The top K words by score are kept during sampling. + #[arg(long, default_value_t = 40)] + pub top_k: usize, + + /// Top-p: The cummulative probability after which no more words are kept + /// for sampling. + #[arg(long, default_value_t = 0.95)] + pub top_p: f32, + + /// Stores a cached prompt at the given path. The same prompt can then be + /// loaded from disk using --restore-prompt + #[arg(long, default_value = None)] + pub cache_prompt: Option, + + /// Restores a cached prompt at the given path, previously using + /// --cache-prompt + #[arg(long, default_value = None)] + pub restore_prompt: Option, +} + +/// CLI args are stored in a lazy static variable so they're accessible from +/// everywhere. Arguments are parsed on first access. +pub static CLI_ARGS: Lazy = Lazy::new(Args::parse); diff --git a/llama-http/src/main.rs b/llama-http/src/main.rs new file mode 100644 index 00000000..5d2cfc07 --- /dev/null +++ b/llama-http/src/main.rs @@ -0,0 +1,129 @@ +use std::convert::Infallible; +use hyper::{Body, Method, Request, Response, Server}; +use hyper::service::{make_service_fn, service_fn}; +use llama_rs::{InferenceParameters, InferenceSession}; +use std::net::SocketAddr; +use futures::{SinkExt, channel::mpsc}; +use flume::{Sender, unbounded}; + +use serde::Deserialize; + +use rand::thread_rng; +mod cli_args; + +use cli_args::CLI_ARGS; + +#[derive(Debug, Deserialize)] +struct PredictionRequest { + num_predict: Option, + prompt: String, +} + +async fn handle_request(req: Request) -> Result, hyper::Error> { + match (req.method(), req.uri().path()) { + (&Method::POST, "/stream") => { + // Parse POST request body as a PredictionRequest + let body = hyper::body::to_bytes(req.into_body()).await?; + let prediction_request = match serde_json::from_slice::(&body) { + Ok(prediction_request) => prediction_request, + Err(_) => { + // Return 400 bad request if the body could not be parsed + let response = Response::builder() + .status(400) + .body(Body::empty()) + .unwrap(); + return Ok(response); + } + }; + + // Create a channel for the stream + let (tx, rx) = unbounded(); + let response_stream = rx.into_stream(); + inference_with_prediction_request(prediction_request, tx); + + // Create a response with a streaming body + let body = Body::wrap_stream(response_stream); + // Create a response with a streaming body + let response = Response::builder() + .header("Content-Type", "text/plain") + .body(body) + .unwrap(); + Ok(response) + }, + _ => { + // Return 404 not found for any other request + let response = Response::builder() + .status(404) + .body(Body::empty()) + .unwrap(); + Ok(response) + } + } +} + + +fn inference_with_prediction_request(prediction_request: PredictionRequest, tx: Sender>) { + let args = &*CLI_ARGS; + + let inference_params = InferenceParameters { + n_threads: args.num_threads as i32, + n_batch: args.batch_size, + top_k: args.top_k, + top_p: args.top_p, + repeat_penalty: args.repeat_penalty, + temp: args.temp, + }; + + // TODO Preload prompt + + + // Load model + let (model, vocabulary) = llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |_progress| { + println!("Loading model..."); + }).expect("Could not load model"); + + let mut rng = thread_rng(); + + let mut session = model.start_session(args.repeat_last_n); + + // print prompt + println!("{}", prediction_request.prompt); + + session.inference_with_prompt::( + &model, + &vocabulary, + &inference_params, + &prediction_request.prompt, + prediction_request.num_predict, + &mut rng, + { + let tx = tx.clone(); + move |t| { + // Send the generated text to the channel + let text = t.to_string(); + println!("{}", text); + tx.send(Ok(text)).unwrap(); + + Ok(()) + } + }, + ).expect("Could not run inference"); +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args = &*CLI_ARGS; + + let addr = SocketAddr::from(([127, 0, 0, 1], args.port)); + let server = Server::bind(&addr).serve(make_service_fn(|_| async { + Ok::<_, hyper::Error>(service_fn(handle_request)) + })); + + println!("Listening on http://{}", addr); + + if let Err(e) = server.await { + eprintln!("server error: {}", e); + } + + Ok(()) +} From 2cd1cf27204d01c6ebfbb8282b70f3f949292279 Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Sun, 19 Mar 2023 10:08:21 -0400 Subject: [PATCH 2/9] feat: HTTP server with streaming and full inference requests --- llama-http/src/inference.rs | 105 +++++++++++++++++++++++ llama-http/src/main.rs | 163 +++++++++++++++++++----------------- 2 files changed, 192 insertions(+), 76 deletions(-) create mode 100644 llama-http/src/inference.rs diff --git a/llama-http/src/inference.rs b/llama-http/src/inference.rs new file mode 100644 index 00000000..da9ab4ac --- /dev/null +++ b/llama-http/src/inference.rs @@ -0,0 +1,105 @@ +use llama_rs::{InferenceParameters, InferenceSnapshot}; +use rand::thread_rng; +use std::convert::Infallible; + +use crate::cli_args::CLI_ARGS; +use flume::{unbounded, Receiver, Sender}; + +#[derive(Debug)] +pub struct InferenceRequest { + /// The channel to send the tokens to. + pub tx_tokens: Sender>, + + pub num_predict: Option, + pub prompt: String, + pub n_batch: Option, + pub top_k: Option, + pub top_p: Option, + pub repeat_penalty: Option, + pub temp: Option, +} + +pub fn initialize_model_and_handle_inferences() -> Sender { + // Create a channel for InferenceRequests and spawn a thread to handle them + + let (tx, rx) = unbounded(); + + std::thread::spawn(move || { + let args = &*CLI_ARGS; + + // TODO Preload prompt + + // Load model + let (mut model, vocabulary) = + llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |_progress| { + println!("Loading model..."); + }) + .expect("Could not load model"); + + let mut session = if let Some(restore_path) = &args.restore_prompt { + let snapshot = InferenceSnapshot::load_from_disk(restore_path); + match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { + Ok(session) => { + println!("Restored cached memory from {restore_path}"); + session + } + Err(err) => { + eprintln!("Could not restore prompt. Error: {err}"); + std::process::exit(1); + } + } + } else { + model.start_session(args.repeat_last_n) + }; + + let mut rng = thread_rng(); + let rx: Receiver = rx.clone(); + loop { + if let Ok(inference_request) = rx.try_recv() { + let inference_params = InferenceParameters { + n_threads: args.num_threads as i32, + n_batch: inference_request.n_batch.unwrap_or(args.batch_size), + top_k: inference_request.top_k.unwrap_or(args.top_k), + top_p: inference_request.top_p.unwrap_or(args.top_p), + repeat_penalty: inference_request + .repeat_penalty + .unwrap_or(args.repeat_penalty), + temp: inference_request.temp.unwrap_or(args.temp), + }; + + // Run inference + session + .inference_with_prompt::( + &model, + &vocabulary, + &inference_params, + &inference_request.prompt, + inference_request.num_predict, + &mut rng, + { + let tx_tokens = inference_request.tx_tokens.clone(); + move |t| { + let text = t.to_string(); + match tx_tokens.send(Ok(text)) { + Ok(_) => { + println!("Sent token {} to receiver.", t); + } + Err(_) => { + // The receiver has been dropped. + println!("Could not send token to receiver."); + } + } + + Ok(()) + } + }, + ) + .expect("Could not run inference"); + } + + std::thread::sleep(std::time::Duration::from_millis(5)); + } + }); + + tx +} diff --git a/llama-http/src/main.rs b/llama-http/src/main.rs index 5d2cfc07..a4b59e66 100644 --- a/llama-http/src/main.rs +++ b/llama-http/src/main.rs @@ -1,122 +1,133 @@ -use std::convert::Infallible; -use hyper::{Body, Method, Request, Response, Server}; +use flume::{unbounded, Sender}; use hyper::service::{make_service_fn, service_fn}; -use llama_rs::{InferenceParameters, InferenceSession}; +use hyper::{Body, Method, Request, Response, Server}; use std::net::SocketAddr; -use futures::{SinkExt, channel::mpsc}; -use flume::{Sender, unbounded}; use serde::Deserialize; -use rand::thread_rng; mod cli_args; +mod inference; use cli_args::CLI_ARGS; +#[derive(Clone)] +struct HttpContext { + tx_inference_request: Sender, +} + +/// The JSON POST request body for the /stream endpoint. #[derive(Debug, Deserialize)] -struct PredictionRequest { +struct InferenceHttpRequest { num_predict: Option, prompt: String, + n_batch: Option, + top_k: Option, + top_p: Option, + repeat_penalty: Option, + temp: Option, } -async fn handle_request(req: Request) -> Result, hyper::Error> { +impl InferenceHttpRequest { + /// This function is used to convert the HTTP request into an `inference::InferenceRequest`. + /// This is is passed to the inference thread via a stream in the HTTP request. + /// + /// We cannot use the same `InferenceRequest` struct for parsing the HTTP response and + /// requesting an inference because the + /// inference thread needs to be able to send tokens back to the HTTP thread + /// via a channel, and this cannot be serialized. + fn to_inference_request( + &self, + tx_tokens: Sender>, + ) -> inference::InferenceRequest { + inference::InferenceRequest { + tx_tokens, + num_predict: self.num_predict, + prompt: self.prompt.clone(), + n_batch: self.n_batch, + top_k: self.top_k, + top_p: self.top_p, + repeat_penalty: self.repeat_penalty, + temp: self.temp, + } + } +} + +async fn handle_request( + context: HttpContext, + req: Request, +) -> Result, hyper::Error> { match (req.method(), req.uri().path()) { (&Method::POST, "/stream") => { // Parse POST request body as a PredictionRequest let body = hyper::body::to_bytes(req.into_body()).await?; - let prediction_request = match serde_json::from_slice::(&body) { - Ok(prediction_request) => prediction_request, + let inference_http_request = match serde_json::from_slice::(&body) { + Ok(inference_http_request) => inference_http_request, Err(_) => { // Return 400 bad request if the body could not be parsed - let response = Response::builder() - .status(400) - .body(Body::empty()) - .unwrap(); + let response = Response::builder().status(400).body(Body::empty()).unwrap(); return Ok(response); } }; // Create a channel for the stream - let (tx, rx) = unbounded(); - let response_stream = rx.into_stream(); - inference_with_prediction_request(prediction_request, tx); + let (tx_tokens, rx_tokens) = unbounded(); + + // Send the prediction request to the inference thread + let inference_request = inference_http_request.to_inference_request(tx_tokens); + context.tx_inference_request.send(inference_request).expect( + "Could not send request to inference thread - did the inference thread die?", + ); + + // Create a response channel. + let (mut tx_http, rx_http) = Body::channel(); + tokio::spawn(async move { + // Read tokens from the channel and send them to the response channel + while let Ok(token) = rx_tokens.recv() { + let token = token.unwrap(); + + // Add a newline to the token to get around Hyper's buffering. + let token = format!("{}\n", token); + + if let Err(error) = tx_http.send_data(token.into()).await { + eprintln!("Error sending data to client: {}", error); + break; + } + } + }); - // Create a response with a streaming body - let body = Body::wrap_stream(response_stream); // Create a response with a streaming body let response = Response::builder() .header("Content-Type", "text/plain") - .body(body) + .body(rx_http) .unwrap(); + Ok(response) - }, + } _ => { // Return 404 not found for any other request - let response = Response::builder() - .status(404) - .body(Body::empty()) - .unwrap(); + let response = Response::builder().status(404).body(Body::empty()).unwrap(); Ok(response) } } } - -fn inference_with_prediction_request(prediction_request: PredictionRequest, tx: Sender>) { - let args = &*CLI_ARGS; - - let inference_params = InferenceParameters { - n_threads: args.num_threads as i32, - n_batch: args.batch_size, - top_k: args.top_k, - top_p: args.top_p, - repeat_penalty: args.repeat_penalty, - temp: args.temp, - }; - - // TODO Preload prompt - - - // Load model - let (model, vocabulary) = llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |_progress| { - println!("Loading model..."); - }).expect("Could not load model"); - - let mut rng = thread_rng(); - - let mut session = model.start_session(args.repeat_last_n); - - // print prompt - println!("{}", prediction_request.prompt); - - session.inference_with_prompt::( - &model, - &vocabulary, - &inference_params, - &prediction_request.prompt, - prediction_request.num_predict, - &mut rng, - { - let tx = tx.clone(); - move |t| { - // Send the generated text to the channel - let text = t.to_string(); - println!("{}", text); - tx.send(Ok(text)).unwrap(); - - Ok(()) - } - }, - ).expect("Could not run inference"); -} - #[tokio::main] async fn main() -> Result<(), Box> { let args = &*CLI_ARGS; + let request_tx = inference::initialize_model_and_handle_inferences(); + let addr = SocketAddr::from(([127, 0, 0, 1], args.port)); - let server = Server::bind(&addr).serve(make_service_fn(|_| async { - Ok::<_, hyper::Error>(service_fn(handle_request)) + + // Make HttpContext available to all requests + let http_context = HttpContext { + tx_inference_request: request_tx, + }; + + let server = Server::bind(&addr).serve(make_service_fn(move |_| { + let http_context = http_context.clone(); + let service = service_fn(move |req| handle_request(http_context.clone(), req)); + async move { Ok::<_, hyper::Error>(service) } })); println!("Listening on http://{}", addr); From 1608a5696b143e05c49658d04bbcdac29a38ad63 Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Sun, 19 Mar 2023 10:24:32 -0400 Subject: [PATCH 3/9] format --- llama-http/src/main.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llama-http/src/main.rs b/llama-http/src/main.rs index a4b59e66..03e96657 100644 --- a/llama-http/src/main.rs +++ b/llama-http/src/main.rs @@ -29,10 +29,10 @@ struct InferenceHttpRequest { impl InferenceHttpRequest { /// This function is used to convert the HTTP request into an `inference::InferenceRequest`. - /// This is is passed to the inference thread via a stream in the HTTP request. + /// This is is passed to the inference thread via a stream in the HTTP request. /// /// We cannot use the same `InferenceRequest` struct for parsing the HTTP response and - /// requesting an inference because the + /// requesting an inference because the /// inference thread needs to be able to send tokens back to the HTTP thread /// via a channel, and this cannot be serialized. fn to_inference_request( @@ -60,7 +60,8 @@ async fn handle_request( (&Method::POST, "/stream") => { // Parse POST request body as a PredictionRequest let body = hyper::body::to_bytes(req.into_body()).await?; - let inference_http_request = match serde_json::from_slice::(&body) { + let inference_http_request = match serde_json::from_slice::(&body) + { Ok(inference_http_request) => inference_http_request, Err(_) => { // Return 400 bad request if the body could not be parsed From e65d8f6a188914aa3c0c950e51d011988a1e5f22 Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Sun, 19 Mar 2023 10:26:47 -0400 Subject: [PATCH 4/9] docs: fix reference to old name in comment --- llama-http/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-http/src/main.rs b/llama-http/src/main.rs index 03e96657..4e07c1fe 100644 --- a/llama-http/src/main.rs +++ b/llama-http/src/main.rs @@ -58,7 +58,7 @@ async fn handle_request( ) -> Result, hyper::Error> { match (req.method(), req.uri().path()) { (&Method::POST, "/stream") => { - // Parse POST request body as a PredictionRequest + // Parse POST request body as an InferenceHttpRequest let body = hyper::body::to_bytes(req.into_body()).await?; let inference_http_request = match serde_json::from_slice::(&body) { From 5906227c39b095360669aa155797cbd1d4bea26d Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Sun, 26 Mar 2023 07:15:34 -0400 Subject: [PATCH 5/9] fix: update http cli args to be more accurate --- Cargo.lock | 534 ++++++++++++++++++++++++++++++++++++ llama-http/src/cli_args.rs | 26 +- llama-http/src/inference.rs | 24 +- 3 files changed, 562 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5fce7727..96739694 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + [[package]] name = "bincode" version = "1.3.3" @@ -26,12 +32,24 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bumpalo" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" + [[package]] name = "bytemuck" version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" +[[package]] +name = "bytes" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" + [[package]] name = "cc" version = "1.0.79" @@ -174,6 +192,114 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "flume" +version = "0.10.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "pin-project", + "spin", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "futures" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" + +[[package]] +name = "futures-executor" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91" + +[[package]] +name = "futures-macro" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2" + +[[package]] +name = "futures-task" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879" + +[[package]] +name = "futures-util" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "getrandom" version = "0.2.8" @@ -181,8 +307,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -192,6 +320,31 @@ dependencies = [ "cc", ] +[[package]] +name = "h2" +version = "0.3.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "heck" version = "0.4.1" @@ -213,12 +366,80 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +[[package]] +name = "http" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +dependencies = [ + "bytes", + "http", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "httpdate" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" + [[package]] name = "humantime" version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "hyper" +version = "0.14.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown", +] + [[package]] name = "io-lifetimes" version = "1.0.6" @@ -241,6 +462,21 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "itoa" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" + +[[package]] +name = "js-sys" +version = "0.3.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -274,6 +510,23 @@ dependencies = [ "spinners", ] +[[package]] +name = "llama-http" +version = "0.1.0" +dependencies = [ + "clap", + "flume", + "futures", + "hyper", + "llama-rs", + "num_cpus", + "once_cell", + "rand", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "llama-rs" version = "0.1.0" @@ -287,6 +540,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "lock_api" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.17" @@ -308,6 +571,27 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "mio" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys", +] + +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "nibble_vec" version = "0.1.0" @@ -351,12 +635,67 @@ version = "6.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + [[package]] name = "partial_sort" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7924d1d0ad836f665c9065e26d016c673ece3993f30d340068b16f282afc1156" +[[package]] +name = "pin-project" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -525,6 +864,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + [[package]] name = "scopeguard" version = "1.1.0" @@ -551,12 +896,60 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + +[[package]] +name = "slab" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +dependencies = [ + "autocfg", +] + [[package]] name = "smallvec" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +[[package]] +name = "socket2" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "spin" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5d6e0250b93c8427a177b849d144a96d5acc57006149479403d7861ab721e34" +dependencies = [ + "lock_api", +] + [[package]] name = "spinners" version = "4.1.0" @@ -648,6 +1041,83 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio" +version = "1.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64" +dependencies = [ + "autocfg", + "bytes", + "libc", + "memchr", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-util" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + +[[package]] +name = "tracing" +version = "0.1.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +dependencies = [ + "cfg-if", + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" + [[package]] name = "unicode-ident" version = "1.0.8" @@ -678,12 +1148,76 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "want" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ce8a968cb1cd110d136ff8b819a556d6fb6d919363c61534f6860c7eb172ba0" +dependencies = [ + "log", + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" + [[package]] name = "winapi" version = "0.3.9" diff --git a/llama-http/src/cli_args.rs b/llama-http/src/cli_args.rs index 24fb6750..a3c44570 100644 --- a/llama-http/src/cli_args.rs +++ b/llama-http/src/cli_args.rs @@ -12,23 +12,15 @@ pub struct Args { #[arg(long, short = 'm')] pub model_path: String, - /// The prompt to feed the generator. Prefixes all prompts served. - #[arg(long, short = 'p', default_value = None)] - pub prompt: Option, - - /// A file to read the prompt from. Takes precedence over `prompt` if set. Prefixes all prompts - /// served. - #[arg(long, short = 'f', default_value = None)] - pub prompt_file: Option, + /// Use 16-bit floats for model memory key and value. Ignored when restoring + /// from the cache. + #[arg(long, default_value_t = false)] + pub float16: bool, /// Sets the number of threads to use #[arg(long, short = 't', default_value_t = num_cpus::get_physical())] pub num_threads: usize, - /// Sets how many tokens to predict - #[arg(long, short = 'n')] - pub num_predict: Option, - /// Sets the size of the context (in tokens). Allows feeding longer prompts. /// Note that this affects memory. TODO: Unsure how large the limit is. #[arg(long, default_value_t = 512)] @@ -36,6 +28,7 @@ pub struct Args { /// How many tokens from the prompt at a time to feed the network. Does not /// affect generation. + /// This is the default value unless overridden by the request. #[arg(long, default_value_t = 8)] pub batch_size: usize, @@ -47,27 +40,26 @@ pub struct Args { /// The penalty for repeating tokens. Higher values make the generation less /// likely to get into a loop, but may harm results when repetitive outputs /// are desired. + /// This is the default value unless overridden by the request. #[arg(long, default_value_t = 1.30)] pub repeat_penalty: f32, /// Temperature + /// This is the default value unless overridden by the request. #[arg(long, default_value_t = 0.80)] pub temp: f32, /// Top-K: The top K words by score are kept during sampling. + /// This is the default value unless overridden by the request. #[arg(long, default_value_t = 40)] pub top_k: usize, /// Top-p: The cummulative probability after which no more words are kept /// for sampling. + /// This is the default value unless overridden by the request. #[arg(long, default_value_t = 0.95)] pub top_p: f32, - /// Stores a cached prompt at the given path. The same prompt can then be - /// loaded from disk using --restore-prompt - #[arg(long, default_value = None)] - pub cache_prompt: Option, - /// Restores a cached prompt at the given path, previously using /// --cache-prompt #[arg(long, default_value = None)] diff --git a/llama-http/src/inference.rs b/llama-http/src/inference.rs index da9ab4ac..f1782f81 100644 --- a/llama-http/src/inference.rs +++ b/llama-http/src/inference.rs @@ -1,4 +1,7 @@ -use llama_rs::{InferenceParameters, InferenceSnapshot}; +use llama_rs::{ + InferenceParameters, InferenceSessionParameters, InferenceSnapshot, ModelKVMemoryType, + TokenBias, +}; use rand::thread_rng; use std::convert::Infallible; @@ -27,8 +30,6 @@ pub fn initialize_model_and_handle_inferences() -> Sender { std::thread::spawn(move || { let args = &*CLI_ARGS; - // TODO Preload prompt - // Load model let (mut model, vocabulary) = llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |_progress| { @@ -44,12 +45,24 @@ pub fn initialize_model_and_handle_inferences() -> Sender { session } Err(err) => { - eprintln!("Could not restore prompt. Error: {err}"); + eprintln!("Could not restore from snapshot. Error: {err}"); std::process::exit(1); } } } else { - model.start_session(args.repeat_last_n) + let inference_session_params = { + let mem_typ = if args.float16 { + ModelKVMemoryType::Float16 + } else { + ModelKVMemoryType::Float32 + }; + InferenceSessionParameters { + memory_k_type: mem_typ, + memory_v_type: mem_typ, + last_n_size: args.repeat_last_n, + } + }; + model.start_session(inference_session_params) }; let mut rng = thread_rng(); @@ -65,6 +78,7 @@ pub fn initialize_model_and_handle_inferences() -> Sender { .repeat_penalty .unwrap_or(args.repeat_penalty), temp: inference_request.temp.unwrap_or(args.temp), + bias_tokens: TokenBias::default(), }; // Run inference From 7b52b83aa093fcc7fc85f04ea4abdba93927bda5 Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Sun, 26 Mar 2023 07:28:42 -0400 Subject: [PATCH 6/9] feat: use log crate in http --- Cargo.lock | 1 + llama-http/Cargo.toml | 1 + llama-http/src/inference.rs | 63 ++++++++++++++++++++++++++++++++----- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 96739694..cda3fbda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -519,6 +519,7 @@ dependencies = [ "futures", "hyper", "llama-rs", + "log", "num_cpus", "once_cell", "rand", diff --git a/llama-http/Cargo.toml b/llama-http/Cargo.toml index 62fcfdda..db7e605f 100644 --- a/llama-http/Cargo.toml +++ b/llama-http/Cargo.toml @@ -17,3 +17,4 @@ llama-rs = { path = "../llama-rs" } rand = { workspace = true } serde = "*" serde_json = "*" +log = "0.4" diff --git a/llama-http/src/inference.rs b/llama-http/src/inference.rs index f1782f81..478ed57f 100644 --- a/llama-http/src/inference.rs +++ b/llama-http/src/inference.rs @@ -1,6 +1,6 @@ use llama_rs::{ - InferenceParameters, InferenceSessionParameters, InferenceSnapshot, ModelKVMemoryType, - TokenBias, + InferenceParameters, InferenceSessionParameters, InferenceSnapshot, LoadProgress, + ModelKVMemoryType, TokenBias, }; use rand::thread_rng; use std::convert::Infallible; @@ -32,8 +32,55 @@ pub fn initialize_model_and_handle_inferences() -> Sender { // Load model let (mut model, vocabulary) = - llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |_progress| { - println!("Loading model..."); + llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |progress| { + match progress { + LoadProgress::HyperparametersLoaded(hparams) => { + log::debug!("Loaded HyperParams {hparams:#?}") + } + LoadProgress::BadToken { index } => { + log::info!("Warning: Bad token in vocab at index {index}") + } + LoadProgress::ContextSize { bytes } => log::info!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::MemorySize { bytes, n_mem } => log::info!( + "Memory size: {} MB {}", + bytes as f32 / 1024.0 / 1024.0, + n_mem + ), + LoadProgress::PartLoading { + file, + current_part, + total_parts, + } => log::info!( + "Loading model part {}/{} from '{}'\n", + current_part, + total_parts, + file.to_string_lossy(), + ), + LoadProgress::PartTensorLoaded { + current_tensor, + tensor_count, + .. + } => { + if current_tensor % 8 == 0 { + log::info!("Loaded tensor {current_tensor}/{tensor_count}"); + } + } + LoadProgress::PartLoaded { + file, + byte_size, + tensor_count, + } => { + log::info!("Loading of '{}' complete", file.to_string_lossy()); + log::info!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); + } + } }) .expect("Could not load model"); @@ -41,11 +88,11 @@ pub fn initialize_model_and_handle_inferences() -> Sender { let snapshot = InferenceSnapshot::load_from_disk(restore_path); match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { Ok(session) => { - println!("Restored cached memory from {restore_path}"); + log::info!("Restored cached memory from {restore_path}"); session } Err(err) => { - eprintln!("Could not restore from snapshot. Error: {err}"); + log::error!("Could not restore from snapshot. Error: {err}"); std::process::exit(1); } } @@ -96,11 +143,11 @@ pub fn initialize_model_and_handle_inferences() -> Sender { let text = t.to_string(); match tx_tokens.send(Ok(text)) { Ok(_) => { - println!("Sent token {} to receiver.", t); + log::debug!("Sent token {} to receiver.", t); } Err(_) => { // The receiver has been dropped. - println!("Could not send token to receiver."); + log::warn!("Could not send token to receiver."); } } From 3a9a44da4b0319b77aec3f20a56cafa8122f764a Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Sun, 26 Mar 2023 07:28:59 -0400 Subject: [PATCH 7/9] chore: fix the major versions in cargo.toml --- llama-http/Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama-http/Cargo.toml b/llama-http/Cargo.toml index db7e605f..91038a3d 100644 --- a/llama-http/Cargo.toml +++ b/llama-http/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] hyper = {version="0.14.25", features=["full"]} -futures = "*" +futures = "0.3" tokio = {version = "*", features=["full"]} clap = { version = "4.1.8", features = ["derive"] } once_cell = "1.17.1" @@ -15,6 +15,6 @@ num_cpus = "1.15.0" flume = "0.10.14" llama-rs = { path = "../llama-rs" } rand = { workspace = true } -serde = "*" -serde_json = "*" +serde = "1.0" +serde_json = "1.0" log = "0.4" From 541b498c196f9604824e8b68c01b568f55391dba Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Sun, 26 Mar 2023 08:45:19 -0400 Subject: [PATCH 8/9] feat: add env_logger to http --- Cargo.lock | 1 + llama-http/Cargo.toml | 1 + llama-http/src/main.rs | 9 +++++++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cda3fbda..1bb7b863 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -515,6 +515,7 @@ name = "llama-http" version = "0.1.0" dependencies = [ "clap", + "env_logger", "flume", "futures", "hyper", diff --git a/llama-http/Cargo.toml b/llama-http/Cargo.toml index 91038a3d..5585b2e5 100644 --- a/llama-http/Cargo.toml +++ b/llama-http/Cargo.toml @@ -18,3 +18,4 @@ rand = { workspace = true } serde = "1.0" serde_json = "1.0" log = "0.4" +env_logger = "0.10.0" diff --git a/llama-http/src/main.rs b/llama-http/src/main.rs index 4e07c1fe..46190cc0 100644 --- a/llama-http/src/main.rs +++ b/llama-http/src/main.rs @@ -114,6 +114,11 @@ async fn handle_request( #[tokio::main] async fn main() -> Result<(), Box> { + env_logger::builder() + .filter_level(log::LevelFilter::Info) + .parse_default_env() + .init(); + let args = &*CLI_ARGS; let request_tx = inference::initialize_model_and_handle_inferences(); @@ -131,10 +136,10 @@ async fn main() -> Result<(), Box> { async move { Ok::<_, hyper::Error>(service) } })); - println!("Listening on http://{}", addr); + log::info!("Listening on http://{}", addr); if let Err(e) = server.await { - eprintln!("server error: {}", e); + log::error!("server error: {}", e); } Ok(()) From 89401073463e747eb533c3da583deeece049976b Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Sun, 26 Mar 2023 08:45:51 -0400 Subject: [PATCH 9/9] refactor: move model and session instantiation into InferenceSessionManager --- llama-http/src/inference.rs | 149 +++++++++++++++++++++--------------- 1 file changed, 89 insertions(+), 60 deletions(-) diff --git a/llama-http/src/inference.rs b/llama-http/src/inference.rs index 478ed57f..945486de 100644 --- a/llama-http/src/inference.rs +++ b/llama-http/src/inference.rs @@ -1,6 +1,6 @@ use llama_rs::{ - InferenceParameters, InferenceSessionParameters, InferenceSnapshot, LoadProgress, - ModelKVMemoryType, TokenBias, + InferenceParameters, InferenceSession, InferenceSessionParameters, InferenceSnapshot, + LoadProgress, ModelKVMemoryType, TokenBias, }; use rand::thread_rng; use std::convert::Infallible; @@ -30,8 +30,80 @@ pub fn initialize_model_and_handle_inferences() -> Sender { std::thread::spawn(move || { let args = &*CLI_ARGS; + let mut inference_session_manager = InferenceSessionManager::new(); + + let rx: Receiver = rx.clone(); + loop { + if let Ok(inference_request) = rx.try_recv() { + let mut session = inference_session_manager.get_session(); + let inference_params = InferenceParameters { + n_threads: args.num_threads as i32, + n_batch: inference_request.n_batch.unwrap_or(args.batch_size), + top_k: inference_request.top_k.unwrap_or(args.top_k), + top_p: inference_request.top_p.unwrap_or(args.top_p), + repeat_penalty: inference_request + .repeat_penalty + .unwrap_or(args.repeat_penalty), + temp: inference_request.temp.unwrap_or(args.temp), + bias_tokens: TokenBias::default(), + }; + let mut rng = thread_rng(); + // Run inference + let model = &(inference_session_manager.model); + let vocabulary = &(inference_session_manager.vocabulary); + session + .inference_with_prompt::( + model, + vocabulary, + &inference_params, + &inference_request.prompt, + inference_request.num_predict, + &mut rng, + { + let tx_tokens = inference_request.tx_tokens.clone(); + move |t| { + let text = t.to_string(); + match tx_tokens.send(Ok(text)) { + Ok(_) => { + log::debug!("Sent token {} to receiver.", t); + } + Err(_) => { + // The receiver has been dropped. + log::warn!("Could not send token to receiver."); + } + } + + Ok(()) + } + }, + ) + .expect("Could not run inference"); + } + + std::thread::sleep(std::time::Duration::from_millis(5)); + } + }); + + tx +} + +/// `InferenceSessionManager` is a way to create new sessions for a model and vocabulary. +/// In the future, it can also manage how many sessions are created and manage creating sessions +/// between threads. +struct InferenceSessionManager { + model: llama_rs::Model, + vocabulary: llama_rs::Vocabulary, +} + +impl InferenceSessionManager { + fn new() -> Self { + // TODO It's not a great pattern to inject these arguments from CLI_ARGS. + // If we ever wanted to support this struct in multiple places, please move the `args` + // variable into properties of this struct. + let args = &*CLI_ARGS; + // Load model - let (mut model, vocabulary) = + let (model, vocabulary) = llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |progress| { match progress { LoadProgress::HyperparametersLoaded(hparams) => { @@ -84,16 +156,24 @@ pub fn initialize_model_and_handle_inferences() -> Sender { }) .expect("Could not load model"); - let mut session = if let Some(restore_path) = &args.restore_prompt { + Self { model, vocabulary } + } + + fn get_session(&mut self) -> InferenceSession { + // TODO It's not a great pattern to inject these arguments from CLI_ARGS. + // If we ever wanted to support this struct in multiple places, please move the `args` + // variable into properties of this struct. + let args = &*CLI_ARGS; + + if let Some(restore_path) = &args.restore_prompt { let snapshot = InferenceSnapshot::load_from_disk(restore_path); - match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { + match snapshot.and_then(|snapshot| self.model.session_from_snapshot(snapshot)) { Ok(session) => { log::info!("Restored cached memory from {restore_path}"); session } Err(err) => { - log::error!("Could not restore from snapshot. Error: {err}"); - std::process::exit(1); + panic!("Could not restore from snapshot. Error: {err}"); } } } else { @@ -109,58 +189,7 @@ pub fn initialize_model_and_handle_inferences() -> Sender { last_n_size: args.repeat_last_n, } }; - model.start_session(inference_session_params) - }; - - let mut rng = thread_rng(); - let rx: Receiver = rx.clone(); - loop { - if let Ok(inference_request) = rx.try_recv() { - let inference_params = InferenceParameters { - n_threads: args.num_threads as i32, - n_batch: inference_request.n_batch.unwrap_or(args.batch_size), - top_k: inference_request.top_k.unwrap_or(args.top_k), - top_p: inference_request.top_p.unwrap_or(args.top_p), - repeat_penalty: inference_request - .repeat_penalty - .unwrap_or(args.repeat_penalty), - temp: inference_request.temp.unwrap_or(args.temp), - bias_tokens: TokenBias::default(), - }; - - // Run inference - session - .inference_with_prompt::( - &model, - &vocabulary, - &inference_params, - &inference_request.prompt, - inference_request.num_predict, - &mut rng, - { - let tx_tokens = inference_request.tx_tokens.clone(); - move |t| { - let text = t.to_string(); - match tx_tokens.send(Ok(text)) { - Ok(_) => { - log::debug!("Sent token {} to receiver.", t); - } - Err(_) => { - // The receiver has been dropped. - log::warn!("Could not send token to receiver."); - } - } - - Ok(()) - } - }, - ) - .expect("Could not run inference"); - } - - std::thread::sleep(std::time::Duration::from_millis(5)); + self.model.start_session(inference_session_params) } - }); - - tx + } }