Skip to content

Commit

Permalink
pnnx fix onnx clip conversion, add onnx clamp test, always reserve on…
Browse files Browse the repository at this point in the history
…nx split outputs, convert onnx mod (#5834)
  • Loading branch information
nihui authored Dec 19, 2024
1 parent 24908b7 commit 78aca5d
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 12 deletions.
33 changes: 30 additions & 3 deletions tools/pnnx/src/pass_level2/torch_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class torch_clamp_onnx : public GraphRewriterPass
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
aten::clamp op_0 1 1 input out min=%min max=%max
Clip op_0 1 1 input out min=%min max=%max
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -68,7 +68,7 @@ class torch_clamp_onnx_1 : public GraphRewriterPass
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
aten::clamp op_0 1 1 input out max=%max
Clip op_0 1 1 input out max=%max
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -80,11 +80,38 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["min"] = 0.f;
op->params["min"] = Parameter();
op->params["max"] = captured_params.at("max");
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clamp_onnx_1, 40)

class torch_clamp_onnx_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
Clip op_0 1 1 input out min=%min
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.clamp";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["max"] = Parameter();
op->params["min"] = captured_params.at("min");
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clamp_onnx_2, 40)

} // namespace pnnx
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,9 +739,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph)
if (op_type == "And") sim_op_type = "aten::__and__";
if (op_type == "Or") sim_op_type = "aten::__or__";
if (op_type == "Xor") sim_op_type = "aten::__xor__";
if (op_type == "Mod" && onnx2pnnx::OnnxNodeProxy(node).attribute("fmod").value_i() == 1) sim_op_type = "aten::fmod";

// trinaryop
if (op_type == "Clip") sim_op_type = "aten::clamp";
if (op_type == "Where") sim_op_type = "aten::where";
}
else if (string_starts_with(op_type, "aten_"))
Expand Down
30 changes: 22 additions & 8 deletions tools/pnnx/src/pass_onnx/dead_code_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,25 @@ static void collect_dead_nodes(const onnx::GraphProto& graph, std::vector<std::s

if (is_outputs_live)
{
for (int j = node.output_size() - 1; j >= 0; j--)
bool is_outputs_optional = true;

// some operator output be discarded even if not used
const std::string& op_type = node.op_type();
if (op_type == "Split") is_outputs_optional = false;

if (is_outputs_optional)
{
if (live_inputs.find(node.output(j)) == live_inputs.end())
{
dead_outputs.push_back(node.output(j));
}
else
for (int j = node.output_size() - 1; j >= 0; j--)
{
// leading outputs cannot be optional
break;
if (live_inputs.find(node.output(j)) == live_inputs.end())
{
dead_outputs.push_back(node.output(j));
}
else
{
// leading outputs cannot be optional
break;
}
}
}

Expand All @@ -65,6 +74,11 @@ static void collect_dead_nodes(const onnx::GraphProto& graph, std::vector<std::s
else
{
dead_node_indexes.push_back(i);

for (int j = node.output_size() - 1; j >= 0; j--)
{
dead_outputs.push_back(node.output(j));
}
}

if (is_outputs_live)
Expand Down
2 changes: 2 additions & 0 deletions tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct constant_as_attribute
};

static constant_as_attribute caas[] = {
{"Clip", 1, "min"},
{"Clip", 2, "max"},
{"Expand", 1, "shape"},
{"Gather", 1, "indices"},
{"If", 0, "cond"},
Expand Down
1 change: 1 addition & 0 deletions tools/pnnx/tests/onnx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ pnnx_onnx_add_test(Tensor_view)
pnnx_onnx_add_test(torch_cat)
pnnx_onnx_add_test(torch_ceil)
pnnx_onnx_add_test(torch_chunk)
pnnx_onnx_add_test(torch_clamp)
pnnx_onnx_add_test(torch_flatten)
pnnx_onnx_add_test(torch_floor)
pnnx_onnx_add_test(torch_max)
Expand Down
60 changes: 60 additions & 0 deletions tools/pnnx/tests/onnx/test_torch_clamp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = torch.clamp(x, max=2)
y = torch.clamp(y, min=0)
z = torch.clamp(z, min=-1, max=1)
return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 3, 16)
y = torch.rand(1, 5, 9, 11)
z = torch.rand(14, 8, 5, 9, 10)

a = net(x, y, z)

# export onnx
torch.onnx.export(net, (x, y, z), "test_torch_clamp.onnx")

# onnx to pnnx
import os
os.system("../../src/pnnx test_torch_clamp.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]")

# pnnx inference
import test_torch_clamp_pnnx
b = test_torch_clamp_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

0 comments on commit 78aca5d

Please sign in to comment.