Skip to content

Commit

Permalink
Add ScatterUpdate value infer (#12595)
Browse files Browse the repository at this point in the history
* Add ScatterUpdate value infer

* Add additional test case to ScatterUpdate tests
  • Loading branch information
mvafin authored Aug 22, 2022
1 parent 9710bde commit 56808c7
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 15 deletions.
25 changes: 25 additions & 0 deletions tools/mo/openvino/tools/mo/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,28 @@ class ScatterSub(Scatter):
class ScatterUpdate(Scatter):
op = op_type = 'ScatterUpdate'
version = 'opset3'

@staticmethod
def infer(node: Node):
node_name = node.soft_get('name', node.id)
Scatter.infer(node)

input_shape = node.in_port(0).data.get_shape()

input_value = node.in_port(0).data.get_value()
indices_value = node.in_port(1).data.get_value()
updates_value = node.in_port(2).data.get_value()

axis = node.in_port(3).data.get_value()

if input_value is not None and indices_value is not None and updates_value is not None and axis is not None:
assert axis.size == 1, "The node {} has axis input value size equal to {} but it should be exactly 1.".format(
node_name, axis.size)
axis = axis.item()
if axis < 0:
axis = len(input_shape) + axis

out_value = input_value.copy()
for idx in np.ndindex(*input_shape[:axis]):
out_value[idx][indices_value] = updates_value[idx]
node.out_port(0).data.set_value(out_value)
114 changes: 99 additions & 15 deletions tools/mo/unit_tests/mo/ops/scatter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from generator import generator, generate

from openvino.tools.mo.ops.scatter import ScatterElementsUpdate
from openvino.tools.mo.ops.scatter import ScatterElementsUpdate, ScatterUpdate
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.graph.graph import Node
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, valued_const_with_data
Expand Down Expand Up @@ -40,20 +40,20 @@ class ScatterElementsInferTest(unittest.TestCase):
[[1.0, 1.1, 3.0, 2.1, 5.0]]),
([ # 3D case
[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]],
[[9, 10],
[11, 12]]
],
[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]],
[[9, 10],
[11, 12]]
],
[
[[1, 0],
[0, 1]],
[[1, 0],
[1, 0]],
[[0, 1],
[1, 0]]
[[1, 0],
[0, 1]],
[[1, 0],
[1, 0]],
[[0, 1],
[1, 0]]
],
[
[[21, 22],
Expand All @@ -73,7 +73,6 @@ class ScatterElementsInferTest(unittest.TestCase):
[32, 31]]
]),
])

def test_scatterelements_value_infer(self, data, indices, updates, axis, ref_res):
nodes = {
**valued_const_with_data('data', np.array(data)),
Expand Down Expand Up @@ -101,3 +100,88 @@ def test_scatterelements_value_infer(self, data, indices, updates, axis, ref_res

res_output_value = scatter_el_node.out_node().value
self.assertTrue(np.array_equal(ref_res, res_output_value))


@generator
class ScatterUpdateInferTest(unittest.TestCase):
@generate(*[
([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]],
[[1, 2]],
[[[1.0, 1.1, 1.2],
[2.0, 2.1, 2.2]]],
0,
[[0.0, 0.0, 0.0],
[1.0, 1.1, 1.2],
[2.0, 2.1, 2.2]]),
# negative axis
([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]],
[[1, 2]],
[[[1.0, 1.1]],
[[1.2, 2.0]],
[[2.1, 2.2]]],
-1,
[[0.0, 1.0, 1.1],
[0.0, 1.2, 2.0],
[0.0, 2.1, 2.2]]),
# one element
([[[0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.]]],
[[1]],
[[[[1., 2.], [3., 4.], [5., 6.]]]],
0,
[[[0., 0.], [0., 0.], [0., 0.]],
[[1., 2.], [3., 4.], [5., 6.]],
[[0., 0.], [0., 0.], [0., 0.]]]),
# shape [2,3,3]
([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
# indices [3,2]
[[1, 2], [0, 1], [1, 2]],
# updates [2,3,2,3]
[[[[1., 2., 3.], [4., 5., 6.]],
[[7., 8., 9.], [9., 8., 7.]],
[[6., 5., 4.], [3., 2., 1.]]],
[[[1., 2., 3.], [4., 5., 6.]],
[[7., 8., 9.], [9., 8., 7.]],
[[6., 5., 4.], [3., 2., 1.]]]],
# axis
1,
# ref
[[[7., 8., 9.], [6., 5., 4.], [3., 2., 1.]],
[[7., 8., 9.], [6., 5., 4.], [3., 2., 1.]]]),
])
def test_scatter_update_value_infer(self, data, indices, updates, axis, ref_res):
nodes = {
**valued_const_with_data('data', np.array(data)),
**valued_const_with_data('indices', int64_array(indices)),
**valued_const_with_data('updates', np.array(updates)),
**valued_const_with_data('axis', int64_array(axis)),
**regular_op_with_empty_data('scatter_update', {'op': 'ScatterUpdate', 'axis': axis}),
**result()
}

graph = build_graph(nodes_attrs=nodes, edges=[
*connect('data', '0:scatter_update'),
*connect('indices', '1:scatter_update'),
*connect('updates', '2:scatter_update'),
*connect('axis', '3:scatter_update'),
*connect('scatter_update', 'output')
], nodes_with_edges_only=True)
graph.stage = 'middle'

scatter_update_node = Node(graph, 'scatter_update')
ScatterUpdate.infer(scatter_update_node)

res_output_shape = scatter_update_node.out_node().shape
self.assertTrue(np.array_equal(int64_array(ref_res).shape, res_output_shape))

res_output_value = scatter_update_node.out_node().value
self.assertTrue(np.array_equal(ref_res, res_output_value))

0 comments on commit 56808c7

Please sign in to comment.