Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Apr 2, 2019
1 parent 8fd1cab commit f8b89c0
Showing 1 changed file with 68 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.javaapi.infer.predictor;

import org.apache.mxnetexamples.Util;
import org.apache.mxnetexamples.javaapi.infer.bert.BertQA;
import org.junit.BeforeClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;

public class BertExampleTest {
final static Logger logger = LoggerFactory.getLogger(BertExampleTest.class);
private static String modelPathPrefix = "";
private static String vocabPath = "";

@BeforeClass
public static void downloadFile() {
logger.info("Downloading Bert QA Model");
String tempDirPath = System.getProperty("java.io.tmpdir");
logger.info("tempDirPath: %s".format(tempDirPath));

String baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA";
Util.downloadUrl(baseUrl + "/static_bert_qa-symbol.json",
tempDirPath + "/static_bert_qa/static_bert_qa-symbol.json", 3);
Util.downloadUrl(baseUrl + "/static_bert_qa-0002.params",
tempDirPath + "/static_bert_qa/static_bert_qa-0002.params", 3);
Util.downloadUrl(baseUrl + "/vocab.json",
tempDirPath + "/static_bert_qa/vocab.json", 3);
modelPathPrefix = tempDirPath + File.separator + "static_bert_qa/static_bert_qa";
vocabPath = tempDirPath + File.separator + "static_bert_qa/vocab.json";
}

@Test
public void testBertQA() throws Exception{
BertQA bert = new BertQA();
String Q = "When did BBC Japan start broadcasting?";
String A = "BBC Japan was a general entertainment Channel.\n" +
" Which operated between December 2004 and April 2006.\n" +
"It ceased operations after its Japanese distributor folded.";
String[] args = new String[] {
"--model-path-prefix", modelPathPrefix,
"--model-vocab", vocabPath,
"--model-epoch", "2",
"--input-question", Q,
"--input-answer", A,
"--seq-length", "384"
};
bert.main(args);
}
}

0 comments on commit f8b89c0

Please sign in to comment.