-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
29 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import onnx | ||
from onnx import helper | ||
from printk import print_colored_box | ||
|
||
# 加载 ONNX 模型 | ||
model_path = '/mnt/share_disk/bruce_cui/onnx_models/laneline_154w_20240320_fastbev_onnxsim.onnx' | ||
model = onnx.load(model_path) | ||
|
||
# 假设我们要修改的输出是模型的最后一个输出 | ||
# 首先,获取输出的数量 | ||
num_outputs = len(model.graph.output) | ||
print(num_outputs) | ||
|
||
for output in model.graph.output: | ||
|
||
if output.type.tensor_type.elem_type == onnx.TensorProto.INT64: | ||
# 修改数据类型为 float32 | ||
output.type.tensor_type.elem_type = onnx.TensorProto.FLOAT | ||
|
||
# 保存修改后的模型 | ||
modified_model_path = '/mnt/share_disk/bruce_cui/onnx_models/laneline_154w_20240320_fastbev_onnxsim_output_float32.onnx' | ||
onnx.save(model, modified_model_path) | ||
|
||
print_colored_box("模型输出类型已从 int64 修改为 float32。") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters