Skip to content

Commit

Permalink
Add support for loading torchscript models (#25321)
Browse files Browse the repository at this point in the history
* Add support for loading torchscript models

* Add tests

* Add example and benchmark

* Add validate_constructor_args method

* Addressing comments, fixing types

* Fix/add tests

* Add change log

* revert changes

* Add few more checks and refactor

* Fixup lint

* Revert "Fixup lint"

This reverts commit b18e10f.

* Add ignore for mypy

* Make validate_constructor_args local to pytorch handler
  • Loading branch information
AnandInguva authored Feb 11, 2023
1 parent 783584d commit 9b77bf9
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 30 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
* Add UDF metrics support for Samza portable mode.
* Option for SparkRunner to avoid the need of SDF output to fit in memory ([#23852](https://github.com/apache/beam/issues/23852)).
This helps e.g. with ParquetIO reads. Turn the feature on by adding experiment `use_bounded_concurrent_output_for_sdf`.
* Add support for loading TorchScript models with `PytorchModelHandler`. The TorchScript model path can be
passed to PytorchModelHandler using `torch_script_model_path=<path_to_model>`. ([#25321](https://github.com/apache/beam/pull/25321))

## Breaking Changes

Expand Down
140 changes: 110 additions & 30 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,66 @@
Iterable[PredictionResult]]


def _load_model(
model_class: torch.nn.Module, state_dict_path, device, **model_params):
model = model_class(**model_params)
def _validate_constructor_args(
state_dict_path, model_class, torch_script_model_path):
message = (
"A {param1} has been supplied to the model "
"handler, but the required {param2} is missing. "
"Please provide the {param2} in order to "
"successfully load the {param1}.")
# state_dict_path and model_class are coupled with each other
# raise RuntimeError if user forgets to pass any one of them.
if state_dict_path and not model_class:
raise RuntimeError(
message.format(param1="state_dict_path", param2="model_class"))

if not state_dict_path and model_class:
raise RuntimeError(
message.format(param1="model_class", param2="state_dict_path"))

if torch_script_model_path and state_dict_path:
raise RuntimeError(
"Please specify either torch_script_model_path or "
"(state_dict_path, model_class) to successfully load the model.")


def _load_model(
model_class: Optional[Callable[..., torch.nn.Module]],
state_dict_path: Optional[str],
device: torch.device,
model_params: Optional[Dict[str, Any]],
torch_script_model_path: Optional[str]):
if device == torch.device('cuda') and not torch.cuda.is_available():
logging.warning(
"Model handler specified a 'GPU' device, but GPUs are not available. " \
"Model handler specified a 'GPU' device, but GPUs are not available. "
"Switching to CPU.")
device = torch.device('cpu')

file = FileSystems.open(state_dict_path, 'rb')
try:
logging.info(
"Loading state_dict_path %s onto a %s device", state_dict_path, device)
state_dict = torch.load(file, map_location=device)
if not torch_script_model_path:
file = FileSystems.open(state_dict_path, 'rb')
model = model_class(**model_params) # type: ignore[misc]
state_dict = torch.load(file, map_location=device)
model.load_state_dict(state_dict)
else:
file = FileSystems.open(torch_script_model_path, 'rb')
model = torch.jit.load(file, map_location=device)
except RuntimeError as e:
if device == torch.device('cuda'):
message = "Loading the model onto a GPU device failed due to an " \
f"exception:\n{e}\nAttempting to load onto a CPU device instead."
logging.warning(message)
return _load_model(
model_class, state_dict_path, torch.device('cpu'), **model_params)
model_class,
state_dict_path,
torch.device('cpu'),
model_params,
torch_script_model_path)
else:
raise e

model.load_state_dict(state_dict)
model.to(device)
model.eval()
logging.info("Finished loading PyTorch model.")
Expand Down Expand Up @@ -148,19 +182,23 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
torch.nn.Module]):
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
state_dict_path: Optional[str] = None,
model_class: Optional[Callable[..., torch.nn.Module]] = None,
model_params: Optional[Dict[str, Any]] = None,
device: str = 'CPU',
*,
inference_fn: TensorInferenceFn = default_tensor_inference_fn,
torch_script_model_path: Optional[str] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None):
"""Implementation of the ModelHandler interface for PyTorch.
Example Usage::
pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri"))
Example Usage for torch model::
pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri",
model_class="my_class"))
Example Usage for torchscript model::
pcoll | RunInference(PytorchModelHandlerTensor(
torch_script_model_path="my_uri"))
See https://pytorch.org/tutorials/beginner/saving_loading_models.html
for details
Expand All @@ -176,6 +214,10 @@ def __init__(
Otherwise, it will be CPU.
inference_fn: the inference function to use during RunInference.
default=_default_tensor_inference_fn
torch_script_model_path: Path to the torch script model.
the model will be loaded using `torch.jit.load()`.
`state_dict_path`, `model_class` and `model_params`
arguments will be disregarded.
min_batch_size: the minimum batch size to use when batching inputs. This
batch will be fed into the inference_fn as a Sequence of Tensors.
max_batch_size: the maximum batch size to use when batching inputs. This
Expand All @@ -192,26 +234,39 @@ def __init__(
logging.info("Device is set to CPU")
self._device = torch.device('cpu')
self._model_class = model_class
self._model_params = model_params
self._model_params = model_params if model_params else {}
self._inference_fn = inference_fn
self._batching_kwargs = {}
if min_batch_size is not None:
self._batching_kwargs['min_batch_size'] = min_batch_size
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
self._torch_script_model_path = torch_script_model_path

_validate_constructor_args(
state_dict_path=self._state_dict_path,
model_class=self._model_class,
torch_script_model_path=self._torch_script_model_path)

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model, device = _load_model(
self._model_class,
self._state_dict_path,
self._device,
**self._model_params)
self._model_params,
self._torch_script_model_path
)
self._device = device
return model

def update_model_path(self, model_path: Optional[str] = None):
self._state_dict_path = model_path if model_path else self._state_dict_path
if self._torch_script_model_path:
self._torch_script_model_path = (
model_path if model_path else self._torch_script_model_path)
else:
self._state_dict_path = (
model_path if model_path else self._state_dict_path)

def run_inference(
self,
Expand Down Expand Up @@ -240,9 +295,11 @@ def run_inference(
An Iterable of type PredictionResult.
"""
inference_args = {} if not inference_args else inference_args

