From 6efe660fa2ca1bc2498b181dc00656376d7a8083 Mon Sep 17 00:00:00 2001 From: Ewan <166796318+ewan0x79@users.noreply.github.com> Date: Thu, 25 Apr 2024 08:57:05 +0800 Subject: [PATCH] [tokenizer] add optional tokenizerPath Prior to modelPath (#3120) * [tokenizer] add optional tokenizerPath Prior to modelPath --------- Co-authored-by: Frank Liu --- .../djl/huggingface/tokenizers/HuggingFaceTokenizer.java | 8 +++++--- .../tokenizers/CrossEncoderTranslatorTest.java | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index 0eaed72a709..8d32ccb4b6c 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -30,6 +30,7 @@ import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Arrays; import java.util.List; import java.util.Locale; @@ -686,7 +687,6 @@ static PaddingStrategy fromValue(String value) { /** The builder for creating huggingface tokenizer. */ public static final class Builder { - private Path tokenizerPath; private NDManager manager; private Map options; @@ -724,7 +724,7 @@ public Builder optTokenizerName(String tokenizerName) { * @return this builder */ public Builder optTokenizerPath(Path tokenizerPath) { - this.tokenizerPath = tokenizerPath; + options.putIfAbsent("tokenizerPath", tokenizerPath.toString()); return this; } @@ -894,9 +894,11 @@ public HuggingFaceTokenizer build() throws IOException { if (tokenizerName != null) { return managed(HuggingFaceTokenizer.newInstance(tokenizerName, options)); } - if (tokenizerPath == null) { + String path = options.get("tokenizerPath"); + if (path == null) { throw new IllegalArgumentException("Missing tokenizer path."); } + Path tokenizerPath = Paths.get(path); if (Files.isDirectory(tokenizerPath)) { Path tokenizerFile = tokenizerPath.resolve("tokenizer.json"); if (Files.exists(tokenizerFile)) { diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java index ef4015d94d3..2a98f63db65 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java @@ -64,6 +64,7 @@ public void testCrossEncoderTranslator() .optBlock(block) .optEngine("PyTorch") .optArgument("tokenizer", "bert-base-cased") + .optArgument("tokenizerPath", modelDir) .optOption("hasParameter", "false") .optTranslatorFactory(new CrossEncoderTranslatorFactory()) .build();