Skip to content

Commit

Permalink
[Typing] 修复 shard_optimizer 以及示例中的类型标注错误 (#67529)
Browse files Browse the repository at this point in the history
* [Fix] typing

* [Fix] typing and ex run timeout
  • Loading branch information
megemini authored Aug 21, 2024
1 parent 710797d commit 6e805a2
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 21 deletions.
5 changes: 3 additions & 2 deletions python/paddle/audio/datasets/esc50.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ class ESC50(AudioClassificationDataset):
.. code-block:: python
>>> # doctest: +TIMEOUT(60)
>>> import paddle
>>> mode = 'dev'
>>> esc50_dataset = paddle.audio.datasets.ESC50(mode=mode,
>>> esc50_dataset = paddle.audio.datasets.ESC50(mode=mode, # type: ignore[arg-type]
... feat_type='raw')
>>> for idx in range(5):
... audio, label = esc50_dataset[idx]
Expand All @@ -80,7 +81,7 @@ class ESC50(AudioClassificationDataset):
[220500] 36
[220500] 19
>>> esc50_dataset = paddle.audio.datasets.ESC50(mode=mode,
>>> esc50_dataset = paddle.audio.datasets.ESC50(mode=mode, # type: ignore[arg-type]
... feat_type='mfcc',
... n_mfcc=40)
>>> for idx in range(5):
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/audio/datasets/tess.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,19 @@ class TESS(AudioClassificationDataset):
.. code-block:: python
>>> # doctest: +TIMEOUT(60)
>>> import paddle
>>> mode = 'dev'
>>> tess_dataset = paddle.audio.datasets.TESS(mode=mode,
>>> tess_dataset = paddle.audio.datasets.TESS(mode=mode, # type: ignore[arg-type]
... feat_type='raw')
>>> for idx in range(5):
... audio, label = tess_dataset[idx]
... # do something with audio, label
... print(audio.shape, label)
... # [audio_data_length] , label_id
>>> tess_dataset = paddle.audio.datasets.TESS(mode=mode,
>>> tess_dataset = paddle.audio.datasets.TESS(mode=mode, # type: ignore[arg-type]
... feat_type='mfcc',
... n_mfcc=40)
>>> for idx in range(5):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ def __call__(self, key: str, param: Tensor, accumulator: Tensor) -> Tensor:

def shard_optimizer(
optimizer: Optimizer,
shard_fn: Callable[[str, Layer, ProcessMesh], None] | None = None,
shard_fn: Callable[[str, Tensor, Tensor], Tensor] | None = None,
) -> _ShardOptimizer:
"""
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distributed/communication/stream/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def alltoall_single(
>>> else:
... data = paddle.to_tensor([2, 3])
>>> task = dist.stream.alltoall_single(output, data, sync_op=False)
>>> task.wait() # type: ignore[union-attr]
>>> task.wait()
>>> out = output.numpy()
>>> print(out)
>>> # [0, 2] (2 GPUs, out for rank 0)
Expand All @@ -339,7 +339,7 @@ def alltoall_single(
... out_split_sizes,
... in_split_sizes,
... sync_op=False)
>>> task.wait() # type: ignore[union-attr]
>>> task.wait()
>>> out = output.numpy()
>>> print(out)
>>> # [[0., 0.], [1., 1.]] (2 GPUs, out for rank 0)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/geometric/sampling/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def sample_neighbors(
>>> row = paddle.to_tensor(row, dtype="int64")
>>> colptr = paddle.to_tensor(colptr, dtype="int64")
>>> nodes = paddle.to_tensor(nodes, dtype="int64")
>>> out_neighbors, out_count = paddle.geometric.sample_neighbors(row, colptr, nodes, sample_size=sample_size)
>>> out_neighbors, out_count = paddle.geometric.sample_neighbors(row, colptr, nodes, sample_size=sample_size, return_eids=False)
"""

Expand Down
4 changes: 2 additions & 2 deletions python/paddle/io/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Dataset(Generic[_T]):
>>> from paddle.io import Dataset
>>> # define a random dataset
>>> class RandomDataset(Dataset):
>>> class RandomDataset(Dataset): # type: ignore[type-arg]
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
Expand Down Expand Up @@ -438,7 +438,7 @@ class ChainDataset(IterableDataset[Any]):
>>> # define a random dataset
>>> class RandomDataset(IterableDataset):
>>> class RandomDataset(IterableDataset): # type: ignore[type-arg]
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
Expand Down
22 changes: 11 additions & 11 deletions python/paddle/vision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class ToTensor(BaseTransform[_InputT, "Tensor"]):
>>> img_arr = ((paddle.rand((4, 5, 3)) * 255.).astype('uint8')).numpy()
>>> fake_img = Image.fromarray(img_arr)
>>> transform = T.ToTensor()
>>> tensor = transform(fake_img)
>>> tensor = transform(fake_img) # type: ignore[call-overload]
>>> print(tensor.shape)
[3, 4, 5]
>>> print(tensor.dtype)
Expand Down Expand Up @@ -530,7 +530,7 @@ class RandomResizedCrop(BaseTransform[_InputT, _RetT]):
>>> transform = RandomResizedCrop(224)
>>> fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8))
>>> fake_img = transform(fake_img)
>>> fake_img = transform(fake_img) # type: ignore[call-overload]
>>> print(fake_img.size)
(224, 224)
Expand Down Expand Up @@ -736,7 +736,7 @@ class CenterCrop(BaseTransform[_InputT, _RetT]):
>>> transform = CenterCrop(224)
>>> fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8))
>>> fake_img = transform(fake_img)
>>> fake_img = transform(fake_img) # type: ignore[call-overload]
>>> print(fake_img.size)
(224, 224)
Expand Down Expand Up @@ -985,7 +985,7 @@ class Transpose(BaseTransform[_InputT, _RetT]):
>>> transform = Transpose()
>>> fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8))
>>> fake_img = transform(fake_img)
>>> fake_img = transform(fake_img) # type: ignore[call-overload]
>>> print(fake_img.shape)
(3, 300, 320)
Expand Down Expand Up @@ -1089,7 +1089,7 @@ class ContrastTransform(BaseTransform[_InputT, _RetT]):
>>> transform = ContrastTransform(0.4)
>>> fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8))
>>> fake_img = transform(fake_img)
>>> fake_img = transform(fake_img) # type: ignore[call-overload]
>>> print(fake_img.size)
(224, 224)
Expand Down Expand Up @@ -1138,7 +1138,7 @@ class SaturationTransform(BaseTransform[_InputT, _RetT]):
>>> transform = SaturationTransform(0.4)
>>> fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8))
>>> fake_img = transform(fake_img)
>>> fake_img = transform(fake_img) # type: ignore[call-overload]
>>> print(fake_img.size)
(224, 224)
"""
Expand Down Expand Up @@ -1184,7 +1184,7 @@ class HueTransform(BaseTransform[_InputT, _RetT]):
>>> transform = HueTransform(0.4)
>>> fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8))
>>> fake_img = transform(fake_img)
>>> fake_img = transform(fake_img) # type: ignore[call-overload]
>>> print(fake_img.size)
(224, 224)
Expand Down Expand Up @@ -1239,7 +1239,7 @@ class ColorJitter(BaseTransform[_InputT, _RetT]):
>>> transform = ColorJitter(0.4, 0.4, 0.4, 0.4)
>>> fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8))
>>> fake_img = transform(fake_img)
>>> fake_img = transform(fake_img) # type: ignore[call-overload]
>>> print(fake_img.size)
(224, 224)
Expand Down Expand Up @@ -1480,7 +1480,7 @@ class Pad(BaseTransform[_InputT, _RetT]):
>>> transform = Pad(2)
>>> fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8))
>>> fake_img = transform(fake_img)
>>> fake_img = transform(fake_img) # type: ignore[call-overload]
>>> print(fake_img.size)
(228, 228)
"""
Expand Down Expand Up @@ -1778,7 +1778,7 @@ class RandomRotation(BaseTransform[_InputT, _RetT]):
>>> transform = RandomRotation(90)
>>> fake_img = Image.fromarray((np.random.rand(200, 150, 3) * 255.).astype(np.uint8))
>>> fake_img = transform(fake_img)
>>> fake_img = transform(fake_img) # type: ignore[call-overload]
>>> print(fake_img.size)
(150, 200)
"""
Expand Down Expand Up @@ -2012,7 +2012,7 @@ class Grayscale(BaseTransform[_InputT, _RetT]):
>>> transform = Grayscale()
>>> fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8))
>>> fake_img = transform(fake_img)
>>> fake_img = transform(fake_img) # type: ignore[call-overload]
>>> print(np.array(fake_img).shape)
(224, 224)
"""
Expand Down

0 comments on commit 6e805a2

Please sign in to comment.