-
Notifications
You must be signed in to change notification settings - Fork 109
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add back refact data pipeline * add back filters_chat.py * grammar
- Loading branch information
1 parent
db29f83
commit a4f7a1c
Showing
5 changed files
with
528 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import traceback | ||
import traces | ||
from typing import Dict | ||
|
||
from code_contrast.format_2023q2 import format, packing | ||
from code_contrast.format_2023q2.el_msg import MsgElement | ||
from code_contrast.format_2023q2.element import Format2023q2 | ||
from code_contrast.format_2023q2.from_orig_dest_message import from_odm_dict | ||
from refact_data_pipeline import DatasetOpts | ||
from refact_encoding.encoding import RefactEncoding | ||
|
||
|
||
class Chat2023Q2FromODM: | ||
def __init__(self, | ||
inner_filter, | ||
dataopts: DatasetOpts): | ||
self.inner_filter = inner_filter | ||
self.n_ctx = dataopts.get("n_ctx", 2048) | ||
self.enc: RefactEncoding = dataopts.encoding | ||
self.fmt: Format2023q2 = format.format_2023q2_escape(self.enc) | ||
|
||
def __iter__(self): | ||
stats: Dict[str, int] = { | ||
"chatskip_failed": 0, | ||
} | ||
for odm in self.inner_filter: | ||
assert len(odm['chat']) > 0 | ||
plan = [] | ||
for item in odm['chat']: | ||
if len(item['instruction']) > 0: | ||
plan.append(MsgElement("SYSTEM", item['instruction'])) | ||
if len(item['input']) > 0: | ||
plan.append(MsgElement("USER", item['input'])) | ||
plan.append(MsgElement("ASSISTANT", item['output'])) | ||
|
||
try: | ||
pack = packing.Packer(self.fmt) | ||
for p in plan: | ||
pack.add_to_plan(p) | ||
pack.pack_context( | ||
start_from_plan_n=0, | ||
mask_from_plan_n=0, | ||
limit_ctx_n=self.n_ctx, | ||
limit_aux_n=0, | ||
add_eot=True, | ||
for_training=True | ||
) | ||
except Exception as e: | ||
msg = "{\n" | ||
for key, val in odm.items(): | ||
msg += f" {repr(key)}: {repr(val)},\n" | ||
msg += "}" | ||
traces.log(msg) | ||
traces.log(traceback.format_exc()) | ||
stats["chatskip_failed"] += 1 | ||
continue | ||
first = [1] + [0] * (len(pack.r) - 1) | ||
assert len(pack.r) == len(first) | ||
assert len(pack.r) == len(pack.m) | ||
emit = { | ||
"tokens": pack.r, | ||
"mask": pack.m, | ||
"first": first, | ||
"stats": {**odm["stats"], **stats} | ||
} | ||
yield emit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import random | ||
from pathlib import Path | ||
from typing import Tuple, Optional, Any, List, Dict | ||
|
||
import mpi4py.MPI as mpi | ||
import numpy as np | ||
import tables as tb | ||
|
||
from refact_data_pipeline import DatasetOpts | ||
|
||
|
||
def _try_open(path: Path) -> Optional[Any]: | ||
try: | ||
return tb.open_file(str(path), mode='r') | ||
except Exception as e: | ||
print(f'Cannot open the file {path}: {e}') | ||
return None | ||
|
||
|
||
|
||
class Hdf5Dataset: | ||
""" | ||
A class that maps HDF5 files to flat array of data | ||
Parameters | ||
---------- | ||
comm : Optional[mpi4py.MPI.Comm] | ||
The MPI communicator. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataopts: DatasetOpts, | ||
files: List[Path], | ||
comm: Optional[mpi.Comm] = None, | ||
cold_restart_skip: Optional[int] = None | ||
): | ||
files = [_try_open(p) for p in files] | ||
files = [f for f in files if f is not None] | ||
assert len(files) > 0 | ||
self.files = files | ||
self.tables = [file.root.data for file in self.files] | ||
self.keys = dataopts.get("keys", "tokens;mask").split(';') | ||
self.manual_seed = dataopts.get("hdfs_seed", None) | ||
self.comm = comm | ||
self.cold_restart_skip = cold_restart_skip | ||
if self.cold_restart_skip is not None: | ||
assert self.manual_seed is not None, \ | ||
"`cold_restart_skip` requires the manual seed, otherwise it doesn't make sence" | ||
self.tables_lengths = [len(t) for t in self.tables] | ||
self.tables_lengths_cumsum = np.cumsum(self.tables_lengths) | ||
self.overall_length = self.tables_lengths_cumsum[-1] | ||
self.index = self.__reshuffle() | ||
self.tables_iter = None | ||
|
||
def __del__(self): | ||
for file in self.files: | ||
file.close() | ||
|
||
def __reshuffle(self) -> np.ndarray: | ||
if self.manual_seed is None: | ||
seed = random.randint(0, 2 ** 32 - 1) | ||
if self.comm is not None: | ||
seed = self.comm.bcast(seed, root=0) | ||
else: | ||
seed = self.manual_seed | ||
|
||
rng = np.random.default_rng(seed) | ||
index = rng.choice(self.overall_length, self.overall_length, replace=False) | ||
|
||
if self.comm is not None: | ||
rank_len = len(index) // self.comm.size | ||
index = index[:rank_len * self.comm.size] | ||
index = index[self.comm.rank * rank_len:(self.comm.rank + 1) * rank_len] | ||
|
||
return index | ||
|
||
def reshuffle(self): | ||
assert self.manual_seed is None, "`reshuffle` with the manual seed leads to do nothing, it may be a bug" | ||
self.index = self.__reshuffle() | ||
assert self.tables_iter is None, "`reshuffle` cannot be called while iterating" | ||
|
||
def __len__(self): | ||
return len(self.index) | ||
|
||
def __next__(self) -> Dict[str, Any]: | ||
assert self.tables_iter is not None, "`__next__` called before `__iter__`" | ||
iter_n, idx = next(self.tables_iter) | ||
data = dict(zip(self.keys, self[idx])) | ||
data['stats'] = dict(record_n=int(iter_n), restart=int(iter_n)) | ||
return data | ||
|
||
def __iter__(self) -> 'Hdf5Dataset': | ||
self.tables_iter = iter(enumerate(self.index)) | ||
if self.cold_restart_skip is not None: | ||
for _ in range(self.cold_restart_skip): | ||
next(self.tables_iter) | ||
self.cold_restart_skip = None | ||
return self | ||
|
||
def __getitem__(self, idx: int) -> Tuple[Any, ...]: | ||
table_idx, table_cumsum = next( | ||
((i, t) for i, t in enumerate(self.tables_lengths_cumsum) if idx < t) | ||
) | ||
row_idx = idx - (table_cumsum - self.tables_lengths[table_idx]) - 1 | ||
row = self.tables[table_idx][row_idx] | ||
return tuple(row[k].tolist() for k in self.keys) |
Oops, something went wrong.