Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MO|nGraph]GatherND_8 #7743

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d5df537
Add GatherND_8 operation
Sep 29, 2021
48e6176
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Sep 29, 2021
e7e9f19
Update shape infer function and tests
Oct 1, 2021
c07752e
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 1, 2021
aecfc65
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 5, 2021
c41f5ad
Initial commit for nGraph GatherND_8 operation
Oct 6, 2021
c50914f
Add GatherNDBase class implementation
Oct 8, 2021
212f72e
Fix base class errors
Oct 8, 2021
49ce072
Add missrd header
Oct 8, 2021
f5e9ea6
Update base class
Oct 8, 2021
f134c95
Update GatherND_8 implementation
Oct 8, 2021
b36a565
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 8, 2021
58ac4cf
Fix codestyle
Oct 11, 2021
b001a3f
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 11, 2021
214095d
Fix wrong rank
Oct 11, 2021
a6b7d16
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 11, 2021
1152419
Implement tests for gatherND_8 shape inference function
Oct 11, 2021
7380d44
fix codestyle
Oct 11, 2021
cd85be1
Add limitation to doc
Oct 11, 2021
b78131a
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 12, 2021
6a0cbef
Siplyfy check in shape inference
Oct 15, 2021
afd02f8
Add more test cases
Oct 15, 2021
29314ec
Update shape inference function
Oct 15, 2021
5819af2
Add more test cases to cover all case with dynamic input shapes
Oct 15, 2021
cfcfa0e
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 15, 2021
c689aa6
Update shape inference function
Oct 18, 2021
b95d1eb
Refactor tests
Oct 18, 2021
1579088
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 18, 2021
20e605c
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 19, 2021
f5ca17d
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 20, 2021
76697d2
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 21, 2021
e803965
Add visitor tests for gatherND_8 operation
Oct 22, 2021
f0f834c
Correct comment
Oct 22, 2021
efe6c5b
Add additional check is shape inference function
Oct 25, 2021
3cd22ed
Update shape inference implementation for gathernd operartion
Oct 25, 2021
c82bbf0
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 25, 2021
1e19448
Fix codestyle
Oct 25, 2021
a498aa4
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 25, 2021
9606771
Remove restriction for data is fully defined
Oct 26, 2021
dbb3059
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 26, 2021
157465e
Resolve merge conflict
Oct 27, 2021
b6d612e
Update shape inference functon
Oct 27, 2021
af8df51
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 28, 2021
e7495b4
Fix missed check for nonetype
Oct 28, 2021
0180965
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Oct 28, 2021
5b51350
Remove redundant checks for batch_dims
Oct 28, 2021
9a8978b
Fix merge conflict
Nov 3, 2021
795e31b
Fix codestyle
Nov 3, 2021
25a2850
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Nov 8, 2021
ad55107
Merge remote-tracking branch 'upstream/master' into feature/achetver/…
Nov 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ Some TensorFlow\* operations do not match to any Inference Engine layer, but are
| GRU | |
| Gather | |
| GatherElements | Doesn't work with negative indices |
| GatherND | |
| GatherND | Doesn't work with negative indices |
| GatherTree | |
| Gemm | |
| GlobalAveragePool | |
Expand Down
78 changes: 46 additions & 32 deletions model-optimizer/extensions/ops/gathernd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': self.op,
'op': self.op,
'version': 'opset5',
'version': 'opset8',
'infer': self.infer,
'in_ports_count': 2,
'out_ports_count': 1,
Expand Down Expand Up @@ -56,41 +56,55 @@ def infer(node: Node):
assert len(indices_shape) > 0, "Indices must not be a scalar"
assert (batch_dims + indices_shape[-1]) <= len(data_shape), \
"Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions"
assert node['version'] in ['opset5', 'opset8'], 'Unsupported version of GatherND operation: {}, operation ' \
'name : {}'.format(node['version'], node.soft_get('name'))

# compute output shape
batch = []
if batch_dims > 0:
if is_fully_defined(data_shape[:batch_dims]):
batch = [np.prod(data_shape[:batch_dims]).tolist()]
else:
batch = [dynamic_dimension_value]
else:
batch = []
if node['version'] == 'opset5': # Support old version of gatherND shape inference
if is_fully_defined(data_shape[:batch_dims]):
batch = [np.prod(data_shape[:batch_dims]).tolist()]
achetver marked this conversation as resolved.
Show resolved Hide resolved
else:
batch = [dynamic_dimension_value]
elif node['version'] == 'opset8':
for dim in range(batch_dims):
assert compatible_dims(indices_shape[dim], data_shape[dim]),\
"Batch dimensions in data.shape and indices.shape must be compatible"
if is_fully_defined(indices_shape[:batch_dims]):
batch = indices_shape[:batch_dims].tolist()
elif is_fully_defined(data_shape[:batch_dims]):
batch = data_shape[:batch_dims].tolist()
else:
for ind in range(batch_dims):
if indices_shape[ind] != dynamic_dimension_value:
achetver marked this conversation as resolved.
Show resolved Hide resolved
batch.append(indices_shape[ind])
elif data_shape[ind] != dynamic_dimension_value:
batch.append(data_shape[ind])
else:
batch.append(dynamic_dimension_value)

slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):])
output_shape = batch + list(indices_shape[batch_dims:-1]) + slice_shape

output_shape = batch + list(indices_shape)[batch_dims:-1] + slice_shape
node.out_port(0).data.set_shape(output_shape)

# compute output value if all input values are defined
if is_fully_defined(indices_value) and is_fully_defined(data_value):
output_value = np.zeros(output_shape, dtype=data_value.dtype)
if batch_dims == 0:
output_indices_range = int64_array(indices_shape[:-1])
for output_index in np.ndindex(tuple(output_indices_range)):
indices_tuple = indices_value[output_index]
output_value[output_index] = data_value[tuple(indices_tuple.T)]
else:
batch_dims_range = int64_array(indices_shape[:batch_dims])
for batch_indices in np.ndindex(tuple(batch_dims_range)):
# compute batch index in output tensor
batch_ind = 0
num_elements = 1
for ind in reversed(range(len(batch_dims_range))):
batch_ind += batch_indices[ind] * num_elements
num_elements *= batch_dims_range[ind]
output_indices_range = int64_array(indices_shape[batch_dims:-1])
for output_index in np.ndindex(tuple(output_indices_range)):
tmp_ind = batch_indices + output_index
indices_tuple = tuple(indices_value[tmp_ind].T)
full_input_ind = batch_indices + indices_tuple
full_output_ind = tuple(np.array([batch_ind]).T) + output_index
output_value[full_output_ind] = data_value[full_input_ind]
# compute output value if all input indices are defined
if is_fully_defined(indices_value) and data_value is not None:
batch_dims_size = 1

for i in range(batch_dims):
batch_dims_size *= indices_shape[i]

output_data = []
achetver marked this conversation as resolved.
Show resolved Hide resolved

reshaped_indices = indices_value.reshape(batch_dims_size, -1, indices_shape[-1])

reshaped_data = data_value.reshape((batch_dims_size,) + tuple((data_shape[batch_dims:])))

for batch_dim in range(reshaped_indices.shape[0]):
for outer_dim in range(reshaped_indices.shape[1]):
gather_index = tuple(reshaped_indices[batch_dim][outer_dim])
output_data.append(reshaped_data[(batch_dim,) + gather_index])
output_value = np.asarray(output_data, dtype=data_value.dtype).reshape(output_shape)
node.out_port(0).data.set_value(output_value)
Loading