Skip to content

Commit

Permalink
[Typing][B-89] Add type annotations for `python/paddle/nn/utils/weigh…
Browse files Browse the repository at this point in the history
…t_norm_hook.py` (#65812)
  • Loading branch information
megemini authored Jul 9, 2024
1 parent 5b1b406 commit 7dbc881
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions python/paddle/nn/utils/weight_norm_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

import paddle
from paddle import _C_ops

from ...base.data_feeder import check_variable_and_dtype
from ...base.layer_helper import LayerHelper
from ...framework import in_dynamic_mode

if TYPE_CHECKING:
from typing_extensions import Never

from paddle import Tensor
from paddle.nn import Layer

__all__ = []


def l2_norm(x, axis, epsilon=1e-12, name=None):
def l2_norm(
x: Tensor, axis: int, epsilon: float = 1e-12, name: str | None = None
) -> Tensor:
if len(x.shape) == 1:
axis = 0

Expand All @@ -46,7 +59,7 @@ def l2_norm(x, axis, epsilon=1e-12, name=None):
return paddle.squeeze(norm, axis=[axis])


def norm_except_dim(p, dim):
def norm_except_dim(p: Tensor, dim: int) -> Tensor:
shape = p.shape
ndims = len(shape)
if dim == -1:
Expand All @@ -65,7 +78,7 @@ def norm_except_dim(p, dim):
return norm_except_dim(p_transposed, 0)


def _weight_norm(v, g, dim):
def _weight_norm(v: Tensor, g: Tensor, dim: int) -> Tensor:
shape = v.shape
ndims = len(shape)

Expand Down Expand Up @@ -96,19 +109,22 @@ def _weight_norm(v, g, dim):


class WeightNorm:
def __init__(self, name, dim):
name: str
dim: int

def __init__(self, name: str, dim: int) -> None:
if dim is None:
dim = -1
self.name = name
self.dim = dim

def compute_weight(self, layer):
def compute_weight(self, layer: Layer) -> Tensor:
g = getattr(layer, self.name + '_g')
v = getattr(layer, self.name + '_v')
return _weight_norm(v, g, self.dim)

@staticmethod
def apply(layer, name, dim):
def apply(layer: Layer, name: str, dim: int) -> WeightNorm:
for k, hook in layer._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
raise RuntimeError(
Expand Down Expand Up @@ -145,7 +161,7 @@ def apply(layer, name, dim):
layer.register_forward_pre_hook(fn)
return fn

def remove(self, layer):
def remove(self, layer: Layer) -> None:
w_var = self.compute_weight(layer)
delattr(layer, self.name)
del layer._parameters[self.name + '_g']
Expand All @@ -155,11 +171,11 @@ def remove(self, layer):
with paddle.no_grad():
paddle.assign(w_var, w)

def __call__(self, layer, inputs):
def __call__(self, layer: Layer, inputs: Never) -> None:
setattr(layer, self.name, self.compute_weight(layer))


def weight_norm(layer, name='weight', dim=0):
def weight_norm(layer: Layer, name: str = 'weight', dim: int = 0) -> Layer:
r"""
Applies weight normalization to a parameter according to the
following formula:
Expand Down Expand Up @@ -205,7 +221,7 @@ def weight_norm(layer, name='weight', dim=0):
return layer


def remove_weight_norm(layer, name='weight'):
def remove_weight_norm(layer: Layer, name: str = 'weight') -> Layer:
"""
remove weight normalization from layer.
Expand Down

0 comments on commit 7dbc881

Please sign in to comment.