model_id = (
self._state_dict_path
if not self._torch_script_model_path else self._torch_script_model_path)
return self._inference_fn(
batch, model, self._device, inference_args, self._state_dict_path)
batch, model, self._device, inference_args, model_id)

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""
Expand Down Expand Up @@ -336,20 +393,25 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
torch.nn.Module]):
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
state_dict_path: Optional[str] = None,
model_class: Optional[Callable[..., torch.nn.Module]] = None,
model_params: Optional[Dict[str, Any]] = None,
device: str = 'CPU',
*,
inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn,
torch_script_model_path: Optional[str] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None):
"""Implementation of the ModelHandler interface for PyTorch.
Example Usage::
Example Usage for torch model::
pcoll | RunInference(PytorchModelHandlerKeyedTensor(
state_dict_path="my_uri",
model_class="my_class"))
pcoll | RunInference(
PytorchModelHandlerKeyedTensor(state_dict_path="my_uri"))
Example Usage for torchscript model::
pcoll | RunInference(PytorchModelHandlerKeyedTensor(
torch_script_model_path="my_uri"))
**NOTE:** This API and its implementation are under development and
do not provide backward compatibility guarantees.
Expand All @@ -368,6 +430,10 @@ def __init__(
Otherwise, it will be CPU.
inference_fn: the function to invoke on run_inference.
default = default_keyed_tensor_inference_fn
torch_script_model_path: Path to the torch script model.
the model will be loaded using `torch.jit.load()`.
`state_dict_path`, `model_class` and `model_params`
arguments will be disregarded..
min_batch_size: the minimum batch size to use when batching inputs. This
batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
max_batch_size: the maximum batch size to use when batching inputs. This
Expand All @@ -385,26 +451,38 @@ def __init__(
logging.info("Device is set to CPU")
self._device = torch.device('cpu')
self._model_class = model_class
self._model_params = model_params
self._model_params = model_params if model_params else {}
self._inference_fn = inference_fn
self._batching_kwargs = {}
if min_batch_size is not None:
self._batching_kwargs['min_batch_size'] = min_batch_size
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
self._torch_script_model_path = torch_script_model_path
_validate_constructor_args(
state_dict_path=self._state_dict_path,
model_class=self._model_class,
torch_script_model_path=self._torch_script_model_path)

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model, device = _load_model(
self._model_class,
self._state_dict_path,
self._device,
**self._model_params)
self._model_params,
self._torch_script_model_path
)
self._device = device
return model

def update_model_path(self, model_path: Optional[str] = None):
self._state_dict_path = model_path if model_path else self._state_dict_path
if self._torch_script_model_path:
self._torch_script_model_path = (
model_path if model_path else self._torch_script_model_path)
else:
self._state_dict_path = (
model_path if model_path else self._state_dict_path)

def run_inference(
self,
Expand Down Expand Up @@ -433,9 +511,11 @@ def run_inference(
An Iterable of type PredictionResult.
"""
inference_args = {} if not inference_args else inference_args

model_id = (
self._state_dict_path
if not self._torch_script_model_path else self._torch_script_model_path)
return self._inference_fn(
batch, model, self._device, inference_args, self._state_dict_path)
batch, model, self._device, inference_args, model_id)

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""
Expand Down
Loading

0 comments on commit 9b77bf9

Please sign in to comment.