@@ -34,13 +34,16 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
3434 self ._name = name
3535 self ._prefix = pathlib .Path (prefix )
3636 self ._has_spans = 0
37+ self ._has_preference_spans = False
3738
3839 with self ._prefix .with_suffix (".idx" ).open ("rb" ) as stream :
3940 Assert .eq (stream .read (9 ), MEMMAP_INDEX_HEADER , msg = f"File: { stream .name } " )
4041 self ._version = struct .unpack ("<Q" , stream .read (8 ))[0 ]
41- assert self ._version in [1 , 2 ], f"Unsupported version for gpt_memmap dataset: { self ._version } ."
42- if self ._version = = 2 :
42+ assert self ._version in [1 , 2 , 3 ], f"Unsupported version for gpt_memmap dataset: { self ._version } ."
43+ if self ._version > = 2 :
4344 self ._has_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
45+ if self ._version >= 3 :
46+ self ._has_preference_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
4447
4548 self ._dtype = MEMMAP_DTYPES [struct .unpack ("<B" , stream .read (1 ))[0 ]].numpy
4649 self ._num_documents = struct .unpack ("<Q" , stream .read (8 ))[0 ]
@@ -52,18 +55,23 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
5255
5356 self ._index_bin_buffer_mmap = np .memmap (self ._prefix .with_suffix (".idx" ), mode = "r" , order = "C" )
5457 self ._index_bin_buffer = memoryview (self ._index_bin_buffer_mmap )
58+
59+ # read document sizes
5560 self ._document_sizes = np .frombuffer (
5661 self ._index_bin_buffer , dtype = np .int32 , count = self ._num_documents , offset = offset
5762 )
63+
64+ # read pointers
5865 self ._pointers = np .frombuffer (
5966 self ._index_bin_buffer ,
6067 dtype = np .int64 ,
6168 count = self ._num_documents ,
6269 offset = offset + self ._document_sizes .nbytes ,
6370 )
6471
72+ # read spans
6573 self ._spans = None
66- if self ._has_spans and self ._version = = 2 :
74+ if self ._has_spans and self ._version > = 2 :
6775 self ._spans = []
6876 self ._num_spans = np .frombuffer (
6977 self ._index_bin_buffer ,
@@ -83,6 +91,36 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
8391 ).reshape (- 1 , 2 )
8492 )
8593
94+ # read preference spans
95+ self ._chosen_spans = None
96+ self ._rejected_spans = None
97+ if self ._has_preference_spans and self ._version >= 3 :
98+ self ._chosen_spans = []
99+ self ._rejected_spans = []
100+ chosen_span_offset = offset + self ._document_sizes .nbytes + self ._pointers .nbytes
101+ for idx in range (self ._num_documents ):
102+ self ._chosen_spans .append (
103+ np .frombuffer (
104+ self ._index_bin_buffer ,
105+ dtype = np .int32 ,
106+ count = 2 ,
107+ offset = chosen_span_offset + idx * 2 * np .dtype (np .int32 ).itemsize ,
108+ )
109+ )
110+
111+ rejected_span_offset = (
112+ offset + self ._document_sizes .nbytes + self ._pointers .nbytes + np .array (self ._chosen_spans ).nbytes
113+ )
114+ for idx in range (self ._num_documents ):
115+ self ._rejected_spans .append (
116+ np .frombuffer (
117+ self ._index_bin_buffer ,
118+ dtype = np .int32 ,
119+ count = 2 ,
120+ offset = rejected_span_offset + idx * 2 * np .dtype (np .int32 ).itemsize ,
121+ )
122+ )
123+
86124 self ._bin_buffer_mmap = np .memmap (self ._prefix .with_suffix (".bin" ), mode = "r" , order = "C" )
87125 self ._bin_buffer = memoryview (self ._bin_buffer_mmap )
88126
@@ -105,7 +143,12 @@ def __del__(self):
105143 del self ._index_bin_buffer_mmap
106144
107145 def get (
108- self , idx : int , offset : int = 0 , length : int | None = None , use_loss_masking_spans : bool = False
146+ self ,
147+ idx : int ,
148+ offset : int = 0 ,
149+ length : int | None = None ,
150+ use_loss_masking_spans : bool = False ,
151+ use_preference_loss_spans : bool = False ,
109152 ) -> GPTSample :
110153 token_ids = np .frombuffer (
111154 self ._bin_buffer ,
@@ -116,13 +159,53 @@ def get(
116159 sample_spans = None
117160 if use_loss_masking_spans and self ._spans is not None :
118161 sample_spans = self ._spans [idx ]
119- # adjust the spans for the offset and length
162+
163+ # filter spans that are outside the range of the selected tokens in the document
120164 sample_spans = sample_spans [
121165 (sample_spans [:, 0 ] < offset + len (token_ids )) & (sample_spans [:, 1 ] >= offset )
122166 ]
123- sample_spans [:, 0 ] = np .maximum (sample_spans [:, 0 ], offset ) - offset
167+
168+ # subtract by offset to normalize span boundaries
169+ sample_spans [:, 0 ] = np .maximum (sample_spans [:, 0 ], offset ) - offset # offset
124170 sample_spans [:, 1 ] = np .minimum (sample_spans [:, 1 ], offset + len (token_ids ) - 1 ) - offset
125- return GPTSample (token_ids = token_ids , loss_masking_spans = sample_spans )
171+
172+ chosen_span = None
173+ rejected_span = None
174+
175+ if use_preference_loss_spans :
176+ if not self ._has_preference_spans :
177+ raise ValueError ("No preference spans found in memmap dataset." )
178+ elif self ._has_preference_spans and self ._chosen_spans is None :
179+ raise ValueError ("Failed to read chosen spans from memmap dataset." )
180+ elif self ._has_preference_spans and self ._rejected_spans is None :
181+ raise ValueError ("Failed to read rejected spans from memmap dataset." )
182+ else :
183+ chosen_span = self ._chosen_spans [idx ]
184+
185+ # filter spans that are outside the range of the selected tokens in the document
186+ chosen_span = chosen_span [(chosen_span [0 ] < offset + len (token_ids )) & (chosen_span [1 ] >= offset )][0 ]
187+
188+ # subtract by offset to normalize span boundaries
189+ chosen_span [0 ] = np .maximum (chosen_span [0 ], offset ) - offset # offset
190+ chosen_span [1 ] = np .minimum (chosen_span [1 ], offset + len (token_ids ) - 1 ) - offset
191+
192+ rejected_span = self ._rejected_spans [idx ]
193+
194+ # filter spans that are outside the range of the selected tokens in the document
195+ rejected_span = rejected_span [
196+ (rejected_span [0 ] < offset + len (token_ids )) & (rejected_span [1 ] >= offset )
197+ ][0 ]
198+
199+ # subtract by offset to normalize span boundaries
200+ rejected_span [0 ] = np .maximum (rejected_span [0 ], offset ) - offset # offset
201+ rejected_span [1 ] = np .minimum (rejected_span [1 ], offset + len (token_ids ) - 1 ) - offset
202+
203+ return GPTSample (
204+ token_ids = token_ids ,
205+ loss_masking_spans = sample_spans ,
206+ chosen_span = chosen_span ,
207+ rejected_span = rejected_span ,
208+ )
126209
127210 @property
128211 def name (self ) -> str :
@@ -157,6 +240,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
157240 # number of spans for each document
158241 num_spans = []
159242 spans = []
243+ chosen_spans = []
244+ rejected_spans = []
160245
161246 prefix = pathlib .Path (prefix )
162247 prefix .parent .mkdir (parents = True , exist_ok = True )
@@ -182,6 +267,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
182267 if document .loss_masking_spans is not None :
183268 num_spans .append (len (document .loss_masking_spans ))
184269 spans .append (document .loss_masking_spans )
270+ if document .chosen_span is not None :
271+ chosen_spans .append (document .chosen_span )
272+ if document .rejected_span is not None :
273+ rejected_spans .append (document .rejected_span )
185274 offset += doc_length * np .dtype (dtype ).itemsize
186275 num_documents += 1
187276
@@ -193,15 +282,20 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
193282 spans = np .vstack (spans , dtype = np .int32 )
194283 else :
195284 spans = np .array (spans , dtype = np .int32 )
285+ chosen_spans = np .array (chosen_spans , dtype = np .int32 ).reshape (- 1 , 2 )
286+ rejected_spans = np .array (rejected_spans , dtype = np .int32 ).reshape (- 1 , 2 )
196287
197288 # Write the index file (.idx)
198289 with prefix .with_suffix (".idx" ).open ("wb" ) as idx_stream :
199290 idx_stream .write (MEMMAP_INDEX_HEADER )
200291 # Indicates the version
201292 # Version 2 optionally adds loss-masking spans
202- idx_stream .write (struct .pack ("<Q" , 2 ))
293+ # Version 3 optionally adds chosen/rejected spans
294+ idx_stream .write (struct .pack ("<Q" , 3 ))
203295 # Flag to indicate whether loss-masking spans are present
204296 idx_stream .write (struct .pack ("<B" , 1 if spans .size > 0 else 0 ))
297+ # Flag to indicate whether preference loss-masking spans are present
298+ idx_stream .write (struct .pack ("<B" , 1 if chosen_spans .size > 0 and rejected_spans .size > 0 else 0 ))
205299 # Data type
206300 idx_stream .write (struct .pack ("<B" , MEMMAP_DTYPES_INV [DataType .from_numpy (dtype .type )]))
207301 # "Number of sequences", same as documents in our case
@@ -216,5 +310,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
216310 idx_stream .write (num_spans .tobytes (order = "C" ))
217311 # Span indices for each document
218312 idx_stream .write (spans .tobytes (order = "C" ))
313+ # Chosen indices for each document
314+ idx_stream .write (chosen_spans .tobytes (order = "C" ))
315+ # Rejected indices for each document
316+ idx_stream .write (rejected_spans .tobytes (order = "C" ))
219317 # Document indices, unused but needed for compatibility with Megatron-LM
220318 idx_stream .write (np .arange (num_documents + 1 , dtype = np .int64 ).tobytes (order = "C" ))
0 commit comments