From 9b77bf9b130e2f0b63fd5fbc8a5ceb975cd9905e Mon Sep 17 00:00:00 2001 From: Anand Inguva <34158215+AnandInguva@users.noreply.github.com> Date: Sat, 11 Feb 2023 11:49:20 -0500 Subject: [PATCH] Add support for loading torchscript models (#25321) * Add support for loading torchscript models * Add tests * Add example and benchmark * Add validate_constructor_args method * Addressing comments, fixing types * Fix/add tests * Add change log * revert changes * Add few more checks and refactor * Fixup lint * Revert "Fixup lint" This reverts commit b18e10f5b6add921c2527dcec0b2c64ea385b082. * Add ignore for mypy * Make validate_constructor_args local to pytorch handler --- CHANGES.md | 2 + .../ml/inference/pytorch_inference.py | 140 ++++++++++++++---- .../ml/inference/pytorch_inference_test.py | 131 ++++++++++++++++ 3 files changed, 243 insertions(+), 30 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 691b2bb2d561..fcc1451848cb 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -66,6 +66,8 @@ * Add UDF metrics support for Samza portable mode. * Option for SparkRunner to avoid the need of SDF output to fit in memory ([#23852](https://github.com/apache/beam/issues/23852)). This helps e.g. with ParquetIO reads. Turn the feature on by adding experiment `use_bounded_concurrent_output_for_sdf`. +* Add support for loading TorchScript models with `PytorchModelHandler`. The TorchScript model path can be + passed to PytorchModelHandler using `torch_script_model_path=`. ([#25321](https://github.com/apache/beam/pull/25321)) ## Breaking Changes diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 520f133be5d0..71a4ccc63a27 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -57,32 +57,66 @@ Iterable[PredictionResult]] -def _load_model( - model_class: torch.nn.Module, state_dict_path, device, **model_params): - model = model_class(**model_params) +def _validate_constructor_args( + state_dict_path, model_class, torch_script_model_path): + message = ( + "A {param1} has been supplied to the model " + "handler, but the required {param2} is missing. " + "Please provide the {param2} in order to " + "successfully load the {param1}.") + # state_dict_path and model_class are coupled with each other + # raise RuntimeError if user forgets to pass any one of them. + if state_dict_path and not model_class: + raise RuntimeError( + message.format(param1="state_dict_path", param2="model_class")) + + if not state_dict_path and model_class: + raise RuntimeError( + message.format(param1="model_class", param2="state_dict_path")) + + if torch_script_model_path and state_dict_path: + raise RuntimeError( + "Please specify either torch_script_model_path or " + "(state_dict_path, model_class) to successfully load the model.") + +def _load_model( + model_class: Optional[Callable[..., torch.nn.Module]], + state_dict_path: Optional[str], + device: torch.device, + model_params: Optional[Dict[str, Any]], + torch_script_model_path: Optional[str]): if device == torch.device('cuda') and not torch.cuda.is_available(): logging.warning( - "Model handler specified a 'GPU' device, but GPUs are not available. " \ + "Model handler specified a 'GPU' device, but GPUs are not available. " "Switching to CPU.") device = torch.device('cpu') - file = FileSystems.open(state_dict_path, 'rb') try: logging.info( "Loading state_dict_path %s onto a %s device", state_dict_path, device) - state_dict = torch.load(file, map_location=device) + if not torch_script_model_path: + file = FileSystems.open(state_dict_path, 'rb') + model = model_class(**model_params) # type: ignore[misc] + state_dict = torch.load(file, map_location=device) + model.load_state_dict(state_dict) + else: + file = FileSystems.open(torch_script_model_path, 'rb') + model = torch.jit.load(file, map_location=device) except RuntimeError as e: if device == torch.device('cuda'): message = "Loading the model onto a GPU device failed due to an " \ f"exception:\n{e}\nAttempting to load onto a CPU device instead." logging.warning(message) return _load_model( - model_class, state_dict_path, torch.device('cpu'), **model_params) + model_class, + state_dict_path, + torch.device('cpu'), + model_params, + torch_script_model_path) else: raise e - model.load_state_dict(state_dict) model.to(device) model.eval() logging.info("Finished loading PyTorch model.") @@ -148,19 +182,23 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, torch.nn.Module]): def __init__( self, - state_dict_path: str, - model_class: Callable[..., torch.nn.Module], - model_params: Dict[str, Any], + state_dict_path: Optional[str] = None, + model_class: Optional[Callable[..., torch.nn.Module]] = None, + model_params: Optional[Dict[str, Any]] = None, device: str = 'CPU', *, inference_fn: TensorInferenceFn = default_tensor_inference_fn, + torch_script_model_path: Optional[str] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None): """Implementation of the ModelHandler interface for PyTorch. - Example Usage:: - - pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri")) + Example Usage for torch model:: + pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri", + model_class="my_class")) + Example Usage for torchscript model:: + pcoll | RunInference(PytorchModelHandlerTensor( + torch_script_model_path="my_uri")) See https://pytorch.org/tutorials/beginner/saving_loading_models.html for details @@ -176,6 +214,10 @@ def __init__( Otherwise, it will be CPU. inference_fn: the inference function to use during RunInference. default=_default_tensor_inference_fn + torch_script_model_path: Path to the torch script model. + the model will be loaded using `torch.jit.load()`. + `state_dict_path`, `model_class` and `model_params` + arguments will be disregarded. min_batch_size: the minimum batch size to use when batching inputs. This batch will be fed into the inference_fn as a Sequence of Tensors. max_batch_size: the maximum batch size to use when batching inputs. This @@ -192,13 +234,19 @@ def __init__( logging.info("Device is set to CPU") self._device = torch.device('cpu') self._model_class = model_class - self._model_params = model_params + self._model_params = model_params if model_params else {} self._inference_fn = inference_fn self._batching_kwargs = {} if min_batch_size is not None: self._batching_kwargs['min_batch_size'] = min_batch_size if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size + self._torch_script_model_path = torch_script_model_path + + _validate_constructor_args( + state_dict_path=self._state_dict_path, + model_class=self._model_class, + torch_script_model_path=self._torch_script_model_path) def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -206,12 +254,19 @@ def load_model(self) -> torch.nn.Module: self._model_class, self._state_dict_path, self._device, - **self._model_params) + self._model_params, + self._torch_script_model_path + ) self._device = device return model def update_model_path(self, model_path: Optional[str] = None): - self._state_dict_path = model_path if model_path else self._state_dict_path + if self._torch_script_model_path: + self._torch_script_model_path = ( + model_path if model_path else self._torch_script_model_path) + else: + self._state_dict_path = ( + model_path if model_path else self._state_dict_path) def run_inference( self, @@ -240,9 +295,11 @@ def run_inference( An Iterable of type PredictionResult. """ inference_args = {} if not inference_args else inference_args - + model_id = ( + self._state_dict_path + if not self._torch_script_model_path else self._torch_script_model_path) return self._inference_fn( - batch, model, self._device, inference_args, self._state_dict_path) + batch, model, self._device, inference_args, model_id) def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: """ @@ -336,20 +393,25 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], torch.nn.Module]): def __init__( self, - state_dict_path: str, - model_class: Callable[..., torch.nn.Module], - model_params: Dict[str, Any], + state_dict_path: Optional[str] = None, + model_class: Optional[Callable[..., torch.nn.Module]] = None, + model_params: Optional[Dict[str, Any]] = None, device: str = 'CPU', *, inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn, + torch_script_model_path: Optional[str] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None): """Implementation of the ModelHandler interface for PyTorch. - Example Usage:: + Example Usage for torch model:: + pcoll | RunInference(PytorchModelHandlerKeyedTensor( + state_dict_path="my_uri", + model_class="my_class")) - pcoll | RunInference( - PytorchModelHandlerKeyedTensor(state_dict_path="my_uri")) + Example Usage for torchscript model:: + pcoll | RunInference(PytorchModelHandlerKeyedTensor( + torch_script_model_path="my_uri")) **NOTE:** This API and its implementation are under development and do not provide backward compatibility guarantees. @@ -368,6 +430,10 @@ def __init__( Otherwise, it will be CPU. inference_fn: the function to invoke on run_inference. default = default_keyed_tensor_inference_fn + torch_script_model_path: Path to the torch script model. + the model will be loaded using `torch.jit.load()`. + `state_dict_path`, `model_class` and `model_params` + arguments will be disregarded.. min_batch_size: the minimum batch size to use when batching inputs. This batch will be fed into the inference_fn as a Sequence of Keyed Tensors. max_batch_size: the maximum batch size to use when batching inputs. This @@ -385,13 +451,18 @@ def __init__( logging.info("Device is set to CPU") self._device = torch.device('cpu') self._model_class = model_class - self._model_params = model_params + self._model_params = model_params if model_params else {} self._inference_fn = inference_fn self._batching_kwargs = {} if min_batch_size is not None: self._batching_kwargs['min_batch_size'] = min_batch_size if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size + self._torch_script_model_path = torch_script_model_path + _validate_constructor_args( + state_dict_path=self._state_dict_path, + model_class=self._model_class, + torch_script_model_path=self._torch_script_model_path) def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -399,12 +470,19 @@ def load_model(self) -> torch.nn.Module: self._model_class, self._state_dict_path, self._device, - **self._model_params) + self._model_params, + self._torch_script_model_path + ) self._device = device return model def update_model_path(self, model_path: Optional[str] = None): - self._state_dict_path = model_path if model_path else self._state_dict_path + if self._torch_script_model_path: + self._torch_script_model_path = ( + model_path if model_path else self._torch_script_model_path) + else: + self._state_dict_path = ( + model_path if model_path else self._state_dict_path) def run_inference( self, @@ -433,9 +511,11 @@ def run_inference( An Iterable of type PredictionResult. """ inference_args = {} if not inference_args else inference_args - + model_id = ( + self._state_dict_path + if not self._torch_script_model_path else self._torch_script_model_path) return self._inference_fn( - batch, model, self._device, inference_args, self._state_dict_path) + batch, model, self._device, inference_args, model_id) def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: """ diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 9d4276d1d42e..947e06e20473 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -123,6 +123,7 @@ def __init__(self, device, *, inference_fn=default_tensor_inference_fn): self._device = device self._inference_fn = inference_fn self._state_dict_path = None + self._torch_script_model_path = None class TestPytorchModelHandlerKeyedTensorForInferenceOnly( @@ -131,6 +132,7 @@ def __init__(self, device, *, inference_fn=default_keyed_tensor_inference_fn): self._device = device self._inference_fn = inference_fn self._state_dict_path = None + self._torch_script_model_path = None def _compare_prediction_result(x, y): @@ -701,6 +703,135 @@ def test_gpu_auto_convert_to_cpu(self): "are not available. Switching to CPU.", log.output) + def test_load_torch_script_model(self): + torch_model = PytorchLinearRegression(2, 1) + torch_script_model = torch.jit.script(torch_model) + + torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt') + + torch.jit.save(torch_script_model, torch_script_path) + + model_handler = PytorchModelHandlerTensor( + torch_script_model_path=torch_script_path) + + torch_script_model = model_handler.load_model() + + self.assertTrue(isinstance(torch_script_model, torch.jit.ScriptModule)) + + def test_inference_torch_script_model(self): + torch_model = PytorchLinearRegression(2, 1) + torch_model.load_state_dict( + OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])), + ('linear.bias', torch.Tensor([0.5]))])) + + torch_script_model = torch.jit.script(torch_model) + + torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt') + + torch.jit.save(torch_script_model, torch_script_path) + + model_handler = PytorchModelHandlerTensor( + torch_script_model_path=torch_script_path) + + with TestPipeline() as pipeline: + pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES) + predictions = pcoll | RunInference(model_handler) + assert_that( + predictions, + equal_to( + TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result)) + + def test_torch_model_class_none(self): + torch_model = PytorchLinearRegression(2, 1) + torch_path = os.path.join(self.tmpdir, 'torch_model.pt') + + torch.save(torch_model, torch_path) + + with self.assertRaisesRegex( + RuntimeError, + "A state_dict_path has been supplied to the model " + "handler, but the required model_class is missing. " + "Please provide the model_class in order to"): + _ = PytorchModelHandlerTensor(state_dict_path=torch_path) + + with self.assertRaisesRegex( + RuntimeError, + "A state_dict_path has been supplied to the model " + "handler, but the required model_class is missing. " + "Please provide the model_class in order to"): + _ = (PytorchModelHandlerKeyedTensor(state_dict_path=torch_path)) + + def test_torch_model_state_dict_none(self): + with self.assertRaisesRegex( + RuntimeError, + "A model_class has been supplied to the model " + "handler, but the required state_dict_path is missing. " + "Please provide the state_dict_path in order to"): + _ = PytorchModelHandlerTensor(model_class=PytorchLinearRegression) + + with self.assertRaisesRegex( + RuntimeError, + "A model_class has been supplied to the model " + "handler, but the required state_dict_path is missing. " + "Please provide the state_dict_path in order to"): + _ = PytorchModelHandlerKeyedTensor(model_class=PytorchLinearRegression) + + def test_specify_torch_script_path_and_state_dict_path(self): + torch_model = PytorchLinearRegression(2, 1) + torch_path = os.path.join(self.tmpdir, 'torch_model.pt') + + torch.save(torch_model, torch_path) + torch_script_model = torch.jit.script(torch_model) + + torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt') + + torch.jit.save(torch_script_model, torch_script_path) + with self.assertRaisesRegex( + RuntimeError, "Please specify either torch_script_model_path or "): + _ = PytorchModelHandlerTensor( + state_dict_path=torch_path, + model_class=PytorchLinearRegression, + torch_script_model_path=torch_script_path) + + def test_prediction_result_model_id_with_torch_script_model(self): + torch_model = PytorchLinearRegression(2, 1) + torch_script_model = torch.jit.script(torch_model) + torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt') + torch.jit.save(torch_script_model, torch_script_path) + + model_handler = PytorchModelHandlerTensor( + torch_script_model_path=torch_script_path) + + def check_torch_script_model_id(element): + assert ('torch_script_model.pt' in element.model_id) is True + + with TestPipeline() as pipeline: + pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES) + predictions = pcoll | RunInference(model_handler) + _ = predictions | beam.Map(check_torch_script_model_id) + + def test_prediction_result_model_id_with_torch_model(self): + # weights associated with PytorchLinearRegression class + state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])), + ('linear.bias', torch.Tensor([0.5]))]) + torch_path = os.path.join(self.tmpdir, 'torch_model.pt') + torch.save(state_dict, torch_path) + + model_handler = PytorchModelHandlerTensor( + state_dict_path=torch_path, + model_class=PytorchLinearRegression, + model_params={ + 'input_dim': 2, 'output_dim': 1 + }) + + def check_torch_script_model_id(element): + assert ('torch_model.pt' in element.model_id) is True + + with TestPipeline() as pipeline: + pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES) + predictions = pcoll | RunInference(model_handler) + _ = predictions | beam.Map(check_torch_script_model_id) + if __name__ == '__main__': unittest.main()