Skip to content

Commit

Permalink
feat: add support for OpenAI completion endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Syst3m1cAn0maly committed Jul 9, 2024
1 parent aabc22e commit 17d7a04
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
10 changes: 10 additions & 0 deletions crates/http-api-bindings/src/completion/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod llama;
mod mistral;
mod openai;

use std::sync::Arc;

use llama::LlamaCppEngine;
use mistral::MistralFIMEngine;
use openai::OpenAICompletionEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::CompletionStream;

Expand All @@ -24,6 +26,14 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
);
Arc::new(engine)
}
"openai/completion" => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
&model.api_endpoint,
model.api_key.clone(),
);
Arc::new(engine)

Check warning on line 35 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L29-L35

Added lines #L29 - L35 were not covered by tests
}

unsupported_kind => panic!(
"Unsupported model kind for http completion: {}",
Expand Down
92 changes: 92 additions & 0 deletions crates/http-api-bindings/src/completion/openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use async_stream::stream;
use async_trait::async_trait;
use futures::{stream::BoxStream, StreamExt};
use reqwest_eventsource::{Event, EventSource};
use serde::{Deserialize, Serialize};
use tabby_inference::{CompletionOptions, CompletionStream};

pub struct OpenAICompletionEngine {
client: reqwest::Client,
model_name: String,
api_endpoint: String,
api_key: Option<String>,
}

impl OpenAICompletionEngine {
pub fn create(model_name: Option<String>, api_endpoint: &str, api_key: Option<String>) -> Self {
let model_name = model_name.unwrap();
let client = reqwest::Client::new();

Self {
client,
model_name,
api_endpoint: format!("{}/completions", api_endpoint),
api_key,
}
}

Check warning on line 26 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L16-L26

Added lines #L16 - L26 were not covered by tests
}

#[derive(Serialize)]
struct CompletionRequest {
model: String,
prompt: String,
max_tokens: i32,
temperature: f32,
stream: bool,
presence_penalty: f32,
}

#[derive(Deserialize)]

Check warning on line 39 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L39

Added line #L39 was not covered by tests
struct CompletionResponseChunk {
choices: Vec<CompletionResponseChoice>,
}

#[derive(Deserialize)]

Check warning on line 44 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L44

Added line #L44 was not covered by tests
struct CompletionResponseChoice {
text: String,
finish_reason: Option<String>,
}

#[async_trait]
impl CompletionStream for OpenAICompletionEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let request = CompletionRequest {
model: self.model_name.clone(),
prompt: prompt.to_owned(),
max_tokens: options.max_decoding_tokens,
temperature: options.sampling_temperature,
stream: true,
presence_penalty: options.presence_penalty,
};

let mut request = self.client.post(&self.api_endpoint).json(&request);
if let Some(api_key) = &self.api_key {
request = request.bearer_auth(api_key);
}

let s = stream! {
let mut es = EventSource::new(request).expect("Failed to create event source");
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let x: CompletionResponseChunk = serde_json::from_str(&message.data).expect("Failed to parse response");
if let Some(choice) = x.choices.first() {
yield choice.text.clone();

if choice.finish_reason.is_some() {
break;
}
}
}
Err(_) => {
// StreamEnd
break;
}
}
}
};

Box::pin(s)
}

Check warning on line 91 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L52-L91

Added lines #L52 - L91 were not covered by tests
}

0 comments on commit 17d7a04

Please sign in to comment.