Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Micalling committed Oct 10, 2024
1 parent 12e0d38 commit 6d98f48
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions python/paddle/nn/layer/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class ParameterDict(Layer):
... super().__init__()
... # create ParameterDict with iterable Parameters
... self.params = paddle.nn.ParameterDict(
... {'t' + str(i): paddle.create_parameter(shape=[2, 2], dtype='float32') for i in range(num_stacked_param)})
... {f"t{i}": paddle.create_parameter(shape=[2, 2], dtype='float32') for i in range(num_stacked_param)})
...
... def forward(self, x):
... for i, (key, p) in enumerate(self.params):
Expand Down Expand Up @@ -371,7 +371,7 @@ def __init__(
self,
parameters: (
ParameterDict
| typing.Mapping[str, Tensor]
| Mapping[str, Tensor]
| Sequence[tuple[str, Tensor]]
| None
) = None,
Expand All @@ -393,13 +393,13 @@ def __len__(self) -> int:

def __iter__(self) -> Iterator[tuple[str, Tensor]]:
with param_guard(self._parameters):
return iter(self._parameters.items())
return iter(self._parameters)

def update(
self,
parameters: (
ParameterDict
| typing.Mapping[str, Tensor]
| Mapping[str, Tensor]
| Sequence[tuple[str, Tensor]]
),
) -> None:
Expand All @@ -420,11 +420,7 @@ def update(
for i, kv in enumerate(parameters):
if len(kv) != 2:
raise ValueError(
"The length of the "
+ str(i)
+ "'s element in parameters is "
+ str(len(kv))
+ ", which must be 2."
f"The length of the {i}'s element in parameters is {len(kv)}, which must be 2."
)
self.add_parameter(kv[0], kv[1])

Expand Down

0 comments on commit 6d98f48

Please sign in to comment.