Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【易用性提升No.7】为transformerDecoderLayer新增layer_norm_eps参数 #60084

Merged
merged 2 commits into from
Dec 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/paddle/nn/layer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ class TransformerDecoderLayer(Layer):
corresponding layer would not have trainable bias parameter. See
usage for details in :code:`ParamAttr` . Default: None,which means
the default bias parameter property is used.
layer_norm_eps: the eps value in layer normalization components. Default=1e-5.

Examples:

Expand Down Expand Up @@ -843,6 +844,7 @@ def __init__(
normalize_before=False,
weight_attr=None,
bias_attr=None,
layer_norm_eps=1e-5,
):
self._config = locals()
self._config.pop("self")
Expand Down Expand Up @@ -889,9 +891,9 @@ def __init__(
self.linear2 = Linear(
dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]
)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)
self.norm1 = LayerNorm(d_model, layer_norm_eps)
self.norm2 = LayerNorm(d_model, layer_norm_eps)
self.norm3 = LayerNorm(d_model, layer_norm_eps)
self.dropout1 = Dropout(dropout, mode="upscale_in_train")
self.dropout2 = Dropout(dropout, mode="upscale_in_train")
self.dropout3 = Dropout(dropout, mode="upscale_in_train")
Expand Down