Skip to content

Commit

Permalink
fix: remove accidentally included guideline from rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Nov 22, 2024
1 parent 4069955 commit 8770b39
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
17 changes: 3 additions & 14 deletions router/src/infer/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::infer::InferError;
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
use std::collections::HashSet;

/// Raise a exception (custom function) used in the chat templates
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Expand All @@ -15,7 +14,6 @@ pub(crate) struct ChatTemplate {
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
variables: HashSet<String>,
}

impl ChatTemplate {
Expand Down Expand Up @@ -47,21 +45,14 @@ impl ChatTemplate {
bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
variables,
}
}

pub(crate) fn apply(
&self,
guideline: Option<&str>,
mut messages: Vec<Message>,
tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> {
// check if guideline is expected but not provided
if self.variables.contains("guideline") && guideline.is_none() {
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
}

let tools = match tools_and_prompt {
Some((tools, tool_prompt)) => {
// check if the `tools` variable is used in the template
Expand All @@ -88,7 +79,6 @@ impl ChatTemplate {
let mut rendered_template = self
.template
.render(ChatTemplateInputs {
guideline,
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
Expand Down Expand Up @@ -782,7 +772,6 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
guideline: Some("Do not use offensive language."),
..Default::default()
},
target: "<s>You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n<start_of_turn>\nHuman Question: I'd like to show off how chat templating works!\n<end_of_turn>\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n",
Expand Down Expand Up @@ -843,7 +832,7 @@ mod tests {
},
];

let result = ct.apply(None, msgs, None);
let result = ct.apply(msgs, None);

match result {
Ok(_) => panic!("Should have failed since no guideline is provided"),
Expand Down Expand Up @@ -885,7 +874,7 @@ mod tests {
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, msgs, tools_and_prompt);
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
Expand Down Expand Up @@ -919,7 +908,7 @@ mod tests {
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, msgs, tools_and_prompt);
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
Expand Down
3 changes: 1 addition & 2 deletions router/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,13 @@ impl Infer {
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(
&self,
guideline: Option<String>,
messages: Vec<Message>,
tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> {
self.chat_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(guideline.as_deref(), messages, tools_and_prompt)
.apply(messages, tools_and_prompt)
.map_err(|e| {
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}");
Expand Down

0 comments on commit 8770b39

Please sign in to comment.