diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java index 86354477fcbb..440670afc098 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java @@ -25,6 +25,11 @@ import com.google.gson.JsonElement; import com.google.gson.JsonObject; +/** + * This is the Utility for pre-processing the data for Bert Model + * You can use this utility to parse Vocabulary JSON into Java Array and Dictionary, + * clean and tokenize sentences and pad the text + */ public class BertDataParser { private Map token2idx; @@ -32,7 +37,7 @@ public class BertDataParser { /** * Parse the Vocabulary to JSON files - * [PAD], [CLS], [SEP], [MASK], [UNK] are reserved token + * [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens * @param jsonFile the filePath of the vocab.json * @throws Exception */ @@ -52,13 +57,13 @@ void parseJSON(String jsonFile) throws Exception { } /** - * Tokenize the input, split all kinds of spaces and - * saparate the end of sentence symbol: . , ? ! - * @param input The input String + * Tokenize the input, split all kinds of whitespace and + * Separate the end of sentence symbol: . , ? ! + * @param input The input string * @return List of tokens */ List tokenizer(String input) { - String[] step1 = input.split("[\n\r\t ]+"); + String[] step1 = input.split("\\s+"); List finalResult = new LinkedList<>(); for (String item : step1) { if (item.length() != 0) { diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java index 3254faeb08d1..b40a4e94afbd 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java @@ -26,6 +26,11 @@ import java.util.*; +/** + * This is an example of using BERT to do the general Question and Answer inference jobs + * Users can provide a question with a paragraph contains answer to the model and + * the model will be able to find the best answer from the answer paragraph + */ public class BertQA { @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model") private String modelPathPrefix = "/model/static_bert_qa"; diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md index 4e65406e8541..7925a259f48f 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md @@ -17,16 +17,18 @@ # Run BERT QA model using Java Inference API -In this tutorial, we will walk through the BERT QA model trained by MXNet. -You will be able to run inference with general Q & A task: +In this tutorial, we will walk through the BERT QA model trained by MXNet. +Users can provide a question with a paragraph contains answer to the model and +the model will be able to find the best answer from the answer paragraph. +Example: ```text Q: When did BBC Japan start broadcasting? ``` -The model are expected to find the right answer in the corresponding text: +Answer paragraph ```text -BBC Japan was a general entertainment Channel. Which operated between December 2004 and April 2006. +BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006. It ceased operations after its Japanese distributor folded. ``` And it picked up the right one: @@ -74,18 +76,18 @@ From the `scala-package/examples/scripts/infer/bert/` folder run: ## Background -To learn more about how BERT works in MXNet, please follow this [tutorial](https://medium.com/apache-mxnet/gluon-nlp-bert-6a489bdd3340). +To learn more about how BERT works in MXNet, please follow this [MXNet Gluon tutorial on NLP using BERT](https://medium.com/apache-mxnet/gluon-nlp-bert-6a489bdd3340). -The model was extracted from the GluonNLP with static length settings. +The model was extracted from MXNet GluonNLP with static length settings. [Download link for the script](https://gluon-nlp.mxnet.io/_downloads/bert.zip) -The original description can be found in [here](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1). +The original description can be found in the [MXNet GluonNLP model zoo](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1). ```bash python static_finetune_squad.py --optimizer adam --accumulate 2 --batch_size 6 --lr 3e-5 --epochs 2 --gpu 0 --export ``` -This script would generate a `json` and `param` which are the standard MXNet model files. +This script will generate `json` and `param` fles that are the standard MXNet model files. By default, this model are using `bert_12_768_12` model with extra layers for QA jobs. After that, to be able to use it in Java, we need to export the dictionary from the script to parse the text @@ -97,5 +99,5 @@ f = open("vocab.json", "w") f.write(json_str) f.close() ``` -This would export a json file for you to deal with the vocabulary. +This would export the token vocabulary in json format. Once you have these three files, you will be able to run this example without problems. diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java index cb46ea9cc0fd..0518254c297d 100644 --- a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java +++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java @@ -26,6 +26,9 @@ import java.io.File; +/** + * Test on BERT QA model + */ public class BertExampleTest { final static Logger logger = LoggerFactory.getLogger(BertExampleTest.class); private static String modelPathPrefix = "";