Skip to content
This repository has been archived by the owner on Feb 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #339 from aseemw/dev/add_const
Browse files Browse the repository at this point in the history
Added squeeze, unsqueeze, const conversions
  • Loading branch information
aseemw authored Sep 22, 2018
2 parents bd65864 + 0ac96db commit 0d5c37e
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 18 deletions.
2 changes: 2 additions & 0 deletions onnx_coreml/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def __init__(self,
# data blob name to the op_type that generates it
self.blob_from_op_type = {} # type: Dict[Text, Text]

self.constant_layers_added = {} # type: Dict[Text, bool]

for node_ in nodes:
for input_ in node_.inputs:
if input_ in self.blob_to_op_type:
Expand Down
52 changes: 39 additions & 13 deletions onnx_coreml/_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from coremltools.proto import NeuralNetwork_pb2 #type: ignore
from ._error_utils import ErrorHandling

_SEQUENCE_LAYERS_REGISTRY = set(["LSTM"])

def _compare(a, b, encoding="utf8"): #type: (Text, Text, Text) -> bool
if isinstance(a, bytes):
a = a.decode(encoding)
Expand Down Expand Up @@ -348,9 +346,6 @@ def _convert_add(builder, node, graph, err): # type: (NeuralNetworkBuilder, Nod
shape_bias=[second_input.shape[0]])
return

if 'broadcast' in node.attrs:
if node.attrs['broadcast'] == 1:
return err.unsupported_op_configuration(builder, node, graph, "Broadcast Add is not supported now")
builder.add_elementwise(
name=node.name,
input_names=node.inputs,
Expand All @@ -359,10 +354,6 @@ def _convert_add(builder, node, graph, err): # type: (NeuralNetworkBuilder, Nod
)

def _convert_mul(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
if 'broadcast' in node.attrs:
if node.attrs['broadcast'] == 1:
return err.unsupported_op_configuration(builder, node, graph, "Broadcast Multiply is not supported now")

builder.add_elementwise(
name=node.name,
input_names=node.inputs,
Expand All @@ -371,10 +362,6 @@ def _convert_mul(builder, node, graph, err): # type: (NeuralNetworkBuilder, Nod
)

def _convert_div(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
if 'broadcast' in node.attrs:
if node.attrs['broadcast'] == 1:
return err.unsupported_op_configuration(builder, node, graph, "Broadcast Div is not supported now")

builder.add_unary(name=node.name + '_inverse', #type: ignore
input_name=node.inputs[1],
output_name=node.inputs[1] + '_inverse',
Expand Down Expand Up @@ -985,6 +972,34 @@ def _convert_custom(builder, node, graph, err): # type: (NeuralNetworkBuilder, N

err.custom_layer_nodes.append(node)

def _convert_identity(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
builder.add_activation(
name=node.name,
non_linearity = 'LINEAR',
input_name=node.inputs[0],
output_name=node.outputs[0],
params=[1.0, 0.0]
)

def _convert_const(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None

for name, value in node.input_tensors.items():
if name not in graph.constant_layers_added:
shape = value.shape
coreml_shape = [1,1,1]
if len(shape) == 3:
coreml_shape = list(shape)
elif len(shape) == 1:
coreml_shape = [shape[0],1,1]
elif len(shape) == 2:
coreml_shape = [1, shape[0], shape[1]]
else:
return err.unsupported_op_configuration(builder, node, graph, "unable to translate constant array shape to CoreML shape")
builder.add_load_constant(name=name,
output_name=name,
constant_value=value.flatten(),
shape=coreml_shape)
graph.constant_layers_added[name] = True


_ONNX_NODE_REGISTRY = {
Expand Down Expand Up @@ -1050,8 +1065,13 @@ def _convert_custom(builder, node, graph, err): # type: (NeuralNetworkBuilder, N
"ArgMin": _convert_reduce,
"Clip": _convert_clip,
"MeanVarianceNormalization": _convert_mvn,
"Unsqueeze": _convert_identity,
"Squeeze": _convert_identity
}

_SEQUENCE_LAYERS_REGISTRY = set(["LSTM"])

_CONST_INPUT_ALLOWED_LAYERS = set([ "Add", "Sum", "Mul", "Concat", "Max", "Min", "Div", "Reciprocal"])

def _get_node_converter_fn(builder, node, err): # type: (NeuralNetworkBuilder, Node, ErrorHandling) -> Callable[[NeuralNetworkBuilder, Node, Graph, ErrorHandling], None]
"""
Expand All @@ -1063,6 +1083,12 @@ def _get_node_converter_fn(builder, node, err): # type: (NeuralNetworkBuilder,
else:
return err.unsupported_op(node)

def _add_const_inputs_if_required(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
if node.op_type in _CONST_INPUT_ALLOWED_LAYERS:
if len(node.input_tensors) > 0:
_convert_const(builder, node, graph, err)


def _convert_node(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
converter_fn = _get_node_converter_fn(builder, node, err)
return converter_fn(builder, node, graph, err)
10 changes: 8 additions & 2 deletions onnx_coreml/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Tuple

from ._operators import _convert_node, _SEQUENCE_LAYERS_REGISTRY, _ONNX_NODE_REGISTRY
from ._operators import _convert_node, _SEQUENCE_LAYERS_REGISTRY, _ONNX_NODE_REGISTRY, _add_const_inputs_if_required
from ._graph import Graph, EdgeInfo, Transformer
from ._transformers import ConvAddFuser, DropoutRemover, \
ReshapeInitTensorFuser, BNBroadcastedMulFuser, BNBroadcastedAddFuser, \
Expand Down Expand Up @@ -410,6 +410,7 @@ def convert(model, # type: Union[onnx.ModelProto, Text]

for i, node in enumerate(graph.nodes):
print("%d/%d: Converting Node Type %s" %(i+1, len(graph.nodes), node.op_type))
_add_const_inputs_if_required(builder, node, graph, err)
_convert_node(builder, node, graph, err)

if add_deprocess:
Expand Down Expand Up @@ -460,7 +461,12 @@ def convert(model, # type: Union[onnx.ModelProto, Text]
if outputs.name == output_:
builder.spec.description.output[i].shortDescription = 'This output is a sequence'

mlmodel = MLModel(builder.spec)
print("Translation to CoreML spec completed. Now compiling the CoreML model.")
try:
mlmodel = MLModel(builder.spec)
except:
raise ValueError('Compilation failed. Translation to CoreML spec was incorrect.')


# print information about all ops for which custom layers have been added
if len(err.custom_layer_nodes) > 0:
Expand Down
31 changes: 30 additions & 1 deletion tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _test_torch_model_single_io(torch_model, torch_input_shape, coreml_input_sha

# delete onnx model
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
shutil.rmtree(model_dir)

class OnnxModelTest(unittest.TestCase):

Expand Down Expand Up @@ -77,6 +77,35 @@ def forward(self, x):
torch_model.train(False)
_test_torch_model_single_io(torch_model, (1, 3, 100, 100), (3, 100, 100)) # type: ignore

def test_const_initializer1(self): # typr: () -> None
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ones = torch.nn.Parameter(torch.ones(1,))

def forward(self, x):
y = x + self.ones
return y

torch_model = Net() # type: ignore
torch_model.train(False)
_test_torch_model_single_io(torch_model, (1, 3), (3,)) # type: ignore


def test_const_initializer2(self): # typr: () -> None
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()

def forward(self, x):
y = x + torch.nn.Parameter(torch.ones(2, 3))
return y

torch_model = Net() # type: ignore
torch_model.train(False)
_test_torch_model_single_io(torch_model, (1, 2, 3), (1, 2, 3)) # type: ignore



if __name__ == '__main__':
unittest.main()
Expand Down
2 changes: 0 additions & 2 deletions tests/onnx_backend_node_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,9 @@ def run_node(cls,
backend_test.exclude('test_log_softmax_lastdim_cpu')
backend_test.exclude('test_softmax_functional_dim3_cpu')
backend_test.exclude('test_softmax_lastdim_cpu')
backend_test.exclude('test_squeeze_cpu')
backend_test.exclude('test_sub_bcast_cpu')
backend_test.exclude('test_sub_cpu')
backend_test.exclude('test_sub_example_cpu')
backend_test.exclude('test_unsqueeze_cpu')
backend_test.exclude('test_slice_end_out_of_bounds_cpu')
backend_test.exclude('test_slice_neg_cpu')
backend_test.exclude('test_GLU_cpu')
Expand Down

0 comments on commit 0d5c37e

Please sign in to comment.