Skip to content

Commit

Permalink
feat(backend): rely on multi consumer queue to scheduler workers
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Nov 22, 2024
1 parent 84eead2 commit 5a85661
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 44 deletions.
49 changes: 49 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backends/llamacpp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ homepage.workspace = true

[dependencies]
async-trait = "0.1"
async-channel = "2.3"
clap = { version = "4.5.19", features = ["derive"] }
cxx = "1.0"
num_cpus = "1"
Expand Down
65 changes: 21 additions & 44 deletions backends/llamacpp/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::ffi::{
create_worker_frontend, set_numactl_core_affinity, GenerationParams, LlamaCppWorkerFrontend,
SamplingParams,
};
use async_channel::{unbounded as mpmc_unbounded, Receiver as MpmcReceiver, Sender as MpmcSender};
use async_trait::async_trait;
use cxx::UniquePtr;
use log::warn;
Expand All @@ -19,7 +20,6 @@ use text_generation_router::{FinishReason, Token};
use thiserror::Error;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::Semaphore;
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
Expand Down Expand Up @@ -102,18 +102,6 @@ pub enum LlamaCppBackendError {
ModelInitializationFailed(PathBuf, String),
}

struct LlamaCppWorker {
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
}

impl LlamaCppWorker {
fn submit(&self, ctx: GenerationContext, sx: UnboundedSender<InferResult>) {
if let Err(err) = self.sender.send((ctx, sx)) {
// TODO: What do we do?
}
}
}

pub struct LlamaCppBackend {
scheduler_sender: UnboundedSender<(GenerationContext, UnboundedSender<InferResult>)>,
scheduler_handle: JoinHandle<()>,
Expand Down Expand Up @@ -141,29 +129,26 @@ impl LlamaCppBackend {
));
}

let cores_allocation = get_cores_allocation(num_cores_per_instance as usize);
// Allocate the multi-consumer queue to orchestrate all the workers
let (backlog_submitter, backlog_receiver) = mpmc_unbounded();

// Allocate all the workers
let streams = cores_allocation
.iter()
.map(
|affinity| match Self::allocate_worker(path, num_cores_per_instance as u32) {
Ok(worker) => {
let tokenizer = Arc::clone(&tokenizer);
let (sender, receiver) = channel();
let affinity = affinity.clone().collect::<Vec<_>>();
spawn(move || worker_loop(worker, affinity, tokenizer, receiver));

Ok(LlamaCppWorker { sender })
}
Err(e) => Err(e),
},
)
.collect::<Result<Vec<_>, _>>()?;
let cores_allocation = get_cores_allocation(num_cores_per_instance as usize);
cores_allocation.iter().for_each(|affinity| {
match Self::allocate_worker(path, num_cores_per_instance as u32) {
Ok(worker) => {
let tokenizer = Arc::clone(&tokenizer);
let affinity = affinity.clone().collect::<Vec<_>>();
let backlog_receiver = backlog_receiver.clone();
spawn(move || worker_loop(worker, affinity, tokenizer, backlog_receiver));
}
Err(e) => {}
}
});

// Start the scheduler loop
let (scheduler_sender, scheduler_receiver) = unbounded_channel();
let scheduler_handle = tokio::spawn(scheduler_loop(scheduler_receiver, streams));
let scheduler_handle = tokio::spawn(scheduler_loop(scheduler_receiver, backlog_submitter));
Ok(Self {
scheduler_sender,
scheduler_handle,
Expand Down Expand Up @@ -263,24 +248,16 @@ fn llama_generate_callback(

async fn scheduler_loop(
mut queue: UnboundedReceiver<(GenerationContext, UnboundedSender<InferResult>)>,
mut workers: Vec<LlamaCppWorker>,
backlog: MpmcSender<(GenerationContext, UnboundedSender<InferResult>)>,
) {
// Semaphore allows us to wait for a worker to become available
let permits = Semaphore::new(workers.len());

// Let's receive incoming requests
loop {
match queue.recv().await {
None => break,
Some((ctx, sender)) => {
let permit = permits.try_acquire();
if let Err(err) = permit {
let _ = sender.send(Err(InferError::Overloaded(err)));
if let Err(e) = backlog.send((ctx, sender)).await {
todo!("What do we do")
}

// We can unwrap because we wouldn't have a semaphore available otherwise
let worker = workers.pop().unwrap();
worker.submit(ctx, sender);
}
}
}
Expand All @@ -290,7 +267,7 @@ fn worker_loop(
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
affinity: Vec<usize>,
tokenizer: Arc<Tokenizer>,
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
backlog: MpmcReceiver<(GenerationContext, UnboundedSender<InferResult>)>,
) {
// This loop will mostly decode single token at every step, so no need to rely on parallelism
tokenizers::utils::parallelism::set_parallelism(false);
Expand All @@ -299,7 +276,7 @@ fn worker_loop(
set_numactl_core_affinity(&affinity);

loop {
if let Ok((generation, stream)) = backlog.recv() {
if let Ok((generation, stream)) = backlog.recv_blocking() {
let start = Instant::now();
let generation_params = generation.generation_params; // copy
let sampling_params = generation.sampling_params; // copy
Expand Down

0 comments on commit 5a85661

Please sign in to comment.