Skip to content

Commit

Permalink
Add expand (#79)
Browse files Browse the repository at this point in the history
* add expand

* fix expand op convert
  • Loading branch information
bzhang5 authored Jun 10, 2021
1 parent 90b7a9a commit 51ddcda
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tools/onnx/onnx_serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,9 @@ bool OnnxSerializer::LoadConstTensor(StaticGraph* graph, const onnx::GraphProto&
std::vector<std::string> tensor_check;
for (int i = 0; i < onnx_graph.node_size(); i++)
{

const onnx::NodeProto& onnx_node = onnx_graph.node(i);

for (int j = 0; j < onnx_node.input_size(); j++)
{
const std::string& input_name = onnx_node.input(j);
Expand Down Expand Up @@ -470,7 +471,6 @@ void OnnxSerializer::CreateInputNode(StaticGraph* graph, const onnx::GraphProto&
for (int i = 0; i < input_number; i++)
{
const onnx::ValueInfoProto& val = onnx_graph.input(i);

if (FindConstTensor(graph, val.name()) != nullptr)
{
continue;
Expand Down Expand Up @@ -1876,12 +1876,14 @@ static bool LoadOnnxExpand(StaticGraph* graph, StaticNode* node, const onnx::Nod

StaticTensor* shape_tensor = FindTensor(graph, onnx_node.input(1));
int size = shape_tensor->dims[0];
int64_t* data = ( int64_t* )GetConstTensorBuffer(shape_tensor);
param.dim_num = size;
for (int i = 0; i < size; i++)
{
param.shape.push_back(data[i]);
if(shape_tensor->mem_size != 0){
int64_t* data = ( int64_t* )GetConstTensorBuffer(shape_tensor);
for (int i = 0; i < size; i++)
{
param.shape.push_back(data[i]);
}
}
param.dim_num = size;

StaticOp* op = CreateStaticOp(graph, "Expand");
SetOperatorParam(op, param);
Expand Down

0 comments on commit 51ddcda

Please sign in to comment.