Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix onnx tests; need to pass scalar value (not np.array) to create_co…
Browse files Browse the repository at this point in the history
…nst_scalar_node.
  • Loading branch information
Joe Evans committed Feb 9, 2021
1 parent b49a28f commit d0dfa91
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def create_const_scalar_node(input_name, value, kwargs):
initializer = kwargs["initializer"]
input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype]
value_node = make_tensor_value_info(input_name, input_type, ())
tensor_node = make_tensor(input_name, input_type, (), (value,))
tensor_node = make_tensor(input_name, input_type, (), ([value]))
initializer.append(tensor_node)
return value_node

Expand Down Expand Up @@ -362,7 +362,7 @@ def convert_fully_connected(node, **kwargs):
in_nodes = [name+'_data_flattened', input_nodes[1]]

if no_bias:
nodes.append(create_const_scalar_node(name+'_bias', np.array([0], dtype=dtype), kwargs))
nodes.append(create_const_scalar_node(name+'_bias', np.int32(0).astype(dtype), kwargs))
in_nodes.append(name+'_bias')
else:
in_nodes.append(input_nodes[2])
Expand Down Expand Up @@ -2430,7 +2430,7 @@ def convert_layer_norm(node, **kwargs):
create_tensor([], name+"_void", kwargs["initializer"]),
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs),
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs),
create_const_scalar_node(name+"_2_s", np.array(2, dtype=dtype), kwargs),
create_const_scalar_node(name+"_2_s", np.int64(2).astype(dtype), kwargs),
create_const_scalar_node(name+"_eps", np.float32(eps), kwargs),
make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=[axes]),
make_node("Sub", [input_nodes[0], name+"_rm0_out"], [name+"_sub0_out"]),
Expand Down Expand Up @@ -2829,9 +2829,9 @@ def convert_arange_like(node, **kwargs):
raise NotImplementedError("arange_like operator with repeat != 1 not yet implemented.")

nodes = [
create_const_scalar_node(name+"_start", np.array([start], dtype=dtype), kwargs),
create_const_scalar_node(name+"_step", np.array([step], dtype=dtype), kwargs),
create_const_scalar_node(name+"_half_step", np.array([float(step)*0.5], dtype=dtype), kwargs),
create_const_scalar_node(name+"_start", np.float32(start).astype(dtype), kwargs),
create_const_scalar_node(name+"_step", np.float32(step).astype(dtype), kwargs),
create_const_scalar_node(name+"_half_step", np.float32(float(step)*0.5).astype(dtype), kwargs),
create_tensor([], name+'_void', kwargs["initializer"])
]
if axis == 'None':
Expand Down Expand Up @@ -2947,9 +2947,9 @@ def convert_arange(node, **kwargs):
raise NotImplementedError("arange operator with repeat != 1 not yet implemented.")

nodes = [
create_const_scalar_node(name+"_start", np.array([start], dtype=dtype), kwargs),
create_const_scalar_node(name+"_stop", np.array([stop], dtype=dtype), kwargs),
create_const_scalar_node(name+"_step", np.array([step], dtype=dtype), kwargs),
create_const_scalar_node(name+"_start", np.float32(start).astype(dtype), kwargs),
create_const_scalar_node(name+"_stop", np.float32(stop).astype(dtype), kwargs),
create_const_scalar_node(name+"_step", np.float32(step).astype(dtype), kwargs),
make_node("Range", [name+"_start", name+"_stop", name+"_step"], [name], name=name)
]

Expand Down Expand Up @@ -2977,7 +2977,7 @@ def convert_reverse(node, **kwargs):
create_tensor([axis], name+'_axis', kwargs['initializer']),
create_tensor([axis+1], name+'_axis_p1', kwargs['initializer']),
create_tensor([], name+'_void', kwargs['initializer']),
create_const_scalar_node(name+'_m1_s', np.array([-1], dtype='int64'), kwargs),
create_const_scalar_node(name+'_m1_s', np.int64(-1), kwargs),
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Shape', [name+'_shape'], [name+'_dim']),
make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']),
Expand Down

0 comments on commit d0dfa91

Please sign in to comment.