@@ -3884,6 +3884,10 @@ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
3884
3884
return vocab.id_to_token [id].type == LLAMA_TOKEN_TYPE_BYTE;
3885
3885
}
3886
3886
3887
+ static bool llama_is_user_defined_token (const llama_vocab& vocab, llama_token id) {
3888
+ return vocab.id_to_token [id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
3889
+ }
3890
+
3887
3891
static uint8_t llama_token_to_byte (const llama_vocab& vocab, llama_token id) {
3888
3892
GGML_ASSERT (llama_is_byte_token (vocab, id));
3889
3893
const auto & token_data = vocab.id_to_token .at (id);
@@ -7224,47 +7228,53 @@ static std::string llama_decode_text(const std::string& text) {
7224
7228
// does not write null-terminator to buf
7225
7229
int llama_token_to_piece_with_model (const struct llama_model * model, llama_token token, char * buf, int length) {
7226
7230
if (0 <= token && token < llama_model_n_vocab (model)) {
7227
- if (llama_is_normal_token (model->vocab , token)) {
7228
- std::string result = model->vocab .id_to_token [token].text ;
7229
- if (llama_vocab_get_type (model->vocab ) == LLAMA_VOCAB_TYPE_SPM) {
7231
+ switch (llama_vocab_get_type (model->vocab )) {
7232
+ case LLAMA_VOCAB_TYPE_SPM: {
7233
+ if (llama_is_normal_token (model->vocab , token)) {
7234
+ std::string result = model->vocab .id_to_token [token].text ;
7230
7235
llama_unescape_whitespace (result);
7231
- } else if (llama_vocab_get_type (model->vocab ) == LLAMA_VOCAB_TYPE_BPE) {
7232
- result = llama_decode_text (result);
7233
- } else {
7234
- GGML_ASSERT (false );
7235
- }
7236
- if (length < (int ) result.length ()) {
7237
- return -result.length ();
7238
- }
7239
- memcpy (buf, result.c_str (), result.length ());
7240
- return result.length ();
7241
- } else if (llama_is_unknown_token (model->vocab , token)) { // NOLINT
7242
- if (length < 3 ) {
7243
- return -3 ;
7244
- }
7245
- buf[0 ] = ' \xe2 ' ;
7246
- buf[1 ] = ' \x96 ' ;
7247
- buf[2 ] = ' \x85 ' ;
7248
- return 3 ;
7249
- } else if (llama_is_control_token (model->vocab , token)) {
7250
- ;
7251
- } else if (llama_is_byte_token (model->vocab , token)) {
7252
- if (llama_vocab_get_type (model->vocab ) == LLAMA_VOCAB_TYPE_SPM) {
7236
+ if (length < (int ) result.length ()) {
7237
+ return -result.length ();
7238
+ }
7239
+ memcpy (buf, result.c_str (), result.length ());
7240
+ return result.length ();
7241
+ } else if (llama_is_unknown_token (model->vocab , token)) { // NOLINT
7242
+ if (length < 3 ) {
7243
+ return -3 ;
7244
+ }
7245
+ memcpy (buf, " \xe2\x96\x85 " , 3 );
7246
+ return 3 ;
7247
+ } else if (llama_is_control_token (model->vocab , token)) {
7248
+ ;
7249
+ } else if (llama_is_byte_token (model->vocab , token)) {
7253
7250
if (length < 1 ) {
7254
7251
return -1 ;
7255
7252
}
7256
7253
buf[0 ] = llama_token_to_byte (model->vocab , token);
7257
7254
return 1 ;
7258
- } else if (llama_vocab_get_type (model->vocab ) == LLAMA_VOCAB_TYPE_BPE) {
7259
- std::string result = llama_decode_text (model->vocab .id_to_token [token].text );
7260
- if (length < (int )result.length ()) {
7255
+ } else {
7256
+ GGML_ASSERT (false );
7257
+ }
7258
+ break ;
7259
+ }
7260
+ case LLAMA_VOCAB_TYPE_BPE: {
7261
+ if (llama_is_normal_token (model->vocab , token)) {
7262
+ std::string result = model->vocab .id_to_token [token].text ;
7263
+ result = llama_decode_text (result);
7264
+ if (length < (int ) result.length ()) {
7261
7265
return -result.length ();
7262
7266
}
7263
7267
memcpy (buf, result.c_str (), result.length ());
7264
7268
return result.length ();
7269
+ } else if (llama_is_control_token (model->vocab , token)) {
7270
+ ;
7265
7271
} else {
7266
7272
GGML_ASSERT (false );
7267
7273
}
7274
+ break ;
7275
+ }
7276
+ default :
7277
+ GGML_ASSERT (false );
7268
7278
}
7269
7279
}
7270
7280
return 0 ;
0 commit comments