-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NPU] add where_index op and tests #34951
Conversation
Thanks for your contribution! |
const auto& booled_runner = | ||
NpuOpRunner("Cast", {*condition}, {booled_cond}, | ||
{{"dst_type", static_cast<int>(bool_type)}}); | ||
booled_runner.Run(stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据WhereIndexOpMaker,AddInput("Condition", "A bool tensor whose rank is at least 1"); 这里的 condition 数据类型必须为bool类型,可以不需要Cast,但新增 PADDLE_ENFORCE_EQ保证输入为bool类型。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多谢,如沟通,where_index算子支持上层API paddle.nonzero和paddle.fluid.layers.where,需支持多种数据类型的输入,此处cast保证多种类型数据(非bool)场景下后面计算Tensor中true/非0值个数的正确性。
|
||
if (true_num == 0) { | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面这段逻辑貌似稍微有点复杂哦,这里是总共做了几步操作
- BOOL -> INT64 类型
- 然后对数值进行reducesum求和
- 然后把和从NPU拷贝到CPU端
- 判断CPU端的和是否为0,如果是0,就直接返回
问下这段逻辑可以直接省掉,直接调用Where OP吗?还是说直接调用Where会出错?如果出错的话可以试一下NonZero这个算子。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多谢,已尝试直接调用Where或NonZero OP,目前均无法完成计算:
两个OP都是动态shape的Output,但需要调用实际算子前指定Output的shape,shape维度可以是准确值或较大合理值,但当前交互机制无法在调用NPU算子后获取Output的准确shape信息,所以较大合理值方式无法使用;采用设置准确值的方式需要进行前序计算——根据ReduceSum支持的类型进行Cast和结果回传,其中ReduceSum的结果为单一int64_t。
|
||
class TestNotBool(TestWhereIndexOp): | ||
def init_config(self): | ||
self.inputs = {'Condition': np.array([1, 0, 8]), } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的单测设计和Paddle的CPU/CUDA端的代码, test_where_index.py不太一样哦,这里Condition输入应该只接受BOOL的数据类型。参考 test_where_index.py 修改一下单测吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多谢,评估where_index算子需支持多种数据类型的输入,此处CPU/GPU UT代码中未包含相应的检查,此处添加非bool类型的UT验证对非bool类型输入处理的正确性。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for shareDataWith
PR types
Others
PR changes
OPs
Describe
add NPU kernel of op where_index and tests