Skip to content

Commit

Permalink
Add back missing modules (#73)
Browse files Browse the repository at this point in the history
* add back refact data pipeline

* add back filters_chat.py

* grammar
  • Loading branch information
JegernOUTT authored Aug 12, 2023
1 parent db29f83 commit a4f7a1c
Show file tree
Hide file tree
Showing 5 changed files with 528 additions and 124 deletions.
23 changes: 22 additions & 1 deletion refact_data_pipeline/datadef.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import List, Set
from pathlib import Path
from typing import List, Set, Tuple


class DatasetDef:
Expand All @@ -17,6 +18,26 @@ def __repr__(self):
self.cloud_path, len(self.cloud_files), str(self.to_apply))


class DatasetDumpedDef:
def __init__(
self,
path: str,
to_apply: Set[str],
suffixes: Tuple[str, ...] = ('.h5', '.hdf5')
):
assert not path.startswith('gs://'), "DatasetDumpedDef doesn't support cloud-based paths " \
"because of random access to files"
# Those paths are not cloud, just for names compatibility
self.cloud_path = path
self.cloud_files = [p for p in sorted(Path(path).iterdir(), key=lambda p: p.name)
if p.suffix in suffixes]
self.to_apply = to_apply

def __repr__(self):
return "dataset definition %s with %i files and filters %s" % (
self.cloud_path, len(self.cloud_path), str(self.to_apply))


class DatasetMix:
def __init__(self,
dataset_defs: List[DatasetDef],
Expand Down
66 changes: 66 additions & 0 deletions refact_data_pipeline/filters_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import traceback
import traces
from typing import Dict

from code_contrast.format_2023q2 import format, packing
from code_contrast.format_2023q2.el_msg import MsgElement
from code_contrast.format_2023q2.element import Format2023q2
from code_contrast.format_2023q2.from_orig_dest_message import from_odm_dict
from refact_data_pipeline import DatasetOpts
from refact_encoding.encoding import RefactEncoding


class Chat2023Q2FromODM:
def __init__(self,
inner_filter,
dataopts: DatasetOpts):
self.inner_filter = inner_filter
self.n_ctx = dataopts.get("n_ctx", 2048)
self.enc: RefactEncoding = dataopts.encoding
self.fmt: Format2023q2 = format.format_2023q2_escape(self.enc)

def __iter__(self):
stats: Dict[str, int] = {
"chatskip_failed": 0,
}
for odm in self.inner_filter:
assert len(odm['chat']) > 0
plan = []
for item in odm['chat']:
if len(item['instruction']) > 0:
plan.append(MsgElement("SYSTEM", item['instruction']))
if len(item['input']) > 0:
plan.append(MsgElement("USER", item['input']))
plan.append(MsgElement("ASSISTANT", item['output']))

try:
pack = packing.Packer(self.fmt)
for p in plan:
pack.add_to_plan(p)
pack.pack_context(
start_from_plan_n=0,
mask_from_plan_n=0,
limit_ctx_n=self.n_ctx,
limit_aux_n=0,
add_eot=True,
for_training=True
)
except Exception as e:
msg = "{\n"
for key, val in odm.items():
msg += f" {repr(key)}: {repr(val)},\n"
msg += "}"
traces.log(msg)
traces.log(traceback.format_exc())
stats["chatskip_failed"] += 1
continue
first = [1] + [0] * (len(pack.r) - 1)
assert len(pack.r) == len(first)
assert len(pack.r) == len(pack.m)
emit = {
"tokens": pack.r,
"mask": pack.m,
"first": first,
"stats": {**odm["stats"], **stats}
}
yield emit
107 changes: 107 additions & 0 deletions refact_data_pipeline/filters_hdfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import random
from pathlib import Path
from typing import Tuple, Optional, Any, List, Dict

import mpi4py.MPI as mpi
import numpy as np
import tables as tb

from refact_data_pipeline import DatasetOpts


def _try_open(path: Path) -> Optional[Any]:
try:
return tb.open_file(str(path), mode='r')
except Exception as e:
print(f'Cannot open the file {path}: {e}')
return None



class Hdf5Dataset:
"""
A class that maps HDF5 files to flat array of data
Parameters
----------
comm : Optional[mpi4py.MPI.Comm]
The MPI communicator.
"""

def __init__(
self,
dataopts: DatasetOpts,
files: List[Path],
comm: Optional[mpi.Comm] = None,
cold_restart_skip: Optional[int] = None
):
files = [_try_open(p) for p in files]
files = [f for f in files if f is not None]
assert len(files) > 0
self.files = files
self.tables = [file.root.data for file in self.files]
self.keys = dataopts.get("keys", "tokens;mask").split(';')
self.manual_seed = dataopts.get("hdfs_seed", None)
self.comm = comm
self.cold_restart_skip = cold_restart_skip
if self.cold_restart_skip is not None:
assert self.manual_seed is not None, \
"`cold_restart_skip` requires the manual seed, otherwise it doesn't make sence"
self.tables_lengths = [len(t) for t in self.tables]
self.tables_lengths_cumsum = np.cumsum(self.tables_lengths)
self.overall_length = self.tables_lengths_cumsum[-1]
self.index = self.__reshuffle()
self.tables_iter = None

def __del__(self):
for file in self.files:
file.close()

def __reshuffle(self) -> np.ndarray:
if self.manual_seed is None:
seed = random.randint(0, 2 ** 32 - 1)
if self.comm is not None:
seed = self.comm.bcast(seed, root=0)
else:
seed = self.manual_seed

rng = np.random.default_rng(seed)
index = rng.choice(self.overall_length, self.overall_length, replace=False)

if self.comm is not None:
rank_len = len(index) // self.comm.size
index = index[:rank_len * self.comm.size]
index = index[self.comm.rank * rank_len:(self.comm.rank + 1) * rank_len]

return index

def reshuffle(self):
assert self.manual_seed is None, "`reshuffle` with the manual seed leads to do nothing, it may be a bug"
self.index = self.__reshuffle()
assert self.tables_iter is None, "`reshuffle` cannot be called while iterating"

def __len__(self):
return len(self.index)

def __next__(self) -> Dict[str, Any]:
assert self.tables_iter is not None, "`__next__` called before `__iter__`"
iter_n, idx = next(self.tables_iter)
data = dict(zip(self.keys, self[idx]))
data['stats'] = dict(record_n=int(iter_n), restart=int(iter_n))
return data

def __iter__(self) -> 'Hdf5Dataset':
self.tables_iter = iter(enumerate(self.index))
if self.cold_restart_skip is not None:
for _ in range(self.cold_restart_skip):
next(self.tables_iter)
self.cold_restart_skip = None
return self

def __getitem__(self, idx: int) -> Tuple[Any, ...]:
table_idx, table_cumsum = next(
((i, t) for i, t in enumerate(self.tables_lengths_cumsum) if idx < t)
)
row_idx = idx - (table_cumsum - self.tables_lengths[table_idx]) - 1
row = self.tables[table_idx][row_idx]
return tuple(row[k].tolist() for k in self.keys)
Loading

0 comments on commit a4f7a1c

Please sign in to comment.