Skip to content

Commit

Permalink
Support conversion_dtype in ONNXConversion pass and `dynamic_shapes…
Browse files Browse the repository at this point in the history
…` on torch.onnx.export(..., dynamo=True) (#1511)

## Describe your changes

(1) Introduce `dynamic_shapes`, which is a crucial parameter to
torch.onnx.export(..., dynamo=True).
(2) Enable `dynamic` boolean control in the config
(3) Actually apply `torch_dtype` to model inputs.

NOTE:
1. Some automation is intentionally not supported by the dynamic_shapes
field due to the complexity of generatirng `dynamic_shapes`, and
torch.onnx.export(dynamo=True) actually supports converting
`dynamic_axes` to `dynamic_shapes` (limited).
2. To follow JSON rules, `dynamic_shapes` requires users to provide a
list of [dim_name(str), min(int), max(int)]. These information will
later be used to compose `torch.export.Dim(dim_name, min=min, max=max)`
([detail](https://pytorch.org/docs/stable/export.html#expressing-dynamism)).
3. `dynamic_shapes` follows the tree structure of the model inputs. For
example, if the model input is nested tuple, then the `dynamic_shapes`
should be a nested tuple, instead of a dictionary.
4. The `kv_cache` support of `dynamic_shapes` is limited in terms of the
variation of model signatures, implementations, and inputs. Users are
encouraged to provide full kv cache.

## Checklist before requesting a review
- [x] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [x] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [x] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
titaiwangms authored Dec 18, 2024
1 parent 86f742d commit f372b89
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 23 deletions.
3 changes: 3 additions & 0 deletions docs/source/extending/advanced-users.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Now, let's take a look at how you can use advance Python interface.
Start by creating an instance of an OliveModelHandler to represent the model to be optimized. Depending on the model framework, the
model can be loaded from file or using a model loader function. For a complete of available models and their initialization options, refer to [OliveModels api reference](models).

Note: The `dynamic_shapes` field requires more than just a string. It should be a list in the format [str, int, int], representing [dim_name, min_value, max_value]. This will later be converted to torch.export.Dim(dim_name, min=min_value, max=max_value). For more details, refer to the documentation: https://pytorch.org/docs/stable/export.html#expressing-dynamism

```python
from olive.models import Modelconfig

Expand All @@ -30,6 +32,7 @@ config = {
"input_shapes": [[1, 3, 32, 32]],
"output_names": ["output"],
"dynamic_axes": {"input": {0: "batch_size"}, "output": {0: "batch_size"}},
"dynamic_shapes": {"input": {0: ["batch_size", 0, 16]}}
}
}
input_model = ModelConfig.parse_obj(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ You can also provide your own IO config which will override the automatically fe
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
},
"dynamic_shapes": {
"input_ids": { "0": ["batch_size", 0, 8], "1": ["sequence_length", 0, 2048] },
"attention_mask": { "0": ["batch_size", 0, 8], "1": ["total_sequence_length", 0, 3072] },
"position_ids": { "0": ["batch_size", 0, 8], "1": ["sequence_length", 0, 2048] }
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions olive/common/hf/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def get_model_io_config(model_name: str, task: str, model: "PreTrainedModel", **
for axis, axis_name in value.items():
if axis_name == "past_sequence_length + 1":
value[axis] = "past_sequence_length + sequence_length"
# NOTE: Due to the complexity of dynamic_shapes, we don't provide it here.
# torch-onnx converter has a naive approach to auto-gen dynamic shapes based on input and
# dynamic_axes, so we don't need to provide dynamic shapes here.
return {"input_names": input_names, "output_names": output_names, "dynamic_axes": dynamic_axes}


Expand Down
21 changes: 21 additions & 0 deletions olive/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,27 @@ def tensor_data_to_device(data, device: str):
return data


def tensor_data_to_dtype(data, dtype):
import torch

if dtype is None:
return data

from torch import Tensor

if isinstance(data, Tensor) and data.dtype in {torch.bfloat16, torch.float16, torch.float32, torch.float64}:
return data.to(dtype)
if isinstance(data, dict):
return {k: tensor_data_to_dtype(v, dtype) for k, v in data.items()}
if isinstance(data, list):
return [tensor_data_to_dtype(v, dtype) for v in data]
if isinstance(data, tuple):
return tuple(tensor_data_to_dtype(v, dtype) for v in data)
if isinstance(data, set):
return {tensor_data_to_dtype(v, dtype) for v in data}
return data


def resolve_torch_dtype(dtype):
"""Get torch dtype from string or torch dtype.
Expand Down
12 changes: 12 additions & 0 deletions olive/model/config/io_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class IoConfig(ConfigBase):
"clip_input": { "0": "batch", "1": "channels", "2": "height", "3": "width" },
"images": { "0": "batch", "1": "height", "2": "width", "3": "channels" }
},
"dynamic_shapes": {
"clip_input": { "0": ["batch", 1, 512], "1": ["channels", 0, 3],
"2": ["height", 0, 512], "3": ["width", 0, 512] },
"images": { "0": ["batch", 1, 512], "1": ["height", 0, 512],
"2": ["width", 0, 512], "3": ["channels", 0, 3] }
},
"kv_cache": None
}
"""
Expand All @@ -35,6 +41,12 @@ class IoConfig(ConfigBase):
output_shapes: List[List[int]] = None
output_types: List[str] = None
dynamic_axes: Dict[str, Dict[int, str]] = None
# Please check `dynamic_shapes` in torch.export.export
# https://pytorch.org/docs/stable/export.html#torch.export.export
# NOTE: JSON does not support torch.export.Dim, so we use List[str, int, int] here.
# for example, {"input_ids": {0: torch.export.Dim("batch", min=2, max=1024)}}
# -> {"input_ids": {0: ["batch", 2, 1024]}}
dynamic_shapes: Union[List[Any], Dict[str, Any]] = None
# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference
# even though we want the dimension to be a constant int.
# We use a workaround here: first use dim_param like "1" to represent the dimension, and then
Expand Down
167 changes: 149 additions & 18 deletions olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import onnx
import torch
from packaging import version

from olive.common.config_utils import validate_config
from olive.common.utils import find_submodules, resolve_torch_dtype, tensor_data_to_device
from olive.common.utils import find_submodules, resolve_torch_dtype, tensor_data_to_device, tensor_data_to_dtype
from olive.hardware import AcceleratorSpec
from olive.model import (
DistributedHfModelHandler,
Expand Down Expand Up @@ -103,6 +103,9 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
"Includes config.json, generation_config.json, and tokenizer related files."
),
),
"dynamic": PassConfigParam(
type_=bool, default_value=True, description=("Whether to export the model with dynamic axes/shapes.")
),
}

def _run_for_config(
Expand Down Expand Up @@ -170,6 +173,7 @@ def _export_pytorch_model(
:param device: the device to use for conversion
:param torch_dtype: the dtype to cast the model to before conversion
:param tempdir: directory to use for temporary files
:param dynamic: whether to export the model with dynamic axes/shapes
"""
from olive.common.hf.peft import make_export_compatible_peft
from olive.common.hf.quant import make_export_compatible_quant
Expand All @@ -181,6 +185,8 @@ def _export_pytorch_model(

logger.debug("Converting model on device %s with dtype %s.", device, torch_dtype)
pytorch_model.to(device)

dummy_inputs = tensor_data_to_dtype(dummy_inputs, torch_dtype)
dummy_inputs = tensor_data_to_device(dummy_inputs, device)

if isinstance(pytorch_model, torch.jit.RecursiveScriptModule):
Expand All @@ -194,6 +200,10 @@ def _export_pytorch_model(
# get input and output names, and dynamic axes
assert io_config is not None, "Cannot get io_config for the model."
io_config = validate_config(io_config, IoConfig)
# If dynamic is False, set dynamic_axes and dynamic_shapes to None
if not config["dynamic"]:
io_config.dynamic_axes = None
io_config.dynamic_shapes = None

onnx_model = None
if config["use_dynamo_exporter"]:
Expand All @@ -213,11 +223,11 @@ def _export_pytorch_model(

dynamo_config.capture_scalar_outputs = True
if isinstance(dummy_inputs, dict):
dummy_kwargs = dummy_inputs
dummy_kwargs = the_input = dummy_inputs
dummy_inputs = ()
else:
dummy_kwargs = {}
dummy_inputs = tuple(dummy_inputs)
dummy_inputs = the_input = tuple(dummy_inputs)

if torch_version < dynamo_supported_version:
onnx_program = torch.onnx.dynamo_export(
Expand All @@ -228,19 +238,35 @@ def _export_pytorch_model(
)
onnx_model = onnx_program.model_proto
else:
onnx_program = torch.onnx.export( # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
pytorch_model,
dummy_inputs,
kwargs=dummy_kwargs,
opset_version=config["target_opset"],
input_names=io_config.input_names,
output_names=io_config.output_names,
dynamic_axes=io_config.dynamic_axes,
dynamo=True,
fallback=True,
)
assert onnx_program is not None
onnx_model = onnx_program.model_proto
# NOTE: Usually validation is done in io_config.py, but because
# dynamic_shapes has nested complexity, and it can't be validated multiple
# times like others, we validate it here.
io_config.dynamic_shapes = _validate_dynamic_shapes(io_config.dynamic_shapes, the_input)
io_config.dynamic_shapes = _convert_dynamic_shapes_to_torch_export_dims(io_config.dynamic_shapes)

# there might be multiple files created during export, so we need to track the dir
# if there are other processes writing to the same dir, we might end up deleting files created by
# other processes
with tempfile.TemporaryDirectory(dir=tempdir, prefix="olive_tmp") as tmp_dir:
tmp_dir_path = Path(tmp_dir)
tmp_model_path = resolve_onnx_path(tmp_dir_path)

onnx_program = torch.onnx.export( # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
pytorch_model,
dummy_inputs,
tmp_model_path, # needed for fallback=True
kwargs=dummy_kwargs,
opset_version=config["target_opset"],
input_names=io_config.input_names,
output_names=io_config.output_names,
dynamic_axes=io_config.dynamic_axes,
dynamic_shapes=io_config.dynamic_shapes,
dynamo=True,
fallback=True,
report=logger.isEnabledFor(logging.DEBUG),
)
assert onnx_program is not None
onnx_model = onnx_program.model_proto
else:
# there might be multiple files created during export, so we need to track the dir
# if there are other processes writing to the same dir, we might end up deleting files created by
Expand Down Expand Up @@ -371,7 +397,6 @@ def _convert_model_on_device(
# get dummy inputs
dummy_inputs = self._get_dummy_inputs(model, config)
io_config = model.io_config

converted_onnx_model = OnnxConversion._export_pytorch_model(
pytorch_model, dummy_inputs, io_config, config, device, torch_dtype, tempfile.tempdir
)
Expand Down Expand Up @@ -549,3 +574,109 @@ def _run_for_config(
converted_model_proto, str(Path(model.model_path).resolve().parent)
)
return model_proto_to_olive_model(converted_model_proto, output_model_path, config)


def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs):
"""Validate dynamic_shapes.
This function validates two things:
(1) To have a valid format of dynamic_shapes, we need to make sure the axes are converted to int.
It was string in the JSON format.
(2) To make sure the dynamic_shapes is in the same tree structure as dummy_inputs.
:param dynamic_shapes: the dynamic_shapes to validate
:param dummy_inputs: the dummy_inputs to align the dynamic_shapes format
:return: the validated dynamic_shapes
"""
if not dynamic_shapes:
return dynamic_shapes

from torch.utils import _pytree

def is_dict_axes(x) -> bool:
return isinstance(x, dict) and all(
isinstance(key, str)
and len(key) == 1
and isinstance(value, list)
and len(value) == 3
and isinstance(value[0], str)
and isinstance(value[1], int)
and isinstance(value[2], int)
for key, value in x.items()
)

flat_dynamic_shapes, _ = _pytree.tree_flatten(dynamic_shapes, is_leaf=is_dict_axes)
new_dynamic_shapes = []
for axes in flat_dynamic_shapes:
if axes is None:
new_dynamic_shapes.append(axes)
continue
new_axes = {}
for axis, dynamic_shape in axes.items():
new_axes[int(axis)] = dynamic_shape
new_dynamic_shapes.append(new_axes)

_, tree_structure = _pytree.tree_flatten(dummy_inputs, is_leaf=is_dict_axes)
return _pytree.tree_unflatten(new_dynamic_shapes, tree_structure)


def _convert_dynamic_shapes_to_torch_export_dims(
dynamic_shapes: Dict[str, Dict[int, torch.export.Dim]]
) -> Dict[str, Dict[int, torch.export.Dim]]:
"""Convert dynamic_shapes to torch export dims.
torch.onnx.export takes the exported program (fx graph) from
torch.export.export, which requires the dynamic_shapes to be in the format
of using torch.export.Dim(name, min=min, max=max). This function converts
the dynamic_shapes to the format that torch.export.export requires.
For a single axis:
before: ["axis_name", min_value, max_value]
after: torch.export.Dim("axis_name", min=min_value, max=max_value)
# Please check `dynamic_shapes` in torch.export.export
# https://pytorch.org/docs/stable/export.html#torch.export.export
:param dynamic_shapes: the dynamic_shapes to convert
:return: the converted dynamic_shapes
"""
if dynamic_shapes is None:
return None

# If the axes has the same name, they should be the same torch.export.Dim
torch_export_dim_farm: Dict[str, torch.export.Dim] = {}

# dynamic_shapes follows input format, which could be nested
def _from_tuple_to_dim(data: Union[Dict, List, Tuple, Any]) -> Union[Dict, List, Tuple, Any]:
if isinstance(data, dict):
for key, value in data.items():
data[key] = _from_tuple_to_dim(value)
# TODO(titaiwang): Can we use `dummy_inputs` to align the dynamic_shapes format?
# JSON foramt does not accept tuple.
elif isinstance(data, (tuple, list)):
# We assume the tuple/list is in the format of (name, min, max)
# TODO(titaiwang): This format could potentially be used as model
# inputs (would string be used as model input?)
if len(data) == 3 and isinstance(data[0], str) and isinstance(data[1], int) and isinstance(data[2], int):
if data[0] in torch_export_dim_farm:
if torch_export_dim_farm[data[0]].min == data[1] and torch_export_dim_farm[data[0]].max == data[2]:
return torch_export_dim_farm[data[0]]
raise ValueError(
f"Found different boundary for the same axis name {data[0]}. "
f"Previous min: {torch_export_dim_farm[data[0]].min} and "
f"max: {torch_export_dim_farm[data[0]].max}. "
f"Current min: {data[1]} and max: {data[2]}."
)
dim = torch.export.Dim(data[0], min=data[1], max=data[2])
torch_export_dim_farm[data[0]] = dim
return dim
if isinstance(data, tuple):
return tuple(_from_tuple_to_dim(item) for item in data)
if isinstance(data, list):
return [_from_tuple_to_dim(item) for item in data]
return data

return _from_tuple_to_dim(dynamic_shapes)
13 changes: 13 additions & 0 deletions olive/passes/pytorch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def inherit_pytorch_from_hf(
hf_io_config = deepcopy(model.io_config)
hf_dummy_inputs = model.get_dummy_inputs()

dynamic_shapes = hf_io_config.get("dynamic_shapes", {})
if isinstance(dynamic_shapes, dict):
{
k: v
for k, v in hf_io_config.get("dynamic_axes", {}).items()
if not k.startswith(("present", "past_key_values"))
}
else:
# TODO(titaiwang): fix this when we have a better way to handle dynamic_shapes
# If the dynamic_shapes is a list, we don't inherit it since
# we do not know the exact index of the past_key_values in the list
dynamic_shapes = {}
# kv cache will be handled by the kv_cache flag in io_config
io_config = {
"input_names": [i for i in hf_io_config.get("input_names", []) if not i.startswith("past_key_values")],
Expand All @@ -62,6 +74,7 @@ def inherit_pytorch_from_hf(
for k, v in hf_io_config.get("dynamic_axes", {}).items()
if not k.startswith(("present", "past_key_values"))
},
"dynamic_shapes": dynamic_shapes,
}

for i_name in io_config["input_names"]:
Expand Down
2 changes: 2 additions & 0 deletions test/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ onnxconverter_common
onnxmltools
onnxoptimizer
onnxruntime_extensions
# TODO(titaiwai): Add onnxscript to requirements.txt once it's released
onnxscript
openvino==2023.2.0
optimum>=1.17.0
pandas
Expand Down
5 changes: 5 additions & 0 deletions test/unit_test/model/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ def setup(self):
"attention_mask": {"0": "batch_size", "1": "seq_length"},
"token_type_ids": {"0": "batch_size", "1": "seq_length"},
},
"dynamic_shapes": {
"input_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
"attention_mask": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
"token_type_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
},
}

def test_dummy_input_with_kv_cache(self):
Expand Down
5 changes: 5 additions & 0 deletions test/unit_test/model/test_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def io_config_fixture():
"attention_mask": {"0": "batch_size", "1": "seq_length"},
"token_type_ids": {"0": "batch_size", "1": "seq_length"},
},
"dynamic_shapes": {
"input_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
"attention_mask": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
"token_type_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
},
}


Expand Down
Loading

0 comments on commit f372b89

Please sign in to comment.