-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][Pytorch]Add Pytorch advanced indexing #6318
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -274,16 +274,18 @@ def _impl(inputs, input_types): | |
end[dim] = min(end[dim], int(inputs[3])) | ||
else: | ||
if isinstance(inputs[3], _expr.Call): | ||
end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) | ||
target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) | ||
else: | ||
end[dim] = inputs[3] | ||
target_end = inputs[3] | ||
|
||
end[dim] = min(end[dim], target_end) | ||
|
||
strides.append(int(inputs[4])) | ||
return _op.transform.strided_slice(data, | ||
begin=_expr.const(begin), | ||
end=_expr.const(end), | ||
strides=_expr.const(strides), | ||
slice_mode="size") | ||
slice_mode="end") | ||
return _impl | ||
|
||
def _split(): | ||
|
@@ -1755,6 +1757,50 @@ def _impl(inputs, input_types): | |
return _impl | ||
|
||
|
||
def _index(): | ||
def _impl(inputs, input_types): | ||
data = inputs[0] | ||
indices = [] | ||
raw_indices = [] | ||
max_indices_len = -1 | ||
for index in inputs[1]: | ||
if not isinstance(index, _expr.Constant): | ||
try: | ||
index = _expr.const(_infer_value(index, {})) | ||
except Exception: | ||
raise RuntimeError("Only supports constant indices for " | ||
"pytorch advanced indexing ") | ||
raw_indices.append(index) | ||
cindex_len = index.data.shape[0] | ||
if cindex_len > max_indices_len: | ||
max_indices_len = cindex_len | ||
|
||
for index in raw_indices: | ||
cnp = index.data.asnumpy() | ||
cindex_len = cnp.shape[0] | ||
if cindex_len < max_indices_len: | ||
cnp = np.tile(cnp, max_indices_len // cindex_len) | ||
indices.append(cnp) | ||
|
||
ret = [] | ||
slice_map = {} | ||
for i in range(indices[0].shape[0]): | ||
tmp = data | ||
current_indices = [] | ||
for index in indices: | ||
current_indices.append(index[i]) | ||
index_key = tuple(current_indices) | ||
if index_key in slice_map: | ||
tmp = slice_map[index_key] | ||
else: | ||
tmp = _op.take(tmp, _expr.const(index[i]), axis=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out of curiosity, would this introduce many take and tile ops? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In most cases it should be fine, but we can definitely improve it by adding a new topi implementation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am okay with this for now. We can implement a topi op if when perf is not good. |
||
slice_map[index_key] = tmp | ||
ret.append(_op.expand_dims(tmp, axis=0)) | ||
|
||
return _op.concatenate(ret, axis=0) | ||
return _impl | ||
|
||
|
||
def _meshgrid(): | ||
def _impl(inputs, input_types): | ||
data = inputs[0] | ||
|
@@ -2060,6 +2106,7 @@ def _get_convert_map(prelude): | |
"aten::type_as" : _type_as(), | ||
"aten::gather" : _gather(), | ||
"aten::index_select" : _select(), | ||
"aten::index" : _index(), | ||
} | ||
return convert_map | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems these two for loops are able to be merged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_indices_len
needs to be fetched first.