From ec0959db447cadf8cebb8185e5fe68b2683b412d Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 6 Apr 2024 18:33:17 -0700 Subject: [PATCH] [tokenizer] set truncatation to default Avoid crash when token exceed model max length. --- .../tokenizers/HuggingFaceTokenizer.java | 5 ++-- .../tokenizers/HuggingFaceTokenizerTest.java | 23 ++++++++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index ba4d61b79b1..80a7a3b5eb1 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -54,9 +54,8 @@ public final class HuggingFaceTokenizer extends NativeResource implements private HuggingFaceTokenizer(long handle, Map options) { super(handle); - String val = TokenizersLibrary.LIB.getTruncationStrategy(handle); - truncation = TruncationStrategy.fromValue(val); - val = TokenizersLibrary.LIB.getPaddingStrategy(handle); + truncation = TruncationStrategy.LONGEST_FIRST; + String val = TokenizersLibrary.LIB.getPaddingStrategy(handle); padding = PaddingStrategy.fromValue(val); maxLength = TokenizersLibrary.LIB.getMaxLength(handle); stride = TokenizersLibrary.LIB.getStride(handle); diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index 0c548d51aec..8b91c8590cd 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -40,7 +40,11 @@ public void testTokenizer() throws IOException { "[CLS]", "Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?", "[SEP]" }; - try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) { + try (HuggingFaceTokenizer tokenizer = + HuggingFaceTokenizer.builder() + .optTokenizerName("bert-base-cased") + .optTruncation(false) + .build()) { Assert.assertEquals(tokenizer.getTruncation(), "DO_NOT_TRUNCATE"); Assert.assertEquals(tokenizer.getPadding(), "DO_NOT_PAD"); Assert.assertEquals(tokenizer.getMaxLength(), -1); @@ -212,7 +216,10 @@ public void testMaxModelLengthTruncationAndAllPaddings() throws IOException { stringBuilder.append(repeat); } List inputs = Arrays.asList(stringBuilder.toString(), "This is a short sentence"); - try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) { + Map options = new ConcurrentHashMap<>(); + options.put("tokenizer", "bert-base-cased"); + options.put("truncation", "false"); + try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(options).build()) { int[] expectedNumberOfIdsNoTruncationNoPadding = new int[] {numRepeats * 2 + 2, 7}; Encoding[] encodings = tokenizer.batchEncode(inputs); for (int i = 0; i < encodings.length; ++i) { @@ -221,10 +228,7 @@ public void testMaxModelLengthTruncationAndAllPaddings() throws IOException { } } - Map options = new ConcurrentHashMap<>(); - options.put("tokenizer", "bert-base-cased"); - options.put("truncation", "true"); - try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(options).build()) { + try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) { int[] expectedSize = new int[] {512, 7}; Encoding[] encodings = tokenizer.batchEncode(inputs); for (int i = 0; i < encodings.length; ++i) { @@ -232,8 +236,11 @@ public void testMaxModelLengthTruncationAndAllPaddings() throws IOException { } } - options.put("padding", "true"); - try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(options).build()) { + try (HuggingFaceTokenizer tokenizer = + HuggingFaceTokenizer.builder() + .optTokenizerName("bert-base-cased") + .optPadding(true) + .build()) { Encoding[] encodings = tokenizer.batchEncode(inputs); for (Encoding encoding : encodings) { Assert.assertEquals(encoding.getIds().length, 512);