Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: tool call support #16

Merged
merged 1 commit into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/config/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ use std::env;
pub fn stream_buffer_size_bytes() -> usize {
env::var("STREAM_BUFFER_SIZE_BYTES")
.unwrap_or_else(|_| "1000".to_string())
.parse::<usize>()
.parse()
.unwrap_or(1000)
}

pub fn default_max_tokens() -> u32 {
env::var("DEFAULT_MAX_TOKENS")
.unwrap_or_else(|_| "4096".to_string())
.parse()
.unwrap_or(4096)
}
8 changes: 8 additions & 0 deletions src/models/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::collections::HashMap;
use super::content::ChatCompletionMessage;
use super::logprob::LogProbs;
use super::streaming::ChatCompletionChunk;
use super::tool_choice::ToolChoice;
use super::tool_definition::ToolDefinition;
use super::usage::Usage;

#[derive(Deserialize, Serialize, Clone)]
Expand All @@ -25,12 +27,18 @@ pub struct ChatCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}

Expand Down
7 changes: 6 additions & 1 deletion src/models/content.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};

use super::tool_calls::ChatMessageToolCall;

#[derive(Deserialize, Serialize, Clone)]
#[serde(untagged)]
pub enum ChatMessageContent {
Expand All @@ -17,7 +19,10 @@ pub struct ChatMessageContentPart {
#[derive(Deserialize, Serialize, Clone)]
pub struct ChatCompletionMessage {
pub role: String,
pub content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<ChatMessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ChatMessageToolCall>>,
}
2 changes: 2 additions & 0 deletions src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ pub mod embeddings;
pub mod logprob;
pub mod streaming;
pub mod tool_calls;
pub mod tool_choice;
pub mod tool_definition;
pub mod usage;
4 changes: 2 additions & 2 deletions src/models/streaming.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};

use super::logprob::ChoiceLogprobs;
use super::tool_calls::ChoiceDeltaToolCall;
use super::tool_calls::ChatMessageToolCall;
use super::usage::Usage;

#[derive(Deserialize, Serialize, Clone, Debug)]
Expand All @@ -11,7 +11,7 @@ pub struct ChoiceDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ChoiceDeltaToolCall>>,
pub tool_calls: Option<Vec<ChatMessageToolCall>>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
Expand Down
29 changes: 8 additions & 21 deletions src/models/tool_calls.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,15 @@
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ChoiceDeltaFunctionCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
pub struct FunctionCall {
pub arguments: String,
pub name: String,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ChoiceDeltaToolCallFunction {
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ChoiceDeltaToolCall {
pub index: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<ChoiceDeltaToolCallFunction>,
#[serde(skip_serializing_if = "Option::is_none")]
pub r#type: Option<String>,
pub struct ChatMessageToolCall {
pub id: String,
pub function: FunctionCall,
#[serde(rename = "type")]
pub r#type: String, // Using `function` as the only valid value
}
34 changes: 34 additions & 0 deletions src/models/tool_choice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Simple(SimpleToolChoice),
Named(ChatCompletionNamedToolChoice),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SimpleToolChoice {
None,
Auto,
Required,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionNamedToolChoice {
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: Function,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ToolType {
Function,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Function {
pub name: String,
}
23 changes: 23 additions & 0 deletions src/models/tool_definition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use std::collections::HashMap;

use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolDefinition {
pub function: FunctionDefinition,

#[serde(rename = "type")]
pub tool_type: String, // Will only accept "function" value
}

/// A definition of a function that can be called.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FunctionDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
6 changes: 6 additions & 0 deletions src/models/usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@ use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct CompletionTokensDetails {
#[serde(skip_serializing_if = "Option::is_none")]
pub accepted_prediction_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audio_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rejected_prediction_tokens: Option<u32>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct PromptTokensDetails {
#[serde(skip_serializing_if = "Option::is_none")]
pub audio_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_tokens: Option<u32>,
}

Expand Down
66 changes: 37 additions & 29 deletions src/pipelines/otel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl OtelTracer {
completion.choices.get_mut(chunk_choice.index as usize)
{
if let Some(content) = &chunk_choice.delta.content {
if let ChatMessageContent::String(existing_content) =
if let Some(ChatMessageContent::String(existing_content)) =
&mut existing_choice.message.content
{
existing_content.push_str(content);
Expand All @@ -86,6 +86,9 @@ impl OtelTracer {
if chunk_choice.finish_reason.is_some() {
existing_choice.finish_reason = chunk_choice.finish_reason.clone();
}
if let Some(tool_calls) = &chunk_choice.delta.tool_calls {
existing_choice.message.tool_calls = Some(tool_calls.clone());
}
} else {
completion.choices.push(ChatCompletionChoice {
index: chunk_choice.index,
Expand All @@ -96,9 +99,10 @@ impl OtelTracer {
.role
.clone()
.unwrap_or_else(|| "assistant".to_string()),
content: ChatMessageContent::String(
content: Some(ChatMessageContent::String(
chunk_choice.delta.content.clone().unwrap_or_default(),
),
)),
tool_calls: chunk_choice.delta.tool_calls.clone(),
},
finish_reason: chunk_choice.finish_reason.clone(),
logprobs: None,
Expand Down Expand Up @@ -150,19 +154,21 @@ impl RecordSpan for ChatCompletionRequest {
}

for (i, message) in self.messages.iter().enumerate() {
span.set_attribute(KeyValue::new(
format!("gen_ai.prompt.{}.role", i),
message.role.clone(),
));
span.set_attribute(KeyValue::new(
format!("gen_ai.prompt.{}.content", i),
match &message.content {
ChatMessageContent::String(content) => content.clone(),
ChatMessageContent::Array(content) => {
serde_json::to_string(content).unwrap_or_default()
}
},
));
if let Some(content) = &message.content {
span.set_attribute(KeyValue::new(
format!("gen_ai.prompt.{}.role", i),
message.role.clone(),
));
span.set_attribute(KeyValue::new(
format!("gen_ai.prompt.{}.content", i),
match &content {
ChatMessageContent::String(content) => content.clone(),
ChatMessageContent::Array(content) => {
serde_json::to_string(content).unwrap_or_default()
}
},
));
}
}
}
}
Expand All @@ -175,19 +181,21 @@ impl RecordSpan for ChatCompletion {
self.usage.record_span(span);

for choice in &self.choices {
span.set_attribute(KeyValue::new(
format!("gen_ai.completion.{}.role", choice.index),
choice.message.role.clone(),
));
span.set_attribute(KeyValue::new(
format!("gen_ai.completion.{}.content", choice.index),
match &choice.message.content {
ChatMessageContent::String(content) => content.clone(),
ChatMessageContent::Array(content) => {
serde_json::to_string(content).unwrap_or_default()
}
},
));
if let Some(content) = &choice.message.content {
span.set_attribute(KeyValue::new(
format!("gen_ai.completion.{}.role", choice.index),
choice.message.role.clone(),
));
span.set_attribute(KeyValue::new(
format!("gen_ai.completion.{}.content", choice.index),
match &content {
ChatMessageContent::String(content) => content.clone(),
ChatMessageContent::Array(content) => {
serde_json::to_string(content).unwrap_or_default()
}
},
));
}
span.set_attribute(KeyValue::new(
format!("gen_ai.completion.{}.finish_reason", choice.index),
choice.finish_reason.clone().unwrap_or_default(),
Expand Down
Loading