Skip to content

Commit

Permalink
[Typing] 修复一些类型标注问题 (PaddlePaddle#66677)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
2 people authored and lixcli committed Aug 5, 2024
1 parent 2753b46 commit 2573bef
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 13 deletions.
4 changes: 3 additions & 1 deletion python/paddle/incubate/autograd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,9 @@ def __init__(
else:
self._jacobian = _JacobianBatchFirst(func, xs)

def __getitem__(self, indexes: int | slice) -> Tensor:
def __getitem__(
self, indexes: int | slice | tuple[int | slice, ...]
) -> Tensor:
return self._jacobian[indexes]

@property
Expand Down
28 changes: 23 additions & 5 deletions python/paddle/io/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ def __len__(self) -> int:
"{}".format('__len__', self.__class__.__name__)
)

if TYPE_CHECKING:
# A virtual method for type checking only
def __iter__(self) -> Iterator[_T]:
...


class IterableDataset(Dataset[_T]):
"""
Expand Down Expand Up @@ -482,19 +487,32 @@ class Subset(Dataset[_T]):
.. code-block:: python
>>> import paddle
>>> from paddle.io import Subset
>>> # example 1:
>>> a = paddle.io.Subset(dataset=range(1, 4), indices=[0, 2])
>>> class RangeDataset(paddle.io.Dataset): # type: ignore[type-arg]
... def __init__(self, start, stop):
... self.start = start
... self.stop = stop
...
... def __getitem__(self, index):
... return index + self.start
...
... def __len__(self):
... return self.stop - self.start
>>> # Example 1:
>>> a = paddle.io.Subset(dataset=RangeDataset(1, 4), indices=[0, 2])
>>> print(list(a))
[1, 3]
>>> # example 2:
>>> b = paddle.io.Subset(dataset=range(1, 4), indices=[1, 1])
>>> # Example 2:
>>> b = paddle.io.Subset(dataset=RangeDataset(1, 4), indices=[1, 1])
>>> print(list(b))
[2, 2]
"""

dataset: Dataset[_T]
indices: Sequence[int]

def __init__(self, dataset: Dataset[_T], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
Expand Down
25 changes: 18 additions & 7 deletions python/paddle/io/dataloader/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import queue
import sys
import traceback
from typing import TYPE_CHECKING, Any

import numpy as np

Expand All @@ -31,6 +32,9 @@
from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
from .flat import _flatten_batch

if TYPE_CHECKING:
from paddle.io import Dataset


class _IterableDatasetStopIteration:
def __init__(self, worker_id):
Expand Down Expand Up @@ -77,7 +81,7 @@ def is_alive(self):
_worker_info = None


def get_worker_info() -> WorkerInfo | None:
def get_worker_info() -> WorkerInfo:
"""
Get DataLoader worker process information function, this function is
used to split data copy in worker process for IterableDataset
Expand Down Expand Up @@ -105,7 +109,7 @@ def get_worker_info() -> WorkerInfo | None:
>>> import numpy as np
>>> from paddle.io import IterableDataset, DataLoader, get_worker_info
>>> class SplitedIterableDataset(IterableDataset):
>>> class SplitedIterableDataset(IterableDataset): # type: ignore[type-arg]
... def __init__(self, start, end):
... self.start = start
... self.end = end
Expand All @@ -118,8 +122,8 @@ def get_worker_info() -> WorkerInfo | None:
... else:
... per_worker = int(
... math.ceil((self.end - self.start) / float(
... worker_info.num_workers))) # type: ignore[attr-defined]
... worker_id = worker_info.id # type: ignore[attr-defined]
... worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
...
Expand Down Expand Up @@ -157,6 +161,11 @@ def get_worker_info() -> WorkerInfo | None:


class WorkerInfo:
num_workers: int
id: int
dataset: Dataset[Any]
seed: int

__initialized = False

def __init__(self, **kwargs):
Expand Down Expand Up @@ -390,9 +399,11 @@ def numpy2lodtensor(arr):
return lodtensor

tensor_list = [
numpy2lodtensor(b)
if isinstance(b, np.ndarray)
else b.get_tensor()
(
numpy2lodtensor(b)
if isinstance(b, np.ndarray)
else b.get_tensor()
)
for b in batch
]
out_queue.put((idx, tensor_list, structure))
Expand Down

0 comments on commit 2573bef

Please sign in to comment.