Skip to content

Commit

Permalink
[Feature] Add exporter pytest (#504)
Browse files Browse the repository at this point in the history
* add exporter pytest

* fix bugs

* delete useless codes

* handle onnx

* delete useless codes
  • Loading branch information
HIT-cwh authored Apr 12, 2023
1 parent 05da6f5 commit 7958ed8
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 34 deletions.
13 changes: 8 additions & 5 deletions mmrazor/models/quantizers/exporters/base_quantize_exporter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import onnx
from mmengine import print_log
from onnx import numpy_helper

from .optim_utils import ONNXOptimUtils

try:
import onnx
from onnx import numpy_helper
except ImportError:
from mmrazor.utils import get_package_placeholder
onnx = get_package_placeholder('No module named onnx')
numpy_helper = get_package_placeholder('No module named onnx.numpy_helper')

SUPPORT_QWEIGHT_NODE = ['Gemm', 'Conv', 'ConvTranspose']

PERCHANNEL_FAKEQUANTIZER = [
Expand Down Expand Up @@ -73,9 +79,6 @@ def _init_mappings_from_onnx(self, onnx_model):
self.output2node = self.optimizer.map_output_and_node(onnx_model)
self.name2data = self.optimizer.map_name_and_data(onnx_model)

# todo: maybe useless
# self.name2init = self.optimizer.map_name_and_initializer(onnx_model)

def _remap_input_and_node(self):
"""Rebuild the mapping from input name to a (node, input index)
tuple."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
from typing import List

import numpy as np
import onnx
from google.protobuf.internal.containers import RepeatedScalarFieldContainer
from onnx import helper, numpy_helper

try:
import onnx
from onnx import helper, numpy_helper
except ImportError:
from mmrazor.utils import get_package_placeholder
onnx = get_package_placeholder('No module named onnx')
numpy_helper = get_package_placeholder('No module named onnx.numpy_helper')
helper = get_package_placeholder('No module named onnx.helper')

from .base_quantize_exporter import BaseQuantizeExportor

Expand Down
36 changes: 10 additions & 26 deletions mmrazor/models/quantizers/exporters/optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
import copy
from typing import Dict, List, Optional

import onnx
from mmengine import print_log
from onnx import numpy_helper

try:
import onnx
from onnx import numpy_helper
except ImportError:
from mmrazor.utils import get_package_placeholder
onnx = get_package_placeholder('No module named onnx')
numpy_helper = get_package_placeholder('No module named onnx.numpy_helper')


class ONNXOptimUtils():
Expand Down Expand Up @@ -62,30 +68,6 @@ def map_input_and_node(cls, onnx_model: onnx.ModelProto):
input2node[input_name].append([node, idx])
return input2node

@classmethod
def get_constant(cls, name, onnx_model):
for node in onnx_model.graph.node:
if node.op_type == 'Constant':
if node.output[0] == name:
return numpy_helper.to_array(node.attribute[0].t).tolist()

@classmethod
def get_initializer(cls, initializer_name, onnx_model):
return numpy_helper.to_array(
onnx_model.initializer[initializer_name][0])

@classmethod
def get_tensor_producer(cls, output_name, output2node):
if output_name not in output2node:
return 'INPUT_TOKEN'
return output2node[output_name]

@classmethod
def get_tensor_consumer(self, input_name, input2node):
if input_name not in input2node:
return ['OUTPUT_TOKEN']
return input2node[input_name]

@classmethod
def remove_node_from_onnx(cls, node: onnx.NodeProto,
onnx_model: onnx.ModelProto):
Expand Down Expand Up @@ -260,6 +242,8 @@ def topo_sort(cls,

@classmethod
def optimize(cls, onnx_model):
"""Remove standalone nodes and redundant initializers, and
topologically sort the nodes in a directed acyclic graph."""

input2node = cls.map_input_and_node(onnx_model)
output2node = cls.map_output_and_node(onnx_model)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import onnx

try:
import onnx
except ImportError:
from mmrazor.utils import get_package_placeholder
onnx = get_package_placeholder('No module named onnx')

from .base_quantize_exporter import BaseQuantizeExportor

Expand Down
1 change: 1 addition & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ isort==4.3.21
nbconvert
nbformat
numpy < 1.24.0 # A temporary solution for tests with mmdet.
onnx
pytest
xdoctest >= 0.10.0
yapf
Loading

0 comments on commit 7958ed8

Please sign in to comment.