Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyu-work committed Dec 10, 2024
1 parent 27754e1 commit 46e7b10
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 9 deletions.
15 changes: 11 additions & 4 deletions docs/source/how-to/configure-workflows/pass/convert-onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,26 @@ b. More fine-grained control of the conversion conditions is also possible:

See [Float16 Conversion](https://onnxruntime.ai/docs/performance/model-optimizations/float16.html#float16-conversion) for more detailed description of the available configuration parameters.

## Inputs/Outputs Float16 to Float32 Conversion
## Inputs/Outputs DataType Conversion

Certain environments such as Onnxruntime WebGPU prefers Float32 logits. The `OnnxIOFloat16ToFloat32` pass converts the inputs and outputs to use Float32 instead of Float16.
In certain environments, such as Onnxruntime WebGPU, Float32 logits are preferred. The `OnnxIODataTypeConverte` pass enables conversion of model inputs and outputs to a specified data type. This is particularly useful for converting between data types such as Float16 and Float32, or any other supported ONNX data types.

### Example Configuration

a. The most basic configuration, which is suitable for many models, leaves all configuration options set to their default values:
The simplest configuration converts all inputs and outputs from Float16 (source_dtype = 10) to Float32 (target_dtype = 1), which is suitable for many models:

```json
{
"type": "OnnxIOFloat16ToFloat32"
"type": "OnnxIODataTypeConverter",
"source_dtype": 10,
"target_dtype": 1
}
```

### Datatype Mapping

The `source_dtype` and `target_dtype` are integers corresponding to ONNX data types. You can find the complete mapping in the ONNX protobuf definition [here](https://github.com/onnx/onnx/blob/96a0ca4374d2198944ff882bd273e64222b59cb9/onnx/onnx.proto3#L503-L551).

## Mixed Precision Conversion
Converting model to mixed precision.

Expand Down
4 changes: 2 additions & 2 deletions docs/source/reference/pass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ OnnxFloatToFloat16

.. _onnx_io_float16_to_float32:

OnnxIOFloat16ToFloat32
OnnxIODataTypeConverter
----------------------
.. autoconfigclass:: olive.passes.OnnxIOFloat16ToFloat32
.. autoconfigclass:: olive.passes.OnnxIODataTypeConverter

.. _ort_mixed_precision:

Expand Down
2 changes: 1 addition & 1 deletion examples/phi2/phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def main(raw_args=None):
template_json["systems"]["local_system"]["accelerators"] = [
{"device": "GPU", "execution_providers": ["JsExecutionProvider"]}
]
fl_type = {"type": "OnnxIOFloat16ToFloat32"}
fl_type = {"type": "OnnxIODataTypeConverte"}
template_json["passes"]["fp32_logits"] = fl_type
new_json_file = "phi2_web.json"
with open(new_json_file, "w") as f:
Expand Down
2 changes: 1 addition & 1 deletion examples/phi3/phi3_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"merge_adapter_weights": { "type": "MergeAdapterWeights" },
"awq": { "type": "AutoAWQQuantizer" },
"builder": { "type": "ModelBuilder", "precision": "<place_holder>" },
"fp32_logits": { "type": "OnnxIOFloat16ToFloat32" },
"fp32_logits": { "type": "OnnxIODataTypeConverte" },
"tune_session_params": {
"type": "OrtSessionParamsTuning",
"data_config": "gqa_transformer_prompt_dummy_data",
Expand Down
2 changes: 1 addition & 1 deletion olive/cli/auto_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def _get_passes_config(self, config: Dict[str, Any], olive_config: OlivePackageC
),
("peephole_optimizer", {"type": "OnnxPeepholeOptimizer"}),
# change io types to fp32
("fp16_to_fp32", {"type": "OnnxIOFloat16ToFloat32"}),
("fp16_to_fp32", {"type": "OnnxIODataTypeConverte"}),
# qnn preparation passes
("to_fixed_shape", {"type": "DynamicToFixedShape", "dim_param": None, "dim_value": None}),
("qnn_preprocess", {"type": "QNNPreprocess"}),
Expand Down

0 comments on commit 46e7b10

Please sign in to comment.