Skip to content

Commit

Permalink
[tokenizer] set truncatation to default
Browse files Browse the repository at this point in the history
Avoid crash when token exceed model max length.
  • Loading branch information
frankfliu committed Apr 7, 2024
1 parent f8791db commit ec0959d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ public final class HuggingFaceTokenizer extends NativeResource<Long> implements

private HuggingFaceTokenizer(long handle, Map<String, String> options) {
super(handle);
String val = TokenizersLibrary.LIB.getTruncationStrategy(handle);
truncation = TruncationStrategy.fromValue(val);
val = TokenizersLibrary.LIB.getPaddingStrategy(handle);
truncation = TruncationStrategy.LONGEST_FIRST;
String val = TokenizersLibrary.LIB.getPaddingStrategy(handle);
padding = PaddingStrategy.fromValue(val);
maxLength = TokenizersLibrary.LIB.getMaxLength(handle);
stride = TokenizersLibrary.LIB.getStride(handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ public void testTokenizer() throws IOException {
"[CLS]", "Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?", "[SEP]"
};

try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) {
try (HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder()
.optTokenizerName("bert-base-cased")
.optTruncation(false)
.build()) {
Assert.assertEquals(tokenizer.getTruncation(), "DO_NOT_TRUNCATE");
Assert.assertEquals(tokenizer.getPadding(), "DO_NOT_PAD");
Assert.assertEquals(tokenizer.getMaxLength(), -1);
Expand Down Expand Up @@ -212,7 +216,10 @@ public void testMaxModelLengthTruncationAndAllPaddings() throws IOException {
stringBuilder.append(repeat);
}
List<String> inputs = Arrays.asList(stringBuilder.toString(), "This is a short sentence");
try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) {
Map<String, String> options = new ConcurrentHashMap<>();
options.put("tokenizer", "bert-base-cased");
options.put("truncation", "false");
try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(options).build()) {
int[] expectedNumberOfIdsNoTruncationNoPadding = new int[] {numRepeats * 2 + 2, 7};
Encoding[] encodings = tokenizer.batchEncode(inputs);
for (int i = 0; i < encodings.length; ++i) {
Expand All @@ -221,19 +228,19 @@ public void testMaxModelLengthTruncationAndAllPaddings() throws IOException {
}
}

Map<String, String> options = new ConcurrentHashMap<>();
options.put("tokenizer", "bert-base-cased");
options.put("truncation", "true");
try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(options).build()) {
try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) {
int[] expectedSize = new int[] {512, 7};
Encoding[] encodings = tokenizer.batchEncode(inputs);
for (int i = 0; i < encodings.length; ++i) {
Assert.assertEquals(encodings[i].getIds().length, expectedSize[i]);
}
}

options.put("padding", "true");
try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(options).build()) {
try (HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder()
.optTokenizerName("bert-base-cased")
.optPadding(true)
.build()) {
Encoding[] encodings = tokenizer.batchEncode(inputs);
for (Encoding encoding : encodings) {
Assert.assertEquals(encoding.getIds().length, 512);
Expand Down

0 comments on commit ec0959d

Please sign in to comment.