diff --git a/test/integration/local/test_serving.py b/test/integration/local/test_serving.py index 85e12116..8ed5201a 100644 --- a/test/integration/local/test_serving.py +++ b/test/integration/local/test_serving.py @@ -89,10 +89,10 @@ def test_serving_calls_model_fn_once(image_uri, sagemaker_local_session, instanc @contextmanager def _predictor(model_tar, script, image, sagemaker_local_session, instance_type, model_server_workers=None): - model = PyTorchModel('file://{}'.format(model_tar), - ROLE, - script, - image=image, + model = PyTorchModel(model_data='file://{}'.format(model_tar), + role=ROLE, + entry_point=script, + image_uri=image, sagemaker_session=sagemaker_local_session, model_server_workers=model_server_workers)