Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Apr 3, 2019
1 parent 6e25592 commit d053102
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mxnet.javaapi

/**
* Layout definition of DataDesc
* N Batch size
* C channels
* H Height
* W Weight
* T sequence length
* __undefined__ default value of Layout
*/
object Layout {
val UNDEFINED: String = org.apache.mxnet.Layout.UNDEFINED
val NCHW: String = org.apache.mxnet.Layout.NCHW
val NTC: String = org.apache.mxnet.Layout.NTC
val NT: String = org.apache.mxnet.Layout.NT
val N: String = org.apache.mxnet.Layout.N
}
3 changes: 0 additions & 3 deletions scala-package/examples/scripts/infer/bert/get_bert_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ data_path=$MXNET_ROOT/scripts/infer/models/static-bert-qa/

if [ ! -d "$data_path" ]; then
mkdir -p "$data_path"
fi

if [ ! -f "$data_path" ]; then
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;

public class BertUtil {
public class BertDataParser {

private Map<String, Integer> token2idx;
private List<String> idx2token;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ static List<String> postProcessing(NDArray result, List<String> tokens) {
NDArray[] output = NDArray.split(
NDArray.new splitParam(result, 2).setAxis(2));
// Get the formatted logits result
NDArray startLogits = NDArray.reshape(
NDArray.new reshapeParam(output[0]).setShape(new Shape(new int[]{0, -3})))[0];
NDArray endLogits = NDArray.reshape(
NDArray.new reshapeParam(output[1]).setShape(new Shape(new int[]{0, -3})))[0];
NDArray startLogits = output[0].reshape(new int[]{0, -3});
NDArray endLogits = output[1].reshape(new int[]{0, -3});
// Get Probability distribution
float[] startProb = NDArray.softmax(
NDArray.new softmaxParam(startLogits))[0].toArray();
Expand All @@ -83,7 +81,7 @@ public static void main(String[] args) throws Exception{
BertQA inst = new BertQA();
CmdLineParser parser = new CmdLineParser(inst);
parser.parseArgument(args);
BertUtil util = new BertUtil();
BertDataParser util = new BertDataParser();
Context context = Context.cpu();
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
Expand Down Expand Up @@ -115,26 +113,25 @@ public static void main(String[] args) throws Exception{
indexesFloat.add((float) integer);
}
// Preparing the input data
NDArray inputs = new NDArray(indexesFloat,
new Shape(new int[]{1, inst.seqLength}), context);
NDArray tokenTypesND = new NDArray(tokenTypes,
new Shape(new int[]{1, inst.seqLength}), context);
NDArray validLengthND = new NDArray(new float[] {(float) validLength},
new Shape(new int[]{1}), context);
List<NDArray> inputBatch = new ArrayList<>();
inputBatch.add(inputs);
inputBatch.add(tokenTypesND);
inputBatch.add(validLengthND);
List<NDArray> inputBatch = Arrays.asList(
new NDArray(indexesFloat,
new Shape(new int[]{1, inst.seqLength}), context),
new NDArray(tokenTypes,
new Shape(new int[]{1, inst.seqLength}), context),
new NDArray(new float[] { validLength },
new Shape(new int[]{1}), context)
);
// Build the model
List<DataDesc> inputDescs = new ArrayList<>();
List<Context> contexts = new ArrayList<>();
contexts.add(context);
inputDescs.add(new DataDesc("data0",
new Shape(new int[]{1, inst.seqLength}), DType.Float32(), "NT"));
inputDescs.add(new DataDesc("data1",
new Shape(new int[]{1, inst.seqLength}), DType.Float32(), "NT"));
inputDescs.add(new DataDesc("data2",
new Shape(new int[]{1}), DType.Float32(), "N"));
List<DataDesc> inputDescs = Arrays.asList(
new DataDesc("data0",
new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()),
new DataDesc("data1",
new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()),
new DataDesc("data2",
new Shape(new int[]{1}), DType.Float32(), Layout.N())
);
Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch);
// Start prediction
NDArray result = bertQA.predictWithNDArray(inputBatch).get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ From the `scala-package/examples/scripts/infer/bert/` folder run:
./get_bert_data.sh
```

**Note**: You may need to run `chmod +x get_bert_data.sh` before running this script.

### Step 2: Setup data path of the model

### Setup Datapath and Parameters
Expand Down Expand Up @@ -80,7 +78,7 @@ To learn more about how BERT works in MXNet, please follow this [tutorial](https

The model was extracted from the GluonNLP with static length settings.

[Download link for the scrtipt](https://gluon-nlp.mxnet.io/_downloads/bert.zip)
[Download link for the script](https://gluon-nlp.mxnet.io/_downloads/bert.zip)

The original description can be found in [here](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1).
```bash
Expand Down

0 comments on commit d053102

Please sign in to comment.