diff --git a/python/paddle/batch.py b/python/paddle/batch.py index 98e5a6a14545a..be0cfa5e00ab2 100644 --- a/python/paddle/batch.py +++ b/python/paddle/batch.py @@ -12,10 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import Callable, Generator, TypeVar + +_T = TypeVar('_T') __all__ = [] -def batch(reader, batch_size, drop_last=False): +def batch( + reader: Callable[[], Generator[_T, None, None]], + batch_size: int, + drop_last: bool = False, +) -> Callable[[], Generator[list[_T], None, None]]: """ This operator creates a batched reader which combines the data from the input reader to batched data.