diff --git a/refact_data_pipeline/datadef.py b/refact_data_pipeline/datadef.py index 73791a5b..c52a4fcf 100644 --- a/refact_data_pipeline/datadef.py +++ b/refact_data_pipeline/datadef.py @@ -1,5 +1,6 @@ import json -from typing import List, Set +from pathlib import Path +from typing import List, Set, Tuple class DatasetDef: @@ -17,6 +18,26 @@ def __repr__(self): self.cloud_path, len(self.cloud_files), str(self.to_apply)) +class DatasetDumpedDef: + def __init__( + self, + path: str, + to_apply: Set[str], + suffixes: Tuple[str, ...] = ('.h5', '.hdf5') + ): + assert not path.startswith('gs://'), "DatasetDumpedDef doesn't support cloud-based paths " \ + "because of random access to files" + # Those paths are not cloud, just for names compatibility + self.cloud_path = path + self.cloud_files = [p for p in sorted(Path(path).iterdir(), key=lambda p: p.name) + if p.suffix in suffixes] + self.to_apply = to_apply + + def __repr__(self): + return "dataset definition %s with %i files and filters %s" % ( + self.cloud_path, len(self.cloud_path), str(self.to_apply)) + + class DatasetMix: def __init__(self, dataset_defs: List[DatasetDef], diff --git a/refact_data_pipeline/filters_chat.py b/refact_data_pipeline/filters_chat.py new file mode 100644 index 00000000..214ba097 --- /dev/null +++ b/refact_data_pipeline/filters_chat.py @@ -0,0 +1,66 @@ +import traceback +import traces +from typing import Dict + +from code_contrast.format_2023q2 import format, packing +from code_contrast.format_2023q2.el_msg import MsgElement +from code_contrast.format_2023q2.element import Format2023q2 +from code_contrast.format_2023q2.from_orig_dest_message import from_odm_dict +from refact_data_pipeline import DatasetOpts +from refact_encoding.encoding import RefactEncoding + + +class Chat2023Q2FromODM: + def __init__(self, + inner_filter, + dataopts: DatasetOpts): + self.inner_filter = inner_filter + self.n_ctx = dataopts.get("n_ctx", 2048) + self.enc: RefactEncoding = dataopts.encoding + self.fmt: Format2023q2 = format.format_2023q2_escape(self.enc) + + def __iter__(self): + stats: Dict[str, int] = { + "chatskip_failed": 0, + } + for odm in self.inner_filter: + assert len(odm['chat']) > 0 + plan = [] + for item in odm['chat']: + if len(item['instruction']) > 0: + plan.append(MsgElement("SYSTEM", item['instruction'])) + if len(item['input']) > 0: + plan.append(MsgElement("USER", item['input'])) + plan.append(MsgElement("ASSISTANT", item['output'])) + + try: + pack = packing.Packer(self.fmt) + for p in plan: + pack.add_to_plan(p) + pack.pack_context( + start_from_plan_n=0, + mask_from_plan_n=0, + limit_ctx_n=self.n_ctx, + limit_aux_n=0, + add_eot=True, + for_training=True + ) + except Exception as e: + msg = "{\n" + for key, val in odm.items(): + msg += f" {repr(key)}: {repr(val)},\n" + msg += "}" + traces.log(msg) + traces.log(traceback.format_exc()) + stats["chatskip_failed"] += 1 + continue + first = [1] + [0] * (len(pack.r) - 1) + assert len(pack.r) == len(first) + assert len(pack.r) == len(pack.m) + emit = { + "tokens": pack.r, + "mask": pack.m, + "first": first, + "stats": {**odm["stats"], **stats} + } + yield emit diff --git a/refact_data_pipeline/filters_hdfs.py b/refact_data_pipeline/filters_hdfs.py new file mode 100644 index 00000000..24e34f4d --- /dev/null +++ b/refact_data_pipeline/filters_hdfs.py @@ -0,0 +1,107 @@ +import random +from pathlib import Path +from typing import Tuple, Optional, Any, List, Dict + +import mpi4py.MPI as mpi +import numpy as np +import tables as tb + +from refact_data_pipeline import DatasetOpts + + +def _try_open(path: Path) -> Optional[Any]: + try: + return tb.open_file(str(path), mode='r') + except Exception as e: + print(f'Cannot open the file {path}: {e}') + return None + + + +class Hdf5Dataset: + """ + A class that maps HDF5 files to flat array of data + + Parameters + ---------- + comm : Optional[mpi4py.MPI.Comm] + The MPI communicator. + """ + + def __init__( + self, + dataopts: DatasetOpts, + files: List[Path], + comm: Optional[mpi.Comm] = None, + cold_restart_skip: Optional[int] = None + ): + files = [_try_open(p) for p in files] + files = [f for f in files if f is not None] + assert len(files) > 0 + self.files = files + self.tables = [file.root.data for file in self.files] + self.keys = dataopts.get("keys", "tokens;mask").split(';') + self.manual_seed = dataopts.get("hdfs_seed", None) + self.comm = comm + self.cold_restart_skip = cold_restart_skip + if self.cold_restart_skip is not None: + assert self.manual_seed is not None, \ + "`cold_restart_skip` requires the manual seed, otherwise it doesn't make sence" + self.tables_lengths = [len(t) for t in self.tables] + self.tables_lengths_cumsum = np.cumsum(self.tables_lengths) + self.overall_length = self.tables_lengths_cumsum[-1] + self.index = self.__reshuffle() + self.tables_iter = None + + def __del__(self): + for file in self.files: + file.close() + + def __reshuffle(self) -> np.ndarray: + if self.manual_seed is None: + seed = random.randint(0, 2 ** 32 - 1) + if self.comm is not None: + seed = self.comm.bcast(seed, root=0) + else: + seed = self.manual_seed + + rng = np.random.default_rng(seed) + index = rng.choice(self.overall_length, self.overall_length, replace=False) + + if self.comm is not None: + rank_len = len(index) // self.comm.size + index = index[:rank_len * self.comm.size] + index = index[self.comm.rank * rank_len:(self.comm.rank + 1) * rank_len] + + return index + + def reshuffle(self): + assert self.manual_seed is None, "`reshuffle` with the manual seed leads to do nothing, it may be a bug" + self.index = self.__reshuffle() + assert self.tables_iter is None, "`reshuffle` cannot be called while iterating" + + def __len__(self): + return len(self.index) + + def __next__(self) -> Dict[str, Any]: + assert self.tables_iter is not None, "`__next__` called before `__iter__`" + iter_n, idx = next(self.tables_iter) + data = dict(zip(self.keys, self[idx])) + data['stats'] = dict(record_n=int(iter_n), restart=int(iter_n)) + return data + + def __iter__(self) -> 'Hdf5Dataset': + self.tables_iter = iter(enumerate(self.index)) + if self.cold_restart_skip is not None: + for _ in range(self.cold_restart_skip): + next(self.tables_iter) + self.cold_restart_skip = None + return self + + def __getitem__(self, idx: int) -> Tuple[Any, ...]: + table_idx, table_cumsum = next( + ((i, t) for i, t in enumerate(self.tables_lengths_cumsum) if idx < t) + ) + row_idx = idx - (table_cumsum - self.tables_lengths[table_idx]) - 1 + row = self.tables[table_idx][row_idx] + return tuple(row[k].tolist() for k in self.keys) diff --git a/refact_data_pipeline/filters_packing.py b/refact_data_pipeline/filters_packing.py new file mode 100644 index 00000000..89edbfda --- /dev/null +++ b/refact_data_pipeline/filters_packing.py @@ -0,0 +1,310 @@ +import random +from typing import Any, Dict, List + +import binpacking +import numpy as np +import psutil +from scipy.special import softmax + +from refact_data_pipeline import DatasetOpts + +ItemT = Dict[str, Any] + + +class Packer: + """ + Pack several tokenized records along time axis. + Stat dict comes from last inner record. + """ + + def __init__(self, + inner_filter, + dataopts: DatasetOpts, + force16: bool = False, + force_pack_complete: bool = False, + force_pack1: bool = False, + keys: List[str] = ["tokens", "mask", "first"] + ): + self.inner_filter = inner_filter + self.enc = dataopts.encoding + self.pack_at_most: int = dataopts.get("pack_at_most", 6) + if force_pack1: + self.pack_at_most = 1 + self.pack_complete: int = dataopts.get("pack_complete", 0) == 1 or force_pack_complete + self.pack_pad0: int = dataopts.get("pack_pad0", 1) == 1 + self.n_ctx: int = dataopts.get("n_ctx", 2048) + self.force16 = force16 + self.keys = keys + + def __iter__(self): + accum = {k: list() for k in self.keys} + stats: Dict[str, int] = { + "packed_in": 0, + "packed_out": 0, + "packed_skip5tokens": 0, + } + last_rec_stats = dict() + + def dict_to_emit(): + nonlocal accum + stats["packed_out"] += 1 + stats["pusher_resmem"] = psutil.Process().memory_info().rss / 1e9 + last_rec_stats.update(stats) + accum_cut = {k: v[:self.n_ctx] for k, v in accum.items()} + emit = { + "stats": {**last_rec_stats, **stats}, + **accum_cut, + } + if self.pack_pad0: + for k in self.keys: + if k == "tokens": + emit[k].extend([self.enc.DIAMOND] * (self.n_ctx - len(emit[k]))) + else: + emit[k].extend([0] * (self.n_ctx - len(emit[k]))) + accum = {k: accum[k][self.n_ctx:] for k in self.keys} + return emit + + packed_n = 0 + for rec in self.inner_filter: + if sum(rec["mask"]) < 5: + stats["packed_skip5tokens"] += 1 + continue + last_rec_stats = rec["stats"] + stats["packed_in"] += 1 + existing_len = len(accum[self.keys[0]]) + if self.pack_complete: + predict_len = existing_len + len(rec["tokens"]) + if existing_len > 0 and ( + predict_len >= self.n_ctx or packed_n >= self.pack_at_most + ): + yield dict_to_emit() + for a in accum.values(): + a.clear() + packed_n = 0 + for k in self.keys: + accum[k].extend(rec[k]) + while self.force16 and len(accum[self.keys[0]]) & 15: + padlen = 16 - (len(accum[self.keys[0]]) & 15) + for k in self.keys: + if k == "tokens": + accum[k].extend([self.enc.DIAMOND] * padlen) + else: + accum[k].extend([0] * padlen) + packed_n += 1 + if not self.pack_complete: + while len(accum[self.keys[0]]) >= self.n_ctx: + yield dict_to_emit() + packed_n = 1 + len0 = len(accum[self.keys[0]]) + assert all(len0 == len(accum[k]) for k in self.keys[1:]) + if len(accum[self.keys[0]]): + yield dict_to_emit() + + +class SinglePacker: + """ + Pack several tokenized records along time axis. + Stat dict comes from last inner record. + """ + + def __init__( + self, + inner_filter, + dataopts: DatasetOpts, + keys: List[str] = ["tokens", "first"] + ): + self.inner_filter = inner_filter + self.enc = dataopts.encoding + self.n_ctx: int = dataopts.get("n_ctx", 2048) + self.keys = keys + + def __iter__(self): + for rec in self.inner_filter: + output = dict(stats=rec["stats"]) + for k in self.keys: + if len(rec[k]) < self.n_ctx: + rec[k] += [self.enc.DIAMOND] * (self.n_ctx - len(rec[k])) + output[k] = rec[k][:self.n_ctx] + output["mask"] = [t != self.enc.DIAMOND for t in output['tokens']] + yield output + + +class DensePacker: + """ + Pack several tokenized records along the time axis. + Stat dict comes from last inner record. + """ + + def __init__( + self, + inner_filter, + dataopts: DatasetOpts, + ): + self.inner_filter_iter = iter(inner_filter) + self.enc = dataopts.encoding + self.n_ctx: int = dataopts['n_ctx'] + self.pack_single: bool = dataopts.get('pack_single', 0) == 1 + self.pack_complete: bool = dataopts.get('pack_complete', 1) == 1 + self.drop_less_than_t: int = dataopts.get('pack_drop_less_than_t', 6) + self.buffer_size: int = dataopts.get('pack_buffer_size', 256) + self.keys = dataopts.get('packer_keys', 'tokens;mask;first').split(';') + self.max_packing_rounds = 8 + self.do_nothing_keys = ['stats'] + assert len(self.keys) > 0 + self.buffer = [] + self.stats = dict( + packed_in=0, + packed_out=0, + packed_small_dropped=0, + last_paddings_perc=0.0 + ) + + def __make_padded_item(self, length: int) -> ItemT: + padded_item = dict() + for k in self.keys: + if k == 'tokens': + padded_item[k] = [self.enc.DIAMOND for _ in range(length)] + elif k in {'mask', 'first'}: + padded_item[k] = [0 for _ in range(length)] + else: + assert f'Unknown key={k} to process' + return padded_item + + def __item_len(self, item: ItemT) -> int: + return len(item[self.keys[0]]) + + def __items_len(self, items): + return sum(self.__item_len(i) for i in items) + + def __fill_buffer(self): + while True: + item = next(self.inner_filter_iter, None) + if item is None: + break + if self.__item_len(item) <= self.drop_less_than_t: + self.stats['packed_small_dropped'] += 1 + continue + if len(self.buffer) < self.buffer_size: + self.buffer.append(item) + else: + break + + def __add_to_acc( + self, + items_acc: List[ItemT], + items_to_add: List[ItemT] + ) -> List[ItemT]: + left_overs = [] + for item in items_to_add: + item_to_add, left_over_item = dict(), dict() + length_to_add = self.n_ctx - self.__items_len(items_acc) + for key in self.keys: + item_to_add[key] = item[key][:length_to_add] + left_over_item[key] = item[key][length_to_add:] + for key in self.do_nothing_keys: + if key not in item: + continue + item_to_add[key] = item[key] + left_over_item[key] = item[key] + items_acc.append(item_to_add) + if not self.pack_complete and self.__item_len(left_over_item) > self.drop_less_than_t: + left_overs.append(left_over_item) + elif not self.pack_complete and self.__item_len(left_over_item) <= self.drop_less_than_t: + self.stats['packed_small_dropped'] += 1 + return left_overs + + def __find_best_for_budget(self, budget: int, force_random_get: bool = False) -> List[ItemT]: + def _pop_item_by_length(length: int) -> ItemT: + idx = next((idx for idx, item in enumerate(self.buffer) + if self.__item_len(item) == length), None) + assert idx is not None, f'No item with length={length}' + return self.buffer.pop(idx) + + assert len(self.buffer) > 0 + if budget == 0: + return [] + + if force_random_get or not self.pack_complete: + item = self.buffer.pop(random.randint(0, len(self.buffer) - 1)) + return [item] + else: + lengths = [self.__item_len(i) for i in self.buffer] + lengths = [l for l in lengths if l <= budget] + if len(lengths) == 0: + return [] + # we can up-weight `old` items later + bins = binpacking.to_constant_volume(lengths, budget) + if len(bins) == 0: + return [] + + # prioritize items with larger lengths + p = softmax(np.exp(np.array([sum(b) for b in bins]) / budget * 2)) + bin = bins[np.random.choice(list(range(len(bins))), p=p)] + items = [_pop_item_by_length(l) for l in bin] + return items + + def __merge_items( + self, + items_acc: List[ItemT], + random_order: bool + ) -> ItemT: + assert len(items_acc) > 0 + + if random_order: + np.random.shuffle(items_acc) + last_item = items_acc[-1] + if self.__items_len(items_acc) < self.n_ctx: + items_acc.append(self.__make_padded_item(self.n_ctx - self.__items_len(items_acc))) + + output_item = dict([(k, []) for k in self.keys]) + # taking the last item for other useful keys + output_item.update(dict([(k, last_item[k]) for k in self.do_nothing_keys if k in last_item])) + if 'stats' in output_item: + output_item['stats'].update(self.stats) + else: + output_item['stats'] = self.stats + + for item in items_acc: + for k in self.keys: + output_item[k].extend(item[k]) + + return output_item + + def __iter__(self): + def _pack_iteration(acc, force_random_get=False): + items = self.__find_best_for_budget( + budget=self.n_ctx - self.__items_len(acc), + force_random_get=force_random_get + ) + if len(items) > 0: + self.stats['packed_in'] += len(items) + leftovers = self.__add_to_acc(acc, items) + self.buffer.extend(leftovers) + return len(items) + + def _merge_acc(acc): + assert len(acc) > 0 + self.stats['packed_out'] += 1 + output_item = self.__merge_items(acc, random_order=True) + if 'tokens' in output_item: + self.stats['last_paddings_perc'] = \ + (np.array(output_item['tokens']) == self.enc.DIAMOND).sum() / self.n_ctx + return output_item + + while True: + self.__fill_buffer() + if len(self.buffer) == 0: + raise StopIteration() + + items_acc = [] + _pack_iteration(acc=items_acc, force_random_get=True) + if self.pack_single: + yield _merge_acc(acc=items_acc) + continue + + for _ in range(self.max_packing_rounds): + packed = _pack_iteration(acc=items_acc) + if packed == 0: + break + + yield _merge_acc(acc=items_acc) diff --git a/refact_data_pipeline/pipeline_pieces.py b/refact_data_pipeline/pipeline_pieces.py index 03042619..d677c1df 100644 --- a/refact_data_pipeline/pipeline_pieces.py +++ b/refact_data_pipeline/pipeline_pieces.py @@ -8,10 +8,14 @@ import datetime import traceback +from mpi4py import MPI + from refact_encoding import RefactEncoding from refact_data_pipeline.datadef import DatasetOpts -from refact_data_pipeline.datadef import DatasetDef +from refact_data_pipeline.datadef import DatasetDef, DatasetDumpedDef from refact_data_pipeline.datadef import DatasetMix +from refact_data_pipeline.filters_hdfs import Hdf5Dataset +from refact_data_pipeline.filters_packing import Packer, SinglePacker, DensePacker from typing import Dict, List, Union, Iterable, Any @@ -219,121 +223,6 @@ def __iter__(self): } -class Packer: - """ - Pack several tokenized records along time axis. - Stat dict comes from last inner record. - """ - def __init__(self, - inner_filter, - dataopts: DatasetOpts, - force16: bool=False, - force_pack_complete: bool=False, - force_pack1: bool=False, - keys: List[str] = ["tokens", "mask", "first"] - ): - self.inner_filter = inner_filter - self.enc = dataopts.encoding - self.pack_at_most: int = dataopts.get("pack_at_most", 6) - if force_pack1: - self.pack_at_most = 1 - self.pack_complete: int = dataopts.get("pack_complete", 0) == 1 or force_pack_complete - self.pack_pad0: int = dataopts.get("pack_pad0", 1) == 1 - self.n_ctx: int = dataopts.get("n_ctx", 2048) - self.force16 = force16 - self.keys = keys - - def __iter__(self): - accum = {k: list() for k in self.keys} - stats: Dict[str, int] = { - "packed_in": 0, - "packed_out": 0, - "packed_skip5tokens": 0, - } - last_rec_stats = dict() - def dict_to_emit(): - nonlocal accum - stats["packed_out"] += 1 - stats["pusher_resmem"] = psutil.Process().memory_info().rss / 1e9 - last_rec_stats.update(stats) - accum_cut = {k: v[:self.n_ctx] for k, v in accum.items()} - emit = { - "stats": {**last_rec_stats, **stats}, - **accum_cut, - } - if self.pack_pad0: - for k in self.keys: - if k=="tokens": - emit[k].extend([self.enc.DIAMOND]*(self.n_ctx - len(emit[k]))) - else: - emit[k].extend([0]*(self.n_ctx - len(emit[k]))) - accum = {k: accum[k][self.n_ctx:] for k in self.keys} - return emit - packed_n = 0 - for rec in self.inner_filter: - if sum(rec["mask"]) < 5: - stats["packed_skip5tokens"] += 1 - continue - last_rec_stats = rec["stats"] - stats["packed_in"] += 1 - existing_len = len(accum[self.keys[0]]) - if self.pack_complete: - predict_len = existing_len + len(rec["tokens"]) - if existing_len > 0 and ( - predict_len >= self.n_ctx or packed_n >= self.pack_at_most - ): - yield dict_to_emit() - for a in accum.values(): - a.clear() - packed_n = 0 - for k in self.keys: - accum[k].extend(rec[k]) - while self.force16 and len(accum[self.keys[0]]) & 15: - padlen = 16 - (len(accum[self.keys[0]]) & 15) - for k in self.keys: - if k=="tokens": - accum[k].extend([self.enc.DIAMOND]*padlen) - else: - accum[k].extend([0]*padlen) - packed_n += 1 - if not self.pack_complete: - while len(accum[self.keys[0]]) >= self.n_ctx: - yield dict_to_emit() - packed_n = 1 - len0 = len(accum[self.keys[0]]) - assert all(len0==len(accum[k]) for k in self.keys[1:]) - if len(accum[self.keys[0]]): - yield dict_to_emit() - - -class SinglePacker: - """ - Pack several tokenized records along time axis. - Stat dict comes from last inner record. - """ - - def __init__( - self, - inner_filter, - dataopts: DatasetOpts, - keys: List[str] = ["tokens", "first"] - ): - self.inner_filter = inner_filter - self.enc = dataopts.encoding - self.n_ctx: int = dataopts.get("n_ctx", 2048) - self.keys = keys - - def __iter__(self): - for rec in self.inner_filter: - output = dict(stats=rec["stats"]) - for k in self.keys: - if len(rec[k]) < self.n_ctx: - rec[k] += [self.enc.DIAMOND] * (self.n_ctx - len(rec[k])) - output[k] = rec[k][:self.n_ctx] - output["mask"] = [t != self.enc.DIAMOND for t in output['tokens']] - yield output - - class Shuffle: def __init__(self, inner_filter, @@ -382,10 +271,10 @@ def build_filter_stack( datadef: Union[DatasetDef, DatasetMix], dataopts: DatasetOpts, enc: RefactEncoding, - comm: Any, + comm: MPI.Comm, cold_restart: List[int] = [], cold_restart_offset = 0, - skip_assert_flag: bool = False, + skip_assert_flag: bool = False ): dataopts.set_encoding(enc) if isinstance(datadef, DatasetMix): @@ -401,12 +290,17 @@ def build_filter_stack( cold_restart = [0]*comm.size path = datadef.cloud_path files_len = len(datadef.cloud_files) - if files_len == 1: - my_files = datadef.cloud_files - elif files_len % comm.size == 0: - my_files = [fn for i, fn in enumerate(datadef.cloud_files) if i % comm.size == comm.rank] + + if not isinstance(datadef, DatasetDumpedDef): + if files_len == 1: + my_files = datadef.cloud_files + elif files_len % comm.size == 0: + my_files = [fn for i, fn in enumerate(datadef.cloud_files) if i % comm.size == comm.rank] + else: + assert 0, "datadef.cloud_files has %i files, but comm.size is %i" % (files_len, comm.size) else: - assert 0, "datadef.cloud_files has %i files, but comm.size is %i" % (files_len, comm.size) + my_files = datadef.cloud_files + log("dataset '%s' has %i files" % (path, len(my_files))) assert len(my_files) > 0 ds = None @@ -416,6 +310,10 @@ def build_filter_stack( cold_restart_key=cold_restart_offset + comm.rank, cold_restart_skip=cold_restart[cold_restart_offset + comm.rank], ) + elif ds is None and filt == 'hdfs': + ds = Hdf5Dataset(dataopts, my_files, comm=comm, + cold_restart_skip=cold_restart[cold_restart_offset + comm.rank], + ) elif filt == "splitranks": ds = SplitRanks(ds, dataopts, commrank=comm.rank, commsize=comm.size) elif ds and filt == "tokenize": @@ -426,6 +324,8 @@ def build_filter_stack( ds = Packer(ds, dataopts) elif ds and filt == "single_pack": ds = SinglePacker(ds, dataopts) + elif ds and filt == "dense_pack": + ds = DensePacker(ds, dataopts) elif ds and filt == "pack16": ds = Packer(ds, dataopts, force16=True) elif ds and filt == "shuffle":