11import csv
2+ from itertools import islice
23from typing import Any , Dict , Iterator , List , Optional , Sequence , TextIO , Union
34
45from torchdata .nodes .base_node import BaseNode
@@ -13,7 +14,7 @@ class CSVReader(BaseNode[Union[List[str], Dict[str, str]]]):
1314 return_dict: Return rows as dictionaries (requires has_header=True)
1415 """
1516
16- LINE_NUM_KEY = "line_num "
17+ NUM_LINES_YIELDED = "num_lines_yielded "
1718 HEADER_KEY = "header"
1819
1920 def __init__ (
@@ -22,6 +23,7 @@ def __init__(
2223 has_header : bool = False ,
2324 delimiter : str = "," ,
2425 return_dict : bool = False ,
26+ encoding : str = "utf-8" ,
2527 ):
2628 super ().__init__ ()
2729 self .file_path = file_path
@@ -30,64 +32,84 @@ def __init__(
3032 self .return_dict = return_dict
3133 if return_dict and not has_header :
3234 raise ValueError ("return_dict=True requires has_header=True" )
35+ self .encoding = encoding
3336 self ._file : Optional [TextIO ] = None
3437 self ._reader : Optional [Iterator [Union [List [str ], Dict [str , str ]]]] = None
3538 self ._header : Optional [Sequence [str ]] = None
36- self ._line_num : int = 0
39+ self ._num_lines_yielded : int = 0
3740 self .reset () # Initialize reader
3841
3942 def reset (self , initial_state : Optional [Dict [str , Any ]] = None ):
40- super ().reset (initial_state )
41-
42- if self ._file and not self ._file .closed :
43- self ._file .close ()
43+ super ().reset ()
44+ self .close ()
45+
46+ # Reopen the file and reset counters
47+ self ._file = open (self .file_path , encoding = self .encoding )
48+ self ._num_lines_yielded = 0
49+ if initial_state is not None :
50+ self ._handle_initial_state (initial_state )
51+ else :
52+ self ._initialize_reader ()
4453
45- self ._file = open (self .file_path , newline = "" , encoding = "utf-8" )
46- self ._line_num = 0
54+ def _handle_initial_state (self , state : Dict [str , Any ]):
55+ """Restore reader state from checkpoint."""
56+ # Validate header compatibility
57+ if (not self .has_header and self .HEADER_KEY in state ) or (self .has_header and state [self .HEADER_KEY ] is None ):
58+ raise ValueError (f"Check if has_header={ self .has_header } matches the state header={ state [self .HEADER_KEY ]} " )
4759
48- if initial_state :
49- self ._header = initial_state .get (self .HEADER_KEY )
50- target_line_num = initial_state [self .LINE_NUM_KEY ]
60+ self ._header = state .get (self .HEADER_KEY )
61+ target_line_num = state [self .NUM_LINES_YIELDED ]
62+ assert self ._file is not None
63+ # Create appropriate reader
64+ if self .return_dict :
5165
52- if self .return_dict :
53- if self ._header is None :
54- raise ValueError ("return_dict=True requires has_header=True" )
55- self ._reader = csv .DictReader (self ._file , delimiter = self .delimiter , fieldnames = self ._header )
56- else :
57- self ._reader = csv .reader (self ._file , delimiter = self .delimiter )
66+ self ._reader = csv .DictReader (self ._file , delimiter = self .delimiter , fieldnames = self ._header )
67+ else :
68+ self ._reader = csv .reader (self ._file , delimiter = self .delimiter )
69+ # Skip header if needed (applies only when file has header)
70+
71+ assert isinstance (self ._reader , Iterator )
72+ if self .has_header :
73+ try :
74+ next (self ._reader ) # Skip header line
75+ except StopIteration :
76+ pass # Empty file
77+ # Fast-forward to target line using efficient slicing
78+ consumed = sum (1 for _ in islice (self ._reader , target_line_num ))
79+ self ._num_lines_yielded = consumed
80+
81+ def _initialize_reader (self ):
82+ """Create fresh reader without state."""
83+ assert self ._file is not None
84+ if self .return_dict :
85+ self ._reader = csv .DictReader (self ._file , delimiter = self .delimiter )
86+ self ._header = self ._reader .fieldnames
87+ else :
88+ self ._reader = csv .reader (self ._file , delimiter = self .delimiter )
5889
59- assert isinstance (self ._reader , Iterator )
6090 if self .has_header :
61- next (self ._reader ) # Skip header
62- for _ in range (target_line_num - self ._line_num ):
63- try :
64- next (self ._reader )
65- self ._line_num += 1
66- except StopIteration :
67- break
68- else :
6991
70- if self .return_dict :
71- self ._reader = csv .DictReader (self ._file , delimiter = self .delimiter )
72- self ._header = self ._reader .fieldnames
73- else :
74- self ._reader = csv .reader (self ._file , delimiter = self .delimiter )
75- if self .has_header :
92+ try :
7693 self ._header = next (self ._reader )
94+ except StopIteration :
95+ self ._header = None # Handle empty file
7796
7897 def next (self ) -> Union [List [str ], Dict [str , str ]]:
7998 try :
8099 assert isinstance (self ._reader , Iterator )
81100 row = next (self ._reader )
82- self ._line_num += 1
101+ self ._num_lines_yielded += 1
83102 return row
84103
85104 except StopIteration :
86105 self .close ()
87106 raise
88107
89108 def get_state (self ) -> Dict [str , Any ]:
90- return {self .LINE_NUM_KEY : self ._line_num , self .HEADER_KEY : self ._header }
109+ return {
110+ self .NUM_LINES_YIELDED : self ._num_lines_yielded ,
111+ self .HEADER_KEY : self ._header ,
112+ }
91113
92114 def close (self ):
93115 if self ._file and not self ._file .closed :
0 commit comments