diff --git a/tools/onnx2bnn/OnnxConverter.cpp b/tools/onnx2bnn/OnnxConverter.cpp index 51989eb..16c7fb8 100644 --- a/tools/onnx2bnn/OnnxConverter.cpp +++ b/tools/onnx2bnn/OnnxConverter.cpp @@ -172,9 +172,10 @@ std::vector OnnxConverter::split( return outputs; } -std::vector OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto, - const std::string &filepath, - const OnnxConverter::Level level) { +std::vector OnnxConverter::Convert( + const ONNX_NAMESPACE::ModelProto &model_proto, const std::string &filepath, + const OnnxConverter::Level level, + const std::vector &expected_binary_conv_outputs) { GOOGLE_PROTOBUF_VERIFY_VERSION; // We recognize binary convolutions in our custom ONNX optimizers. @@ -271,7 +272,12 @@ std::vector OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto } auto ori_weight_name = m(node.input(1)); - const bool binary_conv = (node.domain() == "dabnn"); + const bool binary_conv = + (node.domain() == "dabnn") || + (std::find(expected_binary_conv_outputs.begin(), + expected_binary_conv_outputs.end(), + node.output(0)) != + expected_binary_conv_outputs.end()); if (binary_conv) { binary_conv_outputs.push_back(node.output(0)); } @@ -476,6 +482,17 @@ std::vector OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto throw std::invalid_argument("Unsupported operator " + op); } } + + for (const auto &expected : expected_binary_conv_outputs) { + if (std::find(binary_conv_outputs.begin(), binary_conv_outputs.end(), + expected) == binary_conv_outputs.end()) { + throw std::invalid_argument( + expected + + " is in the list file but not in the ONNX model, please check " + "your list file"); + } + } + auto flat_layers = builder_.CreateVector(layers_); auto flat_inputs = builder_.CreateVector(inputs); auto flat_tensors = builder_.CreateVector(tensors_); diff --git a/tools/onnx2bnn/OnnxConverter.h b/tools/onnx2bnn/OnnxConverter.h index 68db990..a6b61e9 100644 --- a/tools/onnx2bnn/OnnxConverter.h +++ b/tools/onnx2bnn/OnnxConverter.h @@ -155,7 +155,7 @@ class OnnxConverter { }; std::vector Convert(const ONNX_NAMESPACE::ModelProto &model, const std::string &filepath, - const Level level=Level::kModerate); + const Level level, const std::vector &expected_binary_conv_outputs); }; template <> diff --git a/tools/onnx2bnn/onnx2bnn.cpp b/tools/onnx2bnn/onnx2bnn.cpp index 1f9e9ad..1afc82b 100644 --- a/tools/onnx2bnn/onnx2bnn.cpp +++ b/tools/onnx2bnn/onnx2bnn.cpp @@ -19,7 +19,7 @@ void usage(const std::string &filename) { std::cout << "Usage:" << std::endl; std::cout << " " << filename << " onnx_model output_filename [ --strict | --moderate | " - "--aggressive ] [--verbose]" + "--aggressive ] [--binary-list] [--verbose]" << std::endl; std::cout << std::endl; std::cout << "Options:" << std::endl; @@ -41,6 +41,11 @@ void usage(const std::string &filename) { "A Conv operator, whose input is got from a Sign op and a Pad op " "(the order doesn't matter), and weight is got from a Sign op." << std::endl; + std::cout + << " --binary-list A text file containing the **output " + "names** of some convolutions, which will be treated as binary " + "convlutions unconditionally. It is mainly for benchmark purpose." + << std::endl; std::cout << std::endl; std::cout << "Example:" << std::endl; std::cout << " " << filename @@ -55,6 +60,7 @@ void usage(const std::string &filename) { int main(int argc, char **argv) { argh::parser cmdl; + cmdl.add_param("--binary-list"); cmdl.parse(argc, argv); google::InitGoogleLogging(cmdl[0].c_str()); FLAGS_alsologtostderr = true; @@ -85,6 +91,18 @@ int main(int argc, char **argv) { FLAGS_v = 5; } + const auto binary_list_filepath = cmdl("binary-list").str(); + vector expected_binary_conv_outputs; + if (!binary_list_filepath.empty()) { + std::ifstream ifs(binary_list_filepath); + if (ifs.is_open()) { + string binary_conv_output; + while (ifs >> binary_conv_output) { + expected_binary_conv_outputs.push_back(binary_conv_output); + } + } + } + ONNX_NAMESPACE::ModelProto model_proto; { std::ifstream ifs(cmdl[1], std::ios::in | std::ios::binary); @@ -93,8 +111,8 @@ int main(int argc, char **argv) { } bnn::OnnxConverter converter; - const auto binary_conv_outputs = - converter.Convert(model_proto, cmdl[2], opt_level); + const auto binary_conv_outputs = converter.Convert( + model_proto, cmdl[2], opt_level, expected_binary_conv_outputs); LOG(INFO) << "Conversion completed! Found " << binary_conv_outputs.size() << " binary convolutions. Add --verbose to get what they are.";