-
Notifications
You must be signed in to change notification settings - Fork 796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: AddBiasResidualLayerNorm #9906
Conversation
zobinHuang
commented
Feb 26, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
宏用的优点多,有时候函数或者lambda应该是更好的选择
if (nb_skip >= 1) { new_op = new_op.Input("skip"); } | ||
new_op = new_op.Output("y").Output("mean").Output("inv_variance"); | ||
|
||
std::shared_ptr<OpExpr> op_pointer = CHECK_JUST(new_op.Build()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::shared_ptr<OpExpr> op_pointer = CHECK_JUST(new_op.Build()); | |
std::shared_ptr<OpExpr> op_expr = CHECK_JUST(new_op.Build()); |
除非类型去要特别强调或者避免歧义的情况下,否则一般不把类型作为命名的一部分,因为这部分的语音是冗余的
if (has_gamma) { new_op = new_op.Input("gamma"); } | ||
if (has_beta) { new_op = new_op.Input("beta"); } | ||
if (has_bias) { new_op = new_op.Input("bias"); } | ||
if (nb_skip >= 1) { new_op = new_op.Input("skip"); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把
for (nb_skip = 0; nb_skip <= 1; nb_skip++)
改成 {false, true} 吧,否则这里还需要更特殊的处理才正确
for (bool has_beta : bool_list) { | ||
/* has_bias */ | ||
for (bool has_bias : bool_list) { | ||
one::OpBuilder new_op = one::OpBuilder("skip_layer_norm").Input("x"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one::OpBuilder new_op = one::OpBuilder("skip_layer_norm").Input("x"); | |
one::OpBuilder new_op = one::OpBuilder("skip_layer_norm").Input("x"); |
new_op 应该命名为 op_builder 或者 builder
const double& epsilon, const double& alpha) const { | ||
// check shape of x | ||
const auto& x_shape = *(x->shape()); | ||
CHECK_GT_OR_RETURN(x_shape.NumAxes(), 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个检查可以考虑换成 GE(2)
CHECK_GT_OR_RETURN(x_shape.NumAxes(), 1) | ||
<< "number of axes of \'x\' should have be greater than 1, yet get " << x_shape.NumAxes(); | ||
|
||
#define GAMMA_BETA_BIAS_SHAPE_CHECK(tensor) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉这里用lambda更合适
// set output shape of mean and varience | ||
DimVector mean_dim_vec; | ||
mean_dim_vec.push_back(x_shape.Count(0, x_shape.NumAxes() - 1)); | ||
Shape mean_shape(mean_dim_vec); // borrow from input shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shape mean_shape(mean_dim_vec); // borrow from input shape | |
Shape mean_shape(mean_dim_vec); |
如果有注释要注意正确性
<< "data type of \'gamma\' is not consitant with \'x\'"; | ||
} | ||
|
||
// check data type of pre_bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// check data type of pre_bias | |
// check data type of bias |
<< "data type of \'beta\' is not consitant with \'x\'"; | ||
} | ||
|
||
// check data types of pre_residual_1 and pre_residual_2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// check data types of pre_residual_1 and pre_residual_2 | |
// check data types of skip |
template<typename SRC, typename DST> | ||
struct SkipLoad { | ||
using LoadType = DST; | ||
SkipLoad(const SRC* src, const SRC* bias, const SRC* skip, double alpha, int64_t row_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里alpha用float而不是double,因为我们不希望出现double的计算
// obtain epsilon and check its value | ||
const double epsilon = ctx->Attr<double>("epsilon"); | ||
const double alpha = ctx->Attr<double>("alpha"); | ||
CHECK_GE(epsilon, CUDNN_BN_MIN_EPSILON); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个没有必要了
// check shape of x | ||
const auto& x_shape = *(x->shape()); | ||
CHECK_GE_OR_RETURN(x_shape.NumAxes(), 2) | ||
<< "number of axes of \'x\' should have be greater than 1, yet get " << x_shape.NumAxes(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里错误信息同步修改一下
<< "number of axes of \'gamma\' should have be equal to 1, yet get " | ||
<< gamma_shape.NumAxes(); | ||
CHECK_EQ_OR_RETURN(gamma_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) | ||
<< "dimension 1 of \'gamma\'(" << gamma_shape.At(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的dimension 1
可能会因为从0还是1数有歧义,考虑换个写法
} | ||
|
||
bool has_gamma = false, has_beta = false, has_bias = false; | ||
if (gamma) { has_gamma = true; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has_skip和has_gamma的写法是否可以一致
CHECK_GT_OR_RETURN(x_shape.NumAxes(), 1) | ||
<< "number of axes of \'x\' should have be greater than 1, yet get " << x_shape.NumAxes(); | ||
|
||
#define GAMMA_BETA_BIAS_SHAPE_CHECK(tensor) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里还是用的宏
CHECK_GT(x_shape.NumAxes(), 1) | ||
<< "number of axes of \'x\' should have be greater than 1, yet get " << x_shape.NumAxes(); | ||
|
||
#define GET_GAMMA_BETA_BIAS_AND_SHAPE_CHECK(tensor) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些地方可以不用宏么
Speed stats:
|
CI failed when running job: cuda-module. PR label automerge has been removed |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9906/ |