Skip to content

Commit 61c4651

Browse files
author
Judd
committed
support reversed role
1 parent 8681325 commit 61c4651

File tree

7 files changed

+54
-11
lines changed

7 files changed

+54
-11
lines changed

models/cohere.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class ChatHistoryEncoder : public BaseHistoryEncoder
1414
void append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const override;
1515
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
1616
void append_ai_opening(int round_idx, std::vector<int> &ids) const override;
17+
void append_user_opening(int round_idx, std::vector<int> &ids) const override;
1718
};
1819

1920
static ChatHistoryEncoder _chat_encoder;
@@ -166,6 +167,14 @@ void ChatHistoryEncoder::append_ai_opening(int round_idx, std::vector<int> &ids)
166167
ids.push_back(tok->chatbot_token_id);
167168
}
168169

170+
void ChatHistoryEncoder::append_user_opening(int round_idx, std::vector<int> &ids) const
171+
{
172+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
173+
174+
ids.push_back(tok->start_of_turn_token_id);
175+
ids.push_back(tok->user_token_id);
176+
}
177+
169178
}
170179

171180
namespace aya_23

models/gemma.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class ChatHistoryEncoder : public BaseHistoryEncoder
1414
void append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const override;
1515
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
1616
void append_ai_opening(int round_idx, std::vector<int> &ids) const override;
17+
void append_user_opening(int round_idx, std::vector<int> &ids) const override;
1718
};
1819

1920
static ChatHistoryEncoder _chat_encoder;
@@ -153,10 +154,13 @@ void ChatHistoryEncoder::append_user(int round_idx, const std::string &user, std
153154
void ChatHistoryEncoder::append_ai_opening(int round_idx, std::vector<int> &ids) const
154155
{
155156
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
156-
std::ostringstream oss_prompt;
157+
tok->encode("model\n", ids, true, false);
158+
}
157159

158-
oss_prompt << "model" << "\n";
159-
tok->encode(oss_prompt.str(), ids, true, false);
160+
void ChatHistoryEncoder::append_user_opening(int round_idx, std::vector<int> &ids) const
161+
{
162+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
163+
tok->encode("user\n", ids, true, false);
160164
}
161165
}
162166

models/granite.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace moe
2121
void append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const override;
2222
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
2323
void append_ai_opening(int round_idx, std::vector<int> &ids) const override;
24+
void append_user_opening(int round_idx, std::vector<int> &ids) const override;
2425
};
2526

2627
static ChatHistoryEncoder _chat_encoder;
@@ -110,11 +111,15 @@ namespace moe
110111
void ChatHistoryEncoder::append_ai_opening(int round_idx, std::vector<int> &ids) const
111112
{
112113
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
113-
std::ostringstream oss;
114-
115114
tok->encode_header("assistant", ids);
116115
}
117116

117+
void ChatHistoryEncoder::append_user_opening(int round_idx, std::vector<int> &ids) const
118+
{
119+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
120+
tok->encode_header("user", ids);
121+
}
122+
118123
template <int NUM_EXPERTS, int EXPERTS_PER_TOK> class GraniteSparseMoE : public BaseSparseMLP
119124
{
120125
public:

models/llama.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ namespace v3
224224
void append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const override;
225225
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
226226
void append_ai_opening(int round_idx, std::vector<int> &ids) const override;
227+
void append_user_opening(int round_idx, std::vector<int> &ids) const override;
227228
};
228229

229230
static ChatHistoryEncoder _chat_encoder;
@@ -322,11 +323,15 @@ namespace v3
322323
void ChatHistoryEncoder::append_ai_opening(int round_idx, std::vector<int> &ids) const
323324
{
324325
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
325-
std::ostringstream oss;
326-
327326
tok->encode_header("assistant", ids);
328327
}
329328

329+
void ChatHistoryEncoder::append_user_opening(int round_idx, std::vector<int> &ids) const
330+
{
331+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
332+
tok->encode_header("user", ids);
333+
}
334+
330335
class ConditionalGeneration : public v2::GenericConditionalGeneration<LlamaBlock>
331336
{
332337
public:

models/minicpm.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ namespace v2
160160
void append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const override;
161161
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
162162
void append_ai_opening(int round_idx, std::vector<int> &ids) const override;
163+
void append_user_opening(int round_idx, std::vector<int> &ids) const override;
163164
};
164165

165166
static ChatHistoryEncoder _chat_encoder;
@@ -249,11 +250,13 @@ namespace v2
249250
void ChatHistoryEncoder::append_ai_opening(int round_idx, std::vector<int> &ids) const
250251
{
251252
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
253+
tok->encode("assistant\n", ids, true, false);
254+
}
252255

253-
std::ostringstream oss_prompt;
254-
255-
oss_prompt << "assistant\n";
256-
tok->encode(oss_prompt.str(), ids, true, false);
256+
void ChatHistoryEncoder::append_user_opening(int round_idx, std::vector<int> &ids) const
257+
{
258+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
259+
tok->encode("user\n", ids, true, false);
257260
}
258261
}
259262

models/smollm.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class ChatHistoryEncoder : public BaseHistoryEncoder
77
void append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const override;
88
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
99
void append_ai_opening(int round_idx, std::vector<int> &ids) const override;
10+
void append_user_opening(int round_idx, std::vector<int> &ids) const override;
1011
};
1112

1213
static ChatHistoryEncoder _chat_encoder;
@@ -71,6 +72,13 @@ void ChatHistoryEncoder::append_ai_opening(int round_idx, std::vector<int> &ids)
7172
tok->encode("assistant\n", ids);
7273
}
7374

75+
void ChatHistoryEncoder::append_user_opening(int round_idx, std::vector<int> &ids) const
76+
{
77+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
78+
ids.push_back(tok->bos_token_id);
79+
tok->encode("user\n", ids);
80+
}
81+
7482
class ConditionalGeneration : public llama::v2::GenericConditionalGeneration<LlamaBlock>
7583
{
7684
public:

models/yi.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class ChatHistoryEncoder : public BaseHistoryEncoder
1212
void append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const override;
1313
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
1414
void append_ai_opening(int round_idx, std::vector<int> &ids) const override;
15+
void append_user_opening(int round_idx, std::vector<int> &ids) const override;
1516
};
1617

1718
static ChatHistoryEncoder _chat_encoder;
@@ -102,6 +103,14 @@ void ChatHistoryEncoder::append_ai_opening(int round_idx, std::vector<int> &ids)
102103
tok->encode("assistant\n", ids);
103104
}
104105

106+
void ChatHistoryEncoder::append_user_opening(int round_idx, std::vector<int> &ids) const
107+
{
108+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
109+
110+
ids.push_back(tok->im_start_token_id);
111+
tok->encode("user\n", ids);
112+
}
113+
105114
bool Tokenizer::is_special_id(int id) const
106115
{
107116
return llama::v2::Tokenizer::is_special_id(id)

0 commit comments

Comments
 (0)