Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
philschmid committed May 8, 2024
1 parent e0b66cc commit 00d1345
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 42 deletions.
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ check_dirs := src tests
# run tests

unit-test:
python -m pytest -n auto --dist loadfile -s -v ./tests/unit/
python -m pytest -v -s ./tests/unit/

integ-test:
python -m pytest -n 2 -s -v ./tests/integ/
Expand Down
15 changes: 6 additions & 9 deletions tests/unit/test_handler_service_without_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def test_handle(inference_handler):
inference_handler.initialize(CONTEXT)
json_data = json.dumps(INPUT)
prediction = inference_handler.handle([{"body": json_data.encode()}], CONTEXT)
loaded_response = json.loads(prediction[0])
assert "entity" in loaded_response[0]
assert "score" in loaded_response[0]
assert "output" in prediction[0]


@require_torch
Expand All @@ -90,13 +88,15 @@ def test_load(inference_handler):
model_dir=tmpdirname,
)
# test with automatic infer
if "HF_TASK" in os.environ:
del os.environ["HF_TASK"]
hf_pipeline_without_task = inference_handler.load(storage_folder)
assert hf_pipeline_without_task.task == "token-classification"

# test with automatic infer
os.environ["HF_TASK"] = TASK
os.environ["HF_TASK"] = "text-classification"
hf_pipeline_with_task = inference_handler.load(storage_folder)
assert hf_pipeline_with_task.task == TASK
assert hf_pipeline_with_task.task == "text-classification"


def test_preprocess(inference_handler):
Expand Down Expand Up @@ -139,10 +139,7 @@ def test_validate_and_initialize_user_module(inference_handler):
prediction = inference_handler.handle([{"body": b""}], CONTEXT)
assert "output" in prediction[0]

assert inference_handler.load({}) == "model"
assert inference_handler.preprocess({}, "") == "data"
assert inference_handler.predict({}, "model") == "output"
assert inference_handler.postprocess("output", "") == "output"
assert inference_handler.load({}) == "Loading inference_tranform_fn.py"


def test_validate_and_initialize_user_module_transform_fn():
Expand Down
32 changes: 0 additions & 32 deletions tests/unit/test_mms_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,35 +186,3 @@ def test_start_mms_with_model_from_hub(
subprocess_popen.assert_called_once_with(multi_model_server_cmd)
sigterm.assert_called_once_with(retrieve.return_value)
os.remove(mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY)


@patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True)
@patch("subprocess.call")
@patch("subprocess.Popen")
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process")
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub")
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler")
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements")
@patch("os.makedirs", return_value=True)
@patch("os.remove", return_value=True)
@patch("os.path.exists", return_value=True)
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file")
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format")
def test_start_mms_neuron_and_model_from_hub(
adapt,
create_config,
exists,
remove,
dir,
install_requirements,
sigterm,
load_model_from_hub,
retrieve,
subprocess_popen,
subprocess_call,
_aws_neuron_available,
):
with pytest.raises(ValueError):
os.environ["HF_MODEL_ID"] = "lysandre/tiny-bert-random"

mms_model_server.start_model_server()

0 comments on commit 00d1345

Please sign in to comment.