Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct Decoding behavior in incremental manner #491

Merged
merged 6 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ members = [
"crates/ctranslate2-bindings",
"crates/rust-cxx-cmake-bridge",
"crates/llama-cpp-bindings",
"crates/stop-words",
"crates/http-api-bindings",
]

Expand Down
1 change: 0 additions & 1 deletion crates/ctranslate2-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ tokio = { workspace = true, features = ["rt"] }
tokio-util = { workspace = true }
tabby-inference = { path = "../tabby-inference" }
async-trait = { workspace = true }
stop-words = { path = "../stop-words" }
futures.workspace = true
async-stream.workspace = true

Expand Down
41 changes: 22 additions & 19 deletions crates/ctranslate2-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ use async_stream::stream;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
use stop_words::{StopWords, StopWordsCondition};
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
use tabby_inference::{
decoding::{DecodingFactory, IncrementalDecoding},
helpers, TextGeneration, TextGenerationOptions,
};
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc::{channel, Sender};
use tokio_util::sync::CancellationToken;
Expand Down Expand Up @@ -70,28 +72,28 @@ pub struct CTranslate2EngineOptions {
}

pub struct InferenceContext {
sender: Sender<u32>,
stop_condition: StopWordsCondition,
sender: Sender<String>,
decoding: IncrementalDecoding,
cancel: CancellationToken,
}

impl InferenceContext {
fn new(
sender: Sender<u32>,
stop_condition: StopWordsCondition,
sender: Sender<String>,
decoding: IncrementalDecoding,
cancel: CancellationToken,
) -> Self {
InferenceContext {
sender,
stop_condition,
decoding,
cancel,
}
}
}

pub struct CTranslate2Engine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
stop_words: StopWords,
decoding_factory: DecodingFactory,
tokenizer: Arc<Tokenizer>,
}

Expand All @@ -108,7 +110,7 @@ impl CTranslate2Engine {

return Self {
engine,
stop_words: StopWords::default(),
decoding_factory: DecodingFactory::default(),
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
};
}
Expand All @@ -133,12 +135,12 @@ impl TextGeneration for CTranslate2Engine {
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();

let stop_condition = self
.stop_words
.create_condition(self.tokenizer.clone(), options.stop_words);
let decoding = self
.decoding_factory
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);

let (sender, mut receiver) = channel::<u32>(8);
let context = InferenceContext::new(sender, stop_condition, cancel_for_inference);
let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
tokio::task::spawn(async move {
let context = Box::new(context);
engine.inference(
Expand All @@ -150,16 +152,15 @@ impl TextGeneration for CTranslate2Engine {
);
});

while let Some(next_token_id) = receiver.recv().await {
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
while let Some(text) = receiver.recv().await {
yield text;
}
};
Box::pin(s)
}
}

fn truncate_tokens(tokens: &[String], max_length: usize) -> &[String] {
fn truncate_tokens<T>(tokens: &[T], max_length: usize) -> &[T] {
if max_length < tokens.len() {
let start = tokens.len() - max_length;
&tokens[start..]
Expand All @@ -174,10 +175,12 @@ fn inference_callback(
token_id: u32,
_token: String,
) -> bool {
let _ = context.sender.blocking_send(token_id);
if context.cancel.is_cancelled() {
true
} else if let Some(new_text) = context.decoding.next_token(token_id) {
let _ = context.sender.blocking_send(new_text);
false
} else {
context.stop_condition.next_token(token_id)
true
}
}
1 change: 0 additions & 1 deletion crates/llama-cpp-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ tokio = { workspace = true, features = ["rt"] }
tabby-inference = { path = "../tabby-inference" }
derive_builder = { workspace = true }
tokenizers = { workspace = true }
stop-words = { version = "0.1.0", path = "../stop-words" }
tokio-util = { workspace = true }
futures.workspace = true
async-stream.workspace = true
4 changes: 2 additions & 2 deletions crates/llama-cpp-bindings/include/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();

virtual uint32_t start(const rust::Str prompt, size_t max_input_length) const = 0;
virtual uint32_t step(uint32_t next_token_id) const = 0;
virtual void start(rust::Slice<const uint32_t> input_token_ids) const = 0;
virtual uint32_t step() const = 0;
virtual void end() const = 0;

virtual uint32_t eos_token() const = 0;
Expand Down
11 changes: 5 additions & 6 deletions crates/llama-cpp-bindings/src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,21 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
ctx_(std::move(ctx)) {
}

uint32_t start(const rust::Str prompt, size_t max_input_length) const override {
void start(rust::Slice<const uint32_t> input_token_ids) const override {
auto* ctx = ctx_.get();
llama_reset_timings(ctx);
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), max_input_length, /* add_bos = */ false);
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());

for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) {
const size_t size = std::min(N_BATCH, tokens_list.size() - i);
eval(tokens_list.data() + i, size, /* reset = */ i == 0);
}
return sample();
}

uint32_t step(uint32_t next_token_id) const override {
const llama_token id = next_token_id;
uint32_t step() const override {
const llama_token id = sample();
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
return sample();
return id;
}

void end() const override {
Expand Down
60 changes: 31 additions & 29 deletions crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use async_trait::async_trait;
use derive_builder::Builder;
use ffi::create_engine;
use futures::{lock::Mutex, stream::BoxStream};
use stop_words::StopWords;
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
use tabby_inference::{decoding::DecodingFactory, helpers, TextGeneration, TextGenerationOptions};
use tokenizers::tokenizer::Tokenizer;

#[cxx::bridge(namespace = "llama")]
Expand All @@ -18,8 +17,8 @@ mod ffi {

fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;

fn start(&self, prompt: &str, max_input_length: usize) -> u32;
fn step(&self, next_token_id: u32) -> u32;
fn start(&self, input_token_ids: &[u32]);
fn step(&self) -> u32;
fn end(&self);

fn eos_token(&self) -> u32;
Expand All @@ -38,15 +37,15 @@ pub struct LlamaEngineOptions {
pub struct LlamaEngine {
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
tokenizer: Arc<Tokenizer>,
stop_words: StopWords,
decoding_factory: DecodingFactory,
}

impl LlamaEngine {
pub fn create(options: LlamaEngineOptions) -> Self {
LlamaEngine {
engine: Mutex::new(create_engine(&options.model_path)),
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
stop_words: StopWords::default(),
decoding_factory: DecodingFactory::default(),
}
}
}
Expand All @@ -63,35 +62,29 @@ impl TextGeneration for LlamaEngine {
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
let prompt = prompt.to_owned();
let mut stop_condition = self
.stop_words
.create_condition(self.tokenizer.clone(), options.stop_words);
let encoding = self.tokenizer.encode(prompt, true).unwrap();

let s = stream! {
let engine = self.engine.lock().await;
let eos_token = engine.eos_token();

let mut next_token_id = engine.start(&prompt, options.max_input_length);
if next_token_id == eos_token {
yield "".to_owned();
} else {
let mut n_remains = options.max_decoding_length - 1;

while n_remains > 0 {
next_token_id = engine.step(next_token_id);
if next_token_id == eos_token {
break;
}

if stop_condition.next_token(next_token_id) {
break;
}

let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
yield text;
n_remains -= 1;
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.start(input_token_ids);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
let mut n_remains = options.max_decoding_length ;
while n_remains > 0 {
let next_token_id = engine.step();
if next_token_id == eos_token {
break;
}

if let Some(new_text) = decoding.next_token(next_token_id) {
yield new_text;
} else {
break;
}

n_remains -= 1;
}

engine.end();
Expand All @@ -100,3 +93,12 @@ impl TextGeneration for LlamaEngine {
Box::pin(s)
}
}

fn truncate_tokens(tokens: &[u32], max_length: usize) -> &[u32] {
if max_length < tokens.len() {
let start = tokens.len() - max_length;
&tokens[start..]
} else {
tokens
}
}
11 changes: 0 additions & 11 deletions crates/stop-words/Cargo.toml

This file was deleted.

Loading