-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI] Average Pool2D Bug. #3607
Conversation
As #3581 discussed, TFLite's tvm.sum type is UINT16 (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/optimized/optimized_ops.h#L2976), not int32. Could we consider this corner case? @tqchen 's said, we could compose this op to handle it, if so, could we add one test to cover? |
@FrozenGene Upcasting to uint16 as done in TFLite is little out of context in this PR. For now, I have added one more test case checking Let me know your comments on the CPP code changes. I am not very familiar with this section of code. Lets try to get this in quickly because this is a functional bug which will certainly affect automatic quantization accuracy as well. |
Array<Expr> indices; | ||
for (const Var& var : output) indices.push_back(var); | ||
indices.Set(height_axis, output[height_axis] * stride_height + dheight); | ||
indices.Set(width_axis, output[width_axis] * stride_width + dwidth); | ||
return tvm::sum(temp(indices) / divide_factor, { dheight, dwidth }); | ||
}; | ||
return tvm::sum(temp(indices), { dheight, dwidth }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we could leave one comment of TODO here, it is better. Remind us there is one potential issue we should care in the future parsing TFLite model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@FrozenGene If you mean int8<->uint8 casting for tflite, I think the TODO can be put into tflite frontend instead. The bugfix here in topi is complete IMO. what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean the tvm.sum
type should be uint16
for tflite's average_pool2d, not int32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For TFLite adaptation, we can cast
the input to uint16
and then pass the casted input to avg_pool2d
. This is purely the job of framework parser. It does not need any code in the TOPI layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes my understanding is avg_pool2d(int_8_value.astype("uint16")).astype("int8")
will be in frontend, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the way you want to do. Even the output dtype is "int32", you also want to do avg_pool2d(int_8_value.astype("int32")).astype("int8")
in frontend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so, the way is ok.
Thanks @FrozenGene @anijain2305 @yzhliu . |
* [TOPI] Average Pool2D Bug. Issue - apache#3581 * Add uint16 test.
* [TOPI] Average Pool2D Bug. Issue - apache#3581 * Add uint16 test.
* [TOPI] Average Pool2D Bug. Issue - apache#3581 * Add uint16 test.
Issue - #3581
@yzhliu @tqchen @FrozenGene @shoubhik @rankyung