diff --git a/tests/proj/main/test_export_model.py b/tests/proj/main/test_export_model.py index b80930b57..82b3832f2 100644 --- a/tests/proj/main/test_export_model.py +++ b/tests/proj/main/test_export_model.py @@ -17,6 +17,23 @@ [ ("bert", BertPreTrainedModel, "bert-base-cased"), ("roberta", RobertaForMaskedLM, "nyu-mll/roberta-med-small-1M-1",), + ], +) +def test_export_model(tmp_path, model_type, model_class, hf_pretrained_model_name_or_path): + export_model( + hf_pretrained_model_name_or_path=hf_pretrained_model_name_or_path, + output_base_path=tmp_path, + ) + read_config = py_io.read_json(os.path.join(tmp_path, f"config.json")) + assert read_config["model_type"] == model_type + assert read_config["model_path"] == os.path.join(tmp_path, "model", f"{model_type}.p") + assert read_config["model_config_path"] == os.path.join(tmp_path, "model", f"{model_type}.json") + + +@pytest.mark.slow +@pytest.mark.parametrize( + "model_type, model_class, hf_pretrained_model_name_or_path", + [ ("deberta-v2", DebertaV2ForMaskedLM, "microsoft/deberta-v2-xlarge",), ], )