Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new token classification example #8340

Merged
merged 13 commits into from
Nov 9, 2020
41 changes: 40 additions & 1 deletion examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@

SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname)
for dirname in ["text-generation", "text-classification", "language-modeling", "question-answering"]
for dirname in [
"text-generation",
"text-classification",
"token-classification",
"language-modeling",
"question-answering",
]
]
sys.path.extend(SRC_DIRS)

Expand All @@ -38,6 +44,7 @@
import run_generation
import run_glue
import run_mlm
import run_ner_new as run_ner
import run_pl_glue
import run_squad

Expand Down Expand Up @@ -185,6 +192,38 @@ def test_run_mlm(self):
result = run_mlm.main()
self.assertLess(result["perplexity"], 42)

def test_run_ner(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_ner.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/conll/sample.json
--validation_file tests/fixtures/tests_samples/conll/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--do_train
--do_eval
--warmup_steps=2
--learning_rate=2e-4
--per_gpu_train_batch_size=2
--per_gpu_eval_batch_size=2
--num_train_epochs=2
""".split()

if torch_device != "cuda":
testargs.append("--no_cuda")

with patch.object(sys, "argv", testargs):
result = run_ner.main()
self.assertGreaterEqual(result["eval_accuracy_score"], 0.75)
self.assertGreaterEqual(result["eval_precision"], 0.75)
self.assertGreaterEqual(result["eval_recall"], 0.2)
self.assertGreaterEqual(result["eval_f1"], 0.25)
self.assertLess(result["eval_loss"], 0.5)

def test_run_squad(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
Expand Down
Loading