-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TF frontend] add some "Segment" and "UnsortedSegment" ops #6928
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice algorithm @alter-xp . I will review a bit more tomorrow. General question:
- Could you add a bit more comment to the generated TIR part?
- Would this also work for GPU?
python/tvm/topi/tensor.py
Outdated
@@ -73,3 +75,341 @@ def full_like(x, fill_value): | |||
The result. | |||
""" | |||
return cpp.full_like(x, fill_value) | |||
|
|||
|
|||
def segment_max(data, segment_ids, num_out): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a bit more comments through this file? This would make it easier to read and also more future proof
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌
python/tvm/topi/tensor.py
Outdated
with ib.for_range(0, num_segment) as n: | ||
with ib.for_range(0, inner_size) as j: | ||
out_index = n * inner_size + j | ||
out[out_index] = -3.4028235e38 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a better way to do this? In theory you could pre-compute also the segment sizes. Or you could calculate the sizes on the fly. At least, I would put something like: sys.float_info.min
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sys.float_info.min
is a number like 2.2250738585072014e-308, , which is not very suitable here, because it is always greater than 0. I replaced it with float("inf")
python/tvm/topi/tensor.py
Outdated
def _segment_min(data, segment_ids, out_buf): | ||
|
||
ib = tir.ir_builder.create() | ||
input_data = ib.buffer_ptr(data) | ||
seg_ids = ib.buffer_ptr(segment_ids) | ||
out = ib.buffer_ptr(out_buf) | ||
|
||
shape = get_const_tuple(data.shape) | ||
num_segment = get_const_tuple(out_buf.shape)[0] | ||
inner_size = 1 | ||
for s in range(1, len(shape)): | ||
inner_size = inner_size * shape[s] | ||
|
||
with ib.for_range(0, num_segment) as n: | ||
with ib.for_range(0, inner_size) as j: | ||
out_index = n * inner_size + j | ||
out[out_index] = 3.4028235e38 | ||
|
||
with ib.for_range(0, shape[0]) as k: | ||
with ib.if_scope(seg_ids[k] == n): | ||
with ib.for_range(0, inner_size) as l: | ||
out_index = n * inner_size + l | ||
in_index = k * inner_size + l | ||
out[out_index] = te.min(input_data[in_index], out[out_index]) | ||
|
||
return ib.get() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like these function share a common implementation. Could you write a single _segment_op
with a string op
parameters and add some logic inside this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌
python/tvm/topi/tensor.py
Outdated
temp_index[num[0]] = k | ||
num[0] += 1 | ||
|
||
with ib.if_scope(num[0] > 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, if the segment is not present we omit it? Why we don't do this for max and min? Some explanation of this could would be useful, I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At that time, it was to ensure that the division by 0 would not occur during the calculation process. Now modified to only be used in the mean
} | ||
|
||
|
||
def verify_segmet(name, data_shape, segmnet_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo verify_segmet -> verify_segment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌
segment_ids.append(segment_ids[-1]) | ||
return np.array(segment_ids).astype("int32") | ||
|
||
def get_ref_data(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using tflite to get the reference data, could you add some reference data manually? In this way tests can be independent and also show what the goal of each function is (also, you already added reference tests in test_forward.py
). What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, thanks for all your comments. All suggestions have been revised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @alter-xp for addressing the comments! LGTM
when testing for unsorted_segment_mean, a single constant data will as input for strideslice op in tf graph. But tf frontend in tvm not support this situation. fix this in pr #6949. |
66839b4
to
dac0351
Compare
@giuseros hi, can you help me see what's wrong with this branch? It hasn't been merged |
* segment_max, segment_min, segment_mean, segment_sum, segment_prod * unsorted_segment_max, unsorted_segment_min, unsorted_segment_mean * unsorted_segment_prod, unsorted_segment_sum
This PR appears to be out of date, please feel free to reopen it if this is not the case. As part of the new year we are attempting to triage the project's open pull requests to ensure that code which Thanks again for your contribution, and feel free to reach out to discuss these changes. |
@giuseros @siju-samuel