diff --git a/test/test_local_io.py b/test/test_local_io.py index f15048b86..7c4185763 100644 --- a/test/test_local_io.py +++ b/test/test_local_io.py @@ -159,9 +159,15 @@ def make_path(fname): expected_res = [("1.csv", ["key", "item"]), ("1.csv", ["a", "1"]), ("1.csv", ["b", "2"]), ("empty2.csv", [])] self.assertEqual(expected_res, list(csv_parser_dp)) + # Functional Test: yield one row at time from each file as tuple instead of list, skipping over empty content + csv_parser_dp = datapipe3.parse_csv(as_tuple=True) + expected_res = [("key", "item"), ("a", "1"), ("b", "2"), ()] + self.assertEqual(expected_res, list(csv_parser_dp)) + # Reset Test: csv_parser_dp = CSVParser(datapipe3, return_path=True) n_elements_before_reset = 2 + expected_res = [("1.csv", ["key", "item"]), ("1.csv", ["a", "1"]), ("1.csv", ["b", "2"]), ("empty2.csv", [])] res_before_reset, res_after_reset = reset_after_n_next_calls(csv_parser_dp, n_elements_before_reset) self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) self.assertEqual(expected_res, res_after_reset) diff --git a/torchdata/datapipes/iter/util/plain_text_reader.py b/torchdata/datapipes/iter/util/plain_text_reader.py index da11ad2d0..6a03368ee 100644 --- a/torchdata/datapipes/iter/util/plain_text_reader.py +++ b/torchdata/datapipes/iter/util/plain_text_reader.py @@ -25,6 +25,7 @@ def __init__( encoding="utf-8", errors: str = "ignore", return_path: bool = False, + as_tuple: bool = False, ) -> None: if skip_lines < 0: raise ValueError("'skip_lines' is required to be a positive integer.") @@ -34,6 +35,7 @@ def __init__( self._encoding = encoding self._errors = errors self._return_path = return_path + self._as_tuple = as_tuple def skip_lines(self, file: IO) -> Union[Iterator[bytes], Iterator[str]]: with contextlib.suppress(StopIteration): @@ -68,6 +70,16 @@ def return_path(self, stream: Iterator[D], *, path: str) -> Iterator[Union[D, Tu for data in stream: yield path, data + def as_tuple(self, stream: Iterator[D]) -> Iterator[Union[D, Tuple]]: + if not self._as_tuple: + yield from stream + return + for data in stream: + if isinstance(data, list): + yield tuple(data) + else: + yield data + @functional_datapipe("readlines") class LineReaderIterDataPipe(IterDataPipe[Union[Str_Or_Bytes, Tuple[str, Str_Or_Bytes]]]): @@ -136,6 +148,7 @@ def __init__( encoding="utf-8", errors: str = "ignore", return_path: bool = True, + as_tuple: bool = False, **fmtparams, ) -> None: self.source_datapipe = source_datapipe @@ -146,6 +159,7 @@ def __init__( encoding=encoding, errors=errors, return_path=return_path, + as_tuple=as_tuple, ) self.fmtparams = fmtparams @@ -154,6 +168,7 @@ def __iter__(self) -> Iterator[Union[D, Tuple[str, D]]]: stream = self._helper.skip_lines(file) stream = self._helper.decode(stream) stream = self._csv_reader(stream, **self.fmtparams) + stream = self._helper.as_tuple(stream) # type: ignore[assignment] yield from self._helper.return_path(stream, path=path) # type: ignore[misc] @@ -173,6 +188,7 @@ class CSVParserIterDataPipe(_CSVBaseParserIterDataPipe): errors: the error handling scheme used while decoding return_path: if ``True``, each line will return a tuple of path and contents, rather than just the contents + as_tuple: if ``True``, each line will return a tuple instead of a list Example: >>> from torchdata.datapipes.iter import IterableWrapper, FileOpener @@ -196,6 +212,7 @@ def __init__( encoding: str = "utf-8", errors: str = "ignore", return_path: bool = False, + as_tuple: bool = False, **fmtparams, ) -> None: super().__init__( @@ -206,6 +223,7 @@ def __init__( encoding=encoding, errors=errors, return_path=return_path, + as_tuple=as_tuple, **fmtparams, )