Skip to content

Commit 7ef5129

Browse files
authored
Fix truncation issue in classify_review function (rasbt#373)
1 parent b56d0b2 commit 7ef5129

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

ch06/01_main-chapter-code/ch06.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2207,7 +2207,9 @@
22072207
"\n",
22082208
" # Prepare inputs to the model\n",
22092209
" input_ids = tokenizer.encode(text)\n",
2210-
" supported_context_length = model.pos_emb.weight.shape[1]\n",
2210+
" supported_context_length = model.pos_emb.weight.shape[0]\n",
2211+
" # Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake\n",
2212+
" # It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)\n",
22112213
"\n",
22122214
" # Truncate sequences if they too long\n",
22132215
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",

ch06/01_main-chapter-code/load-finetuned-model.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@
179179
"\n",
180180
" # Prepare inputs to the model\n",
181181
" input_ids = tokenizer.encode(text)\n",
182-
" supported_context_length = model.pos_emb.weight.shape[1]\n",
182+
" supported_context_length = model.pos_emb.weight.shape[0]\n",
183183
"\n",
184184
" # Truncate sequences if they too long\n",
185185
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",

ch06/04_user_interface/previous_chapters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def classify_review(text, model, tokenizer, device, max_length=None, pad_token_i
353353

354354
# Prepare inputs to the model
355355
input_ids = tokenizer.encode(text)
356-
supported_context_length = model.pos_emb.weight.shape[1]
356+
supported_context_length = model.pos_emb.weight.shape[0]
357357

358358
# Truncate sequences if they too long
359359
input_ids = input_ids[:min(max_length, supported_context_length)]

0 commit comments

Comments
 (0)