-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle Tensorrt] add pd_op.greater_than, pd_op.less_than marker and converter #68686
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
layer_output = add_elementwise_layer( | ||
network, paddle_op, inputs, trt.ElementWiseOperation.GREATER | ||
) | ||
return trt_cast(network, layer_output, trt.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不可能只支持float32类型吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
layer_output = add_elementwise_layer( | ||
network, paddle_op, inputs, trt.ElementWiseOperation.LESS | ||
) | ||
return trt_cast(network, layer_output, trt.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
|
||
@converter_registry.register("pd_op.greater_than", trt_version="8.x") | ||
def greater_than_converter(network, paddle_op, inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议Binray的Converter可以合并
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
def setUp(self): | ||
self.python_api = paddle.greater_than | ||
self.api_args = { | ||
"x": np.random.randn(3).astype(np.float32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整体没啥问题,但是单测是否覆盖更多的类型,例如INT、float16等类型会更好。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if paddle_op.name() == "pd_op.greater_than": | ||
layer_output = add_elementwise_layer( | ||
network, paddle_op, inputs, trt.ElementWiseOperation.GREATER | ||
) | ||
else: | ||
layer_output = add_elementwise_layer( | ||
network, paddle_op, inputs, trt.ElementWiseOperation.LESS | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里分支处理并不严谨,建议改成这样
if paddle_op.name() == "pd_op.greater_than": | |
layer_output = add_elementwise_layer( | |
network, paddle_op, inputs, trt.ElementWiseOperation.GREATER | |
) | |
else: | |
layer_output = add_elementwise_layer( | |
network, paddle_op, inputs, trt.ElementWiseOperation.LESS | |
) | |
if paddle_op.name() == "pd_op.greater_than": | |
layer_output = add_elementwise_layer( | |
network, paddle_op, inputs, trt.ElementWiseOperation.GREATER | |
) | |
elif paddle_op.name() == "pd_op.less_than":: | |
layer_output = add_elementwise_layer( | |
network, paddle_op, inputs, trt.ElementWiseOperation.LESS | |
) | |
else: | |
throw xxx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
return false; | ||
} | ||
#if IS_TRT_VERSION_LT(8400) | ||
VLOG(3) << "LessThanOp is not supported when TensorRT < 8.4"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VLOG(3) << "LessThanOp is not supported when TensorRT < 8.4"; | |
VLOG(3) << "pd_op.less_than is not supported when TensorRT < 8.4"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
auto x_dtype = pir::GetDataTypeFromValue(x); | ||
auto y_dtype = pir::GetDataTypeFromValue(y); | ||
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) { | ||
VLOG(3) << "Greater_than op do not support bool datatype"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VLOG(3) << "Greater_than op do not support bool datatype"; | |
VLOG(3) << "pd_op.greater_than op do not support bool datatype"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
62f57eb
PR Category
Inference
PR Types
New features
Description
card-71500
添加greater_than op 和 less_than op的marker、converter和单测