-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][A-44] Add type annotations for paddle/optimizer/lbfgs.py
#65308
Changes from 4 commits
58f53f8
c400d3c
18a1959
59db89f
96d160c
61919b2
139aa31
7ca8924
40a1277
f84f2bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,14 +12,24 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from functools import reduce | ||
from typing import TYPE_CHECKING, Any, Sequence | ||
|
||
import paddle | ||
|
||
from ..base import framework | ||
from .optimizer import Optimizer | ||
|
||
if TYPE_CHECKING: | ||
from paddle import Tensor | ||
from paddle.nn.clip import GradientClipBase | ||
from paddle.regularizer import WeightDecayRegularizer | ||
|
||
from .optimizer import _ParameterConfig | ||
|
||
__all__ = [] | ||
|
||
|
||
|
@@ -333,28 +343,28 @@ class LBFGS(Optimizer): | |
learning_rate (float, optional): learning rate .The default value is 1. | ||
max_iter (int, optional): maximal number of iterations per optimization step. | ||
The default value is 20. | ||
max_eval (int, optional): maximal number of function evaluations per optimization | ||
max_eval (int|None, optional): maximal number of function evaluations per optimization | ||
step. The default value is max_iter * 1.25. | ||
tolerance_grad (float, optional): termination tolerance on first order optimality | ||
The default value is 1e-5. | ||
tolerance_change (float, optional): termination tolerance on function | ||
value/parameter changes. The default value is 1e-9. | ||
history_size (int, optional): update history size. The default value is 100. | ||
line_search_fn (string, optional): either 'strong_wolfe' or None. The default value is strong_wolfe. | ||
parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \ | ||
line_search_fn (string|None, optional): either 'strong_wolfe' or None. The default value is strong_wolfe. | ||
parameters (list|tuple|None, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \ | ||
This parameter is required in dygraph mode. The default value is None. | ||
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ | ||
weight_decay (float|WeightDecayRegularizer|None, optional): The strategy of regularization. \ | ||
It canbe a float value as coeff of L2 regularization or \ | ||
:ref:`api_paddle_regularizer_L1Decay`, :ref:`api_paddle_regularizer_L2Decay`. | ||
If a parameter has set regularizer using :ref:`api_paddle_ParamAttr` already, \ | ||
the regularization setting here in optimizer will be ignored for this parameter. \ | ||
Otherwise, the regularization setting here in optimizer will take effect. \ | ||
Default None, meaning there is no regularization. | ||
grad_clip (GradientClipBase, optional): Gradient clipping strategy, it's an instance of \ | ||
grad_clip (GradientClipBase|None, optional): Gradient clipping strategy, it's an instance of \ | ||
some derived class of ``GradientClipBase`` . There are three clipping strategies \ | ||
( :ref:`api_paddle_nn_ClipGradByGlobalNorm` , :ref:`api_paddle_nn_ClipGradByNorm` , \ | ||
:ref:`api_paddle_nn_ClipGradByValue` ). Default None, meaning there is no gradient clipping. | ||
name (str, optional): Normally there is no need for user to set this property. | ||
name (str|None, optional): Normally there is no need for user to set this property. | ||
For more information, please refer to :ref:`api_guide_Name`. | ||
The default value is None. | ||
|
||
|
@@ -405,18 +415,18 @@ class LBFGS(Optimizer): | |
|
||
def __init__( | ||
self, | ||
learning_rate=1.0, | ||
max_iter=20, | ||
max_eval=None, | ||
tolerance_grad=1e-7, | ||
tolerance_change=1e-9, | ||
history_size=100, | ||
line_search_fn=None, | ||
parameters=None, | ||
weight_decay=None, | ||
grad_clip=None, | ||
name=None, | ||
): | ||
learning_rate: float = 1.0, | ||
max_iter: int = 20, | ||
max_eval: int | None = None, | ||
tolerance_grad: float = 1e-7, | ||
tolerance_change: float = 1e-9, | ||
history_size: int = 100, | ||
line_search_fn: str | None = None, | ||
parameters: Sequence[Tensor] | Sequence[_ParameterConfig] | None = None, | ||
weight_decay: float | WeightDecayRegularizer | None = None, | ||
grad_clip: GradientClipBase | None = None, | ||
name: str | None = None, | ||
) -> None: | ||
if max_eval is None: | ||
max_eval = max_iter * 5 // 4 | ||
|
||
|
@@ -452,7 +462,7 @@ def __init__( | |
|
||
self._numel_cache = None | ||
|
||
def state_dict(self): | ||
def state_dict(self) -> dict[str, Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 用 TypedDict |
||
r"""Returns the state of the optimizer as a :class:`dict`. | ||
|
||
Return: | ||
|
@@ -505,7 +515,7 @@ def state_dict(self): | |
|
||
return {'state': packed_state} | ||
|
||
def _numel(self): | ||
def _numel(self) -> int: | ||
# compute the number of all parameters | ||
if self._numel_cache is None: | ||
self._numel_cache = reduce( | ||
|
@@ -553,7 +563,7 @@ def _directional_evaluate(self, closure, x, alpha, d): | |
return loss, flat_grad | ||
|
||
@framework.non_static_only | ||
def step(self, closure): | ||
def step(self, closure) -> Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个应该是 Tensor? |
||
"""Performs a single optimization step. | ||
|
||
Args: | ||
|
@@ -778,7 +788,7 @@ def obj_func(x, alpha, d): | |
|
||
def minimize( | ||
self, loss, startup_program=None, parameters=None, no_grad_set=None | ||
): | ||
) -> Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 用 NoReturn There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 收到 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我这里说的不是 |
||
"""Empty method. LBFGS optimizer does not use this way to minimize ``loss``. Please refer 'Examples' of LBFGS() above for usage.""" | ||
raise NotImplementedError( | ||
"LBFGS optimizer does not use this way to minimize loss. Please refer 'Examples' of LBFGS() for usage." | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还有几个 api 别漏了哈 ~
state_dict
step
minimize
~~~咱们这次主要是关注
公开 api
,重点关注不以_
开头的那些函数或者方法 ~另外,ISSUE 里面统计的每个文件中 api 数量可能不准,比如,有的是一个类,只统计了
1
次,但是类里面的方法其实都需要修改 ~ 可以参考官网有哪些接口,宁缺勿漏宁多勿漏 吧 ~ 🤟🤟🤟There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到