-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][A-44] Add type annotations for paddle/optimizer/lbfgs.py
#65308
Merged
Merged
Changes from 7 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
58f53f8
fix
enkilee c400d3c
Merge branch 'PaddlePaddle:develop' into typing-a44-lbfgs
enkilee 18a1959
fix
enkilee 59db89f
Merge branch 'typing-a44-lbfgs' of https://github.com/enkilee/Paddle …
enkilee 96d160c
fix
enkilee 61919b2
fix
enkilee 139aa31
fix
enkilee 7ca8924
fix
enkilee 40a1277
fix
enkilee f84f2bc
add virtual method `__iter__` and fix sample code
SigureMo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,17 +12,32 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from functools import reduce | ||
from typing import TYPE_CHECKING, NoReturn, Sequence, TypedDict | ||
|
||
import paddle | ||
|
||
from ..base import framework | ||
from .optimizer import Optimizer | ||
|
||
if TYPE_CHECKING: | ||
from paddle import Tensor | ||
from paddle.nn.clip import GradientClipBase | ||
from paddle.regularizer import WeightDecayRegularizer | ||
|
||
from .optimizer import _ParameterConfig | ||
|
||
__all__ = [] | ||
|
||
|
||
class _LbfgsStateDict(TypedDict): | ||
state: str | ||
packed_state: _LbfgsStateDict | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 啊哈哈,我没看明白,他代码返回的是dict,然后dict里面给了键 state, 和state对应的dict的键的值 |
||
|
||
|
||
def dot(x, y): | ||
r""" | ||
NOTE: This is a temporary workaround for unstable result computed by `paddle.dot`, | ||
|
@@ -333,28 +348,28 @@ class LBFGS(Optimizer): | |
learning_rate (float, optional): learning rate .The default value is 1. | ||
max_iter (int, optional): maximal number of iterations per optimization step. | ||
The default value is 20. | ||
max_eval (int, optional): maximal number of function evaluations per optimization | ||
max_eval (int|None, optional): maximal number of function evaluations per optimization | ||
step. The default value is max_iter * 1.25. | ||
tolerance_grad (float, optional): termination tolerance on first order optimality | ||
The default value is 1e-5. | ||
tolerance_change (float, optional): termination tolerance on function | ||
value/parameter changes. The default value is 1e-9. | ||
history_size (int, optional): update history size. The default value is 100. | ||
line_search_fn (string, optional): either 'strong_wolfe' or None. The default value is strong_wolfe. | ||
parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \ | ||
line_search_fn (string|None, optional): either 'strong_wolfe' or None. The default value is strong_wolfe. | ||
parameters (list|tuple|None, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \ | ||
This parameter is required in dygraph mode. The default value is None. | ||
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ | ||
weight_decay (float|WeightDecayRegularizer|None, optional): The strategy of regularization. \ | ||
It canbe a float value as coeff of L2 regularization or \ | ||
:ref:`api_paddle_regularizer_L1Decay`, :ref:`api_paddle_regularizer_L2Decay`. | ||
If a parameter has set regularizer using :ref:`api_paddle_ParamAttr` already, \ | ||
the regularization setting here in optimizer will be ignored for this parameter. \ | ||
Otherwise, the regularization setting here in optimizer will take effect. \ | ||
Default None, meaning there is no regularization. | ||
grad_clip (GradientClipBase, optional): Gradient clipping strategy, it's an instance of \ | ||
grad_clip (GradientClipBase|None, optional): Gradient clipping strategy, it's an instance of \ | ||
some derived class of ``GradientClipBase`` . There are three clipping strategies \ | ||
( :ref:`api_paddle_nn_ClipGradByGlobalNorm` , :ref:`api_paddle_nn_ClipGradByNorm` , \ | ||
:ref:`api_paddle_nn_ClipGradByValue` ). Default None, meaning there is no gradient clipping. | ||
name (str, optional): Normally there is no need for user to set this property. | ||
name (str|None, optional): Normally there is no need for user to set this property. | ||
For more information, please refer to :ref:`api_guide_Name`. | ||
The default value is None. | ||
|
||
|
@@ -405,18 +420,18 @@ class LBFGS(Optimizer): | |
|
||
def __init__( | ||
self, | ||
learning_rate=1.0, | ||
max_iter=20, | ||
max_eval=None, | ||
tolerance_grad=1e-7, | ||
tolerance_change=1e-9, | ||
history_size=100, | ||
line_search_fn=None, | ||
parameters=None, | ||
weight_decay=None, | ||
grad_clip=None, | ||
name=None, | ||
): | ||
learning_rate: float = 1.0, | ||
max_iter: int = 20, | ||
max_eval: int | None = None, | ||
tolerance_grad: float = 1e-7, | ||
tolerance_change: float = 1e-9, | ||
history_size: int = 100, | ||
line_search_fn: str | None = None, | ||
parameters: Sequence[Tensor] | Sequence[_ParameterConfig] | None = None, | ||
weight_decay: float | WeightDecayRegularizer | None = None, | ||
grad_clip: GradientClipBase | None = None, | ||
name: str | None = None, | ||
) -> None: | ||
if max_eval is None: | ||
max_eval = max_iter * 5 // 4 | ||
|
||
|
@@ -452,7 +467,7 @@ def __init__( | |
|
||
self._numel_cache = None | ||
|
||
def state_dict(self): | ||
def state_dict(self) -> _LbfgsStateDict: | ||
r"""Returns the state of the optimizer as a :class:`dict`. | ||
|
||
Return: | ||
|
@@ -505,7 +520,7 @@ def state_dict(self): | |
|
||
return {'state': packed_state} | ||
|
||
def _numel(self): | ||
def _numel(self) -> int: | ||
# compute the number of all parameters | ||
if self._numel_cache is None: | ||
self._numel_cache = reduce( | ||
|
@@ -553,7 +568,7 @@ def _directional_evaluate(self, closure, x, alpha, d): | |
return loss, flat_grad | ||
|
||
@framework.non_static_only | ||
def step(self, closure): | ||
def step(self, closure) -> Tensor: | ||
"""Performs a single optimization step. | ||
|
||
Args: | ||
|
@@ -778,7 +793,7 @@ def obj_func(x, alpha, d): | |
|
||
def minimize( | ||
self, loss, startup_program=None, parameters=None, no_grad_set=None | ||
): | ||
) -> NoReturn: | ||
"""Empty method. LBFGS optimizer does not use this way to minimize ``loss``. Please refer 'Examples' of LBFGS() above for usage.""" | ||
raise NotImplementedError( | ||
"LBFGS optimizer does not use this way to minimize loss. Please refer 'Examples' of LBFGS() for usage." | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还有几个 api 别漏了哈 ~
state_dict
step
minimize
~~~咱们这次主要是关注
公开 api
,重点关注不以_
开头的那些函数或者方法 ~另外,ISSUE 里面统计的每个文件中 api 数量可能不准,比如,有的是一个类,只统计了
1
次,但是类里面的方法其实都需要修改 ~ 可以参考官网有哪些接口,宁缺勿漏宁多勿漏 吧 ~ 🤟🤟🤟There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到