Skip to content

Commit b62e633

Browse files
authored
feat: support separate chat_template.jinja file (#1853)
1 parent 8ae3719 commit b62e633

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

lib/llm/src/model_card/create.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ impl ModelDeploymentCard {
8686
tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?),
8787
gen_config: None, // AFAICT there is no equivalent in a GGUF
8888
prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())),
89+
chat_template_file: None,
8990
prompt_context: None, // TODO - auto-detect prompt context
9091
revision: 0,
9192
last_published: None,
@@ -124,6 +125,7 @@ impl ModelDeploymentCard {
124125
tokenizer: Some(TokenizerKind::from_repo(repo_id).await?),
125126
gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional
126127
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?,
128+
chat_template_file: PromptFormatterArtifact::chat_template_from_repo(repo_id).await?,
127129
prompt_context: None, // TODO - auto-detect prompt context
128130
revision: 0,
129131
last_published: None,
@@ -157,6 +159,19 @@ impl PromptFormatterArtifact {
157159
.ok())
158160
}
159161

162+
pub async fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
163+
Ok(Self::chat_template_try_is_hf_repo(repo_id)
164+
.await
165+
.with_context(|| format!("unable to extract prompt format from repo {}", repo_id))
166+
.ok())
167+
}
168+
169+
async fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
170+
Ok(Self::HfChatTemplate(
171+
check_for_file(repo, "chat_template.jinja").await?,
172+
))
173+
}
174+
160175
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
161176
Ok(Self::HfTokenizerConfigJson(
162177
check_for_file(repo, "tokenizer_config.json").await?,

lib/llm/src/model_card/model.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pub enum TokenizerKind {
6262
#[serde(rename_all = "snake_case")]
6363
pub enum PromptFormatterArtifact {
6464
HfTokenizerConfigJson(String),
65+
HfChatTemplate(String),
6566
GGUF(PathBuf),
6667
}
6768

@@ -101,6 +102,10 @@ pub struct ModelDeploymentCard {
101102
#[serde(default, skip_serializing_if = "Option::is_none")]
102103
pub prompt_formatter: Option<PromptFormatterArtifact>,
103104

105+
/// chat template may be stored as a separate file instead of in `prompt_formatter`.
106+
#[serde(default, skip_serializing_if = "Option::is_none")]
107+
pub chat_template_file: Option<PromptFormatterArtifact>,
108+
104109
/// Generation config - default sampling params
105110
#[serde(default, skip_serializing_if = "Option::is_none")]
106111
pub gen_config: Option<GenerationConfig>,
@@ -259,6 +264,11 @@ impl ModelDeploymentCard {
259264
PromptFormatterArtifact::HfTokenizerConfigJson,
260265
"tokenizer_config.json"
261266
);
267+
nats_upload!(
268+
self.chat_template_file,
269+
PromptFormatterArtifact::HfChatTemplate,
270+
"chat_template.jinja"
271+
);
262272
nats_upload!(
263273
self.tokenizer,
264274
TokenizerKind::HfTokenizerJson,
@@ -308,6 +318,11 @@ impl ModelDeploymentCard {
308318
PromptFormatterArtifact::HfTokenizerConfigJson,
309319
"tokenizer_config.json"
310320
);
321+
nats_download!(
322+
self.chat_template_file,
323+
PromptFormatterArtifact::HfChatTemplate,
324+
"chat_template.jinja"
325+
);
311326
nats_download!(
312327
self.tokenizer,
313328
TokenizerKind::HfTokenizerJson,

lib/llm/src/preprocessor/prompt/template.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ mod oai;
2626
mod tokcfg;
2727

2828
use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
29-
use tokcfg::ChatTemplate;
29+
use tokcfg::{ChatTemplate, ChatTemplateValue};
3030

3131
impl PromptFormatter {
3232
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
@@ -37,13 +37,28 @@ impl PromptFormatter {
3737
PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
3838
let content = std::fs::read_to_string(&file)
3939
.with_context(|| format!("fs:read_to_string '{file}'"))?;
40-
let config: ChatTemplate = serde_json::from_str(&content)?;
40+
let mut config: ChatTemplate = serde_json::from_str(&content)?;
41+
// Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)
42+
// stores the chat template in a separate file, we check if the file exists and
43+
// put the chat template into config as normalization.
44+
if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) =
45+
mdc.chat_template_file
46+
{
47+
let chat_template = std::fs::read_to_string(&chat_template_file)
48+
.with_context(|| format!("fs:read_to_string '{}'", chat_template_file))?;
49+
// clean up the string to remove newlines
50+
let chat_template = chat_template.replace('\n', "");
51+
config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
52+
}
4153
Self::from_parts(
4254
config,
4355
mdc.prompt_context
4456
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
4557
)
4658
}
59+
PromptFormatterArtifact::HfChatTemplate(_) => Err(anyhow::anyhow!(
60+
"prompt_formatter should not have type HfChatTemplate"
61+
)),
4762
PromptFormatterArtifact::GGUF(gguf_path) => {
4863
let config = ChatTemplate::from_gguf(&gguf_path)?;
4964
Self::from_parts(config, ContextMixins::default())

0 commit comments

Comments
 (0)