Skip to content

Commit

Permalink
[MO|nGraph]GatherND_8 (#7743)
Browse files Browse the repository at this point in the history
* Add GatherND_8 operation

* Update shape infer function and tests

* Initial commit for nGraph GatherND_8 operation

* Add GatherNDBase class implementation

* Fix base class errors

* Add missrd header

* Update base class

* Update GatherND_8 implementation

* Fix codestyle

* Fix wrong rank

* Implement tests for gatherND_8 shape inference function

* fix codestyle

* Add limitation to doc

* Siplyfy check in shape inference

* Add more test cases

* Update shape inference function

* Add more test cases to cover all case with dynamic input shapes

* Update shape inference function

* Refactor tests

* Add visitor tests for gatherND_8 operation

* Correct comment

* Add additional check is shape inference function

* Update shape inference implementation for gathernd operartion

* Fix codestyle

* Remove restriction for data is fully defined

* Update shape inference functon

* Fix missed check for nonetype

* Remove redundant checks for batch_dims

* Fix codestyle
  • Loading branch information
Anton Chetverikov authored Nov 10, 2021
1 parent 76994c6 commit c8e1c8e
Show file tree
Hide file tree
Showing 12 changed files with 643 additions and 138 deletions.
2 changes: 1 addition & 1 deletion docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ Some TensorFlow\* operations do not match to any Inference Engine layer, but are
| GRU | |
| Gather | |
| GatherElements | Doesn't work with negative indices |
| GatherND | |
| GatherND | Doesn't work with negative indices |
| GatherTree | |
| Gemm | |
| GlobalAveragePool | |
Expand Down
78 changes: 46 additions & 32 deletions model-optimizer/extensions/ops/gathernd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': self.op,
'op': self.op,
'version': 'opset5',
'version': 'opset8',
'infer': self.infer,
'in_ports_count': 2,
'out_ports_count': 1,
Expand Down Expand Up @@ -56,41 +56,55 @@ def infer(node: Node):
assert len(indices_shape) > 0, "Indices must not be a scalar"
assert (batch_dims + indices_shape[-1]) <= len(data_shape), \
"Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions"
assert node['version'] in ['opset5', 'opset8'], 'Unsupported version of GatherND operation: {}, operation ' \
'name : {}'.format(node['version'], node.soft_get('name'))

# compute output shape
batch = []
if batch_dims > 0:
if is_fully_defined(data_shape[:batch_dims]):
batch = [np.prod(data_shape[:batch_dims]).tolist()]
else:
batch = [dynamic_dimension_value]
else:
batch = []
if node['version'] == 'opset5': # Support old version of gatherND shape inference
if is_fully_defined(data_shape[:batch_dims]):
batch = [np.prod(data_shape[:batch_dims]).tolist()]
else:
batch = [dynamic_dimension_value]
elif node['version'] == 'opset8':
for dim in range(batch_dims):
assert compatible_dims(indices_shape[dim], data_shape[dim]),\
"Batch dimensions in data.shape and indices.shape must be compatible"
if is_fully_defined(indices_shape[:batch_dims]):
batch = indices_shape[:batch_dims].tolist()
elif is_fully_defined(data_shape[:batch_dims]):
batch = data_shape[:batch_dims].tolist()
else:
for ind in range(batch_dims):
if indices_shape[ind] != dynamic_dimension_value:
batch.append(indices_shape[ind])
elif data_shape[ind] != dynamic_dimension_value:
batch.append(data_shape[ind])
else:
batch.append(dynamic_dimension_value)

slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):])
output_shape = batch + list(indices_shape[batch_dims:-1]) + slice_shape

output_shape = batch + list(indices_shape)[batch_dims:-1] + slice_shape
node.out_port(0).data.set_shape(output_shape)

# compute output value if all input values are defined
if is_fully_defined(indices_value) and is_fully_defined(data_value):
output_value = np.zeros(output_shape, dtype=data_value.dtype)
if batch_dims == 0:
output_indices_range = int64_array(indices_shape[:-1])
for output_index in np.ndindex(tuple(output_indices_range)):
indices_tuple = indices_value[output_index]
output_value[output_index] = data_value[tuple(indices_tuple.T)]
else:
batch_dims_range = int64_array(indices_shape[:batch_dims])
for batch_indices in np.ndindex(tuple(batch_dims_range)):
# compute batch index in output tensor
batch_ind = 0
num_elements = 1
for ind in reversed(range(len(batch_dims_range))):
batch_ind += batch_indices[ind] * num_elements
num_elements *= batch_dims_range[ind]
output_indices_range = int64_array(indices_shape[batch_dims:-1])
for output_index in np.ndindex(tuple(output_indices_range)):
tmp_ind = batch_indices + output_index
indices_tuple = tuple(indices_value[tmp_ind].T)
full_input_ind = batch_indices + indices_tuple
full_output_ind = tuple(np.array([batch_ind]).T) + output_index
output_value[full_output_ind] = data_value[full_input_ind]
# compute output value if all input indices are defined
if is_fully_defined(indices_value) and data_value is not None:
batch_dims_size = 1

for i in range(batch_dims):
batch_dims_size *= indices_shape[i]

output_data = []

reshaped_indices = indices_value.reshape(batch_dims_size, -1, indices_shape[-1])

reshaped_data = data_value.reshape((batch_dims_size,) + tuple((data_shape[batch_dims:])))

for batch_dim in range(reshaped_indices.shape[0]):
for outer_dim in range(reshaped_indices.shape[1]):
gather_index = tuple(reshaped_indices[batch_dim][outer_dim])
output_data.append(reshaped_data[(batch_dim,) + gather_index])
output_value = np.asarray(output_data, dtype=data_value.dtype).reshape(output_shape)
node.out_port(0).data.set_value(output_value)
Loading

0 comments on commit c8e1c8e

Please sign in to comment.