-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[API Compatibility] Add paddle.compat.nn.Linear
#76169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
b7540cb to
1dafe95
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #76169 +/- ##
==========================================
Coverage ? 97.29%
==========================================
Files ? 2
Lines ? 37
Branches ? 0
==========================================
Hits ? 36
Misses ? 1
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
XiaoguangHu01
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| # KaimingUniform initializer should be more flexible: user should be able to specify place | ||
| expected_place = paddle.base.framework._current_expected_place() | ||
| original_place = self.weight.place | ||
| nn.init.kaiming_uniform_(self.weight, a=sqrt(5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是因为init.kaiming_uniform_会改变place吗,这个API也是最近才加的,如果有bug可以直接改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对。init.kaiming_uniform_ 不管输入的 Tensor 是什么,直接全部变到 _current_expected_place() 的输出上,如果这个接口可以保留一个 place 的输入,那就不需要转换了。当然这个通常来说,用默认 place 计算(比如GPU设备输入的device是GPU,CPU设备输入的device是CPU)是没什么问题的。
| if place_mismatch and in_dynamic_mode(): | ||
| self.weight = self.weight.to(original_place) | ||
| if self.bias is not None: | ||
| # nn.init._calculate_fan_in_and_fan_out(self.weight) for 2D array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.init._calculate_fan_in_and_fan_out实现是不是不对,这个API也是最近才加的,如果有bug可以直接改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.init._calculate_fan_in_and_fan_out 是 PyTorch 的接口,Paddle没有。这里是用了一种等效的写法,Linear不需要这个 API(因为2D情况下比较简单)。
| ) | ||
| self.in_features = in_features | ||
| self.out_features = out_features | ||
| self.weight = self.create_parameter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还有一种初始方式是直接在create_parameter时设置与torch能对齐的weight_attr/bias_attr,看哪种好写吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果只就初始化而言,weight_attr/bias_attr 给 Initializer 的写法更简单,且不需要在动态图时手动转换。但是 reset_parameters 这个函数是 torch.nn.Linear 的一个成员函数,目前也有很多模型在 inplace 重设参数时会调用这个函数,所以我直接复用了 reset_parameters 而没有给 Initializer。reset_initializer 函数中 kaiming_uniform_ 的问题仍然存在。这个初始化的优化我后续补充一个PR简化一下,本PR就先提供对齐的功能。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果只就初始化而言,weight_attr/bias_attr 给 Initializer 的写法更简单,且不需要在动态图时手动转换。但是 reset_parameters 这个函数是 torch.nn.Linear 的一个成员函数,目前也有很多模型在 inplace 重设参数时会调用这个函数,所以我直接复用了 reset_parameters 而没有给 Initializer。reset_initializer 函数中 kaiming_uniform_ 的问题仍然存在。这个初始化的优化我后续补充一个PR简化一下,本PR就先提供对齐的功能。
使用reset_parameters也可以,后面看看这两个点吧:
- torch.nn.init不会改变weight的place,我们这个实现看起来有冗余操作,减少这些额外拷贝比较好。
- _calculate_fan_in_and_fan_out本期也计划新增,是_compute_fans的别名,这里建议复用起来。
zhwesky2010
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个先合入吧。下个PR再看看。
PR Category
User Experience
PR Types
New features
Description
新增了 PyTorch 对齐的
paddle.compat.nn.Linear,调用 #76144 的compat.nn.functional.linear。除了参数用法、数学意义的对齐之外,初始化方法也进行了对齐(kaiming norm 初始化 weight,uniform 初始化 bias)。PaConvert 已全部通过。
TODO:
Pcard-89620