From d11f89c624c8ac5255385bb78381e0679a37f738 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Fri, 15 Nov 2024 10:51:23 -0800 Subject: [PATCH 01/56] ignore __dj__produced_data__ --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 6ea36c585..6ba50fada 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ dist wandb/ __pycache__ .vscode/ +**/__dj__produced_data__/* From 41dea264fbd6f62b2c948d61581326922c6a0e2d Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Mon, 18 Nov 2024 22:27:37 -0800 Subject: [PATCH 02/56] add download framework; add wiki support --- .gitignore | 1 + data_juicer/download/__inif__.py | 0 data_juicer/download/arxiv.py | 405 ++++++++++++++ data_juicer/download/base.py | 223 ++++++++ data_juicer/download/commoncrawl.py | 0 data_juicer/download/wikipedia.py | 793 ++++++++++++++++++++++++++++ data_juicer/utils/file_utils.py | 169 +++++- environments/minimal_requires.txt | 1 + tests/download/__init__.py | 0 tests/download/test_download.py | 23 + 10 files changed, 1614 insertions(+), 1 deletion(-) create mode 100644 data_juicer/download/__inif__.py create mode 100644 data_juicer/download/arxiv.py create mode 100644 data_juicer/download/base.py create mode 100644 data_juicer/download/commoncrawl.py create mode 100644 data_juicer/download/wikipedia.py create mode 100644 tests/download/__init__.py create mode 100644 tests/download/test_download.py diff --git a/.gitignore b/.gitignore index 6ba50fada..48108c132 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ wandb/ __pycache__ .vscode/ **/__dj__produced_data__/* +venv/ diff --git a/data_juicer/download/__inif__.py b/data_juicer/download/__inif__.py new file mode 100644 index 000000000..e69de29bb diff --git a/data_juicer/download/arxiv.py b/data_juicer/download/arxiv.py new file mode 100644 index 000000000..af77427e1 --- /dev/null +++ b/data_juicer/download/arxiv.py @@ -0,0 +1,405 @@ +import gzip +import os +import re +import subprocess +import tarfile +import tempfile + +from base import (DocumentDownloader, DocumentExtractor, DocumentIterator, + download_and_extract, get_arxiv_urls) + +from data_juicer.core.data import DJDataset +from data_juicer.utils.file_utils import (expand_outdir_and_mkdir, + get_all_files_paths_under) + +# The iterator and extractor code are in large part taken +# from the Red-Pajama repo +# https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/arxiv + + +class ArxivDownloader(DocumentDownloader): + + def __init__(self, download_dir, verbose=False): + super().__init__() + self._download_dir = download_dir + self._verbose = False + + def download(self, tarfile): + output_file = os.path.join(self._download_dir, tarfile) + s3path = os.path.join('s3://arxiv/src', tarfile) + if os.path.exists(output_file): + print(f'tar file: {output_file} exists. Not downloading') + else: + print(f'Downloading {s3path} and writing to {output_file}') + cmd = [ + 's5cmd', '--request-payer=requester', 'cp', s3path, output_file + ] + if self._verbose: + stdout, stderr = None, None + else: + stdout, stderr = subprocess.DEVNULL, subprocess.DEVNULL + p = subprocess.run( + cmd, + stdout=stdout, + stderr=stderr, + ) + if p.returncode != 0: + print(f'Failed to download {s3path} to {output_file}') + + return output_file + + +class ArxivIterator(DocumentIterator): + + def __init__(self, log_frequency=1000): + super().__init__() + self._log_freq = log_frequency + self._cnt = 0 + + def iterate(self, file_path): + self._cnt = 0 + download_dir = os.path.split(file_path)[0] + bname = os.path.split(file_path)[-1] + # with (tempfile...), yapf spits lib2to3.pgen2.parse.ParseError + with tempfile.TemporaryDirectory(dir=download_dir) as tmpdir: + with tarfile.open(file_path) as tf: + tf.extractall(members=tf.getmembers(), path=tmpdir) + for i, item in enumerate(get_all_files_paths_under(tmpdir)): + if self._cnt > 0 and self._cnt % self._log_freq == 0: + print(f'Extracted {self._cnt} papers from {file_path}') + self._cnt += 1 + + tex_files = self._tex_proj_loader(item) + arxiv_id = os.path.splitext(os.path.split(item)[-1])[0] + + # get the arxiv id in the correct format + try: + clean_arxiv_id = self._format_arxiv_id(arxiv_id) + except Exception as e: + print( + f'[WARNING] failed to format arxiv id {arxiv_id}; exception={e}' # noqa: E501 + ) + clean_arxiv_id = arxiv_id + + if tex_files is None: + continue + + yield { + 'id': clean_arxiv_id, + 'source_id': f'{bname}' + }, tex_files + + def _tex_proj_loader(self, file_or_dir_path): + r"""function to load the tex files from a tar file or a gzip file. The + function will return a tuple containing a list of tex files and the + timestamp of the project. + + @param file_or_dir_path: path to the tar file or the gzip file + + @return: tuple containing a list of tex files and the timestamp of the + project + """ # noqa E501 + files_and_content = [] + + try: + # if it is a directory, open it as a tarfile + with tarfile.open(file_or_dir_path) as sub_tf: + for member in sub_tf.getmembers(): + if member.name.endswith('.tex'): + + file_content = sub_tf.extractfile(member).read() + + try: + file_content = file_content.decode('utf-8') + except UnicodeDecodeError: + return None + + files_and_content.append(file_content) + + except tarfile.ReadError: + # otherwise we try opening it as a gzip file + try: + with gzip.open(file_or_dir_path, 'rb') as gz: + file_content = gz.read() + except Exception: + # all fails, we skip this file + # self._logger.info(f"[ERROR] {e}: {file_or_dir_path}") + return None + + try: + file_content = file_content.decode('utf-8') + except UnicodeDecodeError: + # self._logger.info(f"UnicodeDecodeError: {file_or_dir_path}") + return None + + files_and_content.append(file_content) + + except Exception as e: + print(f'[ERROR] {e}: {file_or_dir_path}') + return None + + return files_and_content + + def _format_arxiv_id(self, arxiv_id): + r"""this function brings the raw arxiv-id into a format compliant with the + specification from arxiv. This is used to create the url to the arxiv + abstract page. + + - Format prior to March 2007: + /YYMMNNN where N is a 3-digit number + - Format after March 2007: /YYMM.NNNNN where N is a + 5 (or 6)-digit number + + References: https://info.arxiv.org/help/arxiv_identifier.html + + @param arxiv_id: raw arxiv id which can be in one of the + following formats: + - + - + + @return: formatted arxiv id + """ # noqa: E501 + match = re.search(r'^([a-zA-Z-]*)([\d\.]+)$', arxiv_id) + + if match is None: + raise ValueError(f'Invalid arxiv id: {arxiv_id}') + + if match.group(1) == '': + return match.group(2) + + return f'{match.group(1)}/{match.group(2)}' + + +class ArxivExtractor(DocumentExtractor): + + def __init__(self): + super().__init__() + + def extract(self, content): + if len(content) == 0: + return None + + # build dictionaries that contain the definitions of all + # macros in all text files. This is later used to expand + # all macros used in the text with their definitions, so + # that consistency among different authors is ensured + + non_arg_macros = {} + for file_content in content: + non_arg_macros.update( + self._build_non_arg_macros_dict(file_content)) + + # TODO: macros that take arguments are not supported yet + arg_macros = {} + + # join multiple latex files with a newline character + try: + cleaned_latex_file_str = '\n'.join( + self._clean_tex_file( + file_content=file_content, + arg_macros=arg_macros, + non_arg_macros=non_arg_macros, + ) for file_content in content) + except Exception: + return {}, None + + # Don't return meta + if cleaned_latex_file_str is not None: + if len(cleaned_latex_file_str) > 0: + return {}, cleaned_latex_file_str + + def _clean_tex_file(self, file_content, arg_macros, non_arg_macros): + r"""function takes a tex file as input and returns a cleaned version. The + cleaned version is a concatenation of the tex files with the + following modifications: + + - remove all comments (i.e. all lines starting with %) + - remove everything before the first section-like header + - remove everything after the first occurrence of either \appendix or + \bibliography + - inline-expand definitions and macros + + @param file_content: the content of the tex file as a string. + + @return: cleaned tex file as a string + """ # noqa: E501 + # find the first occurence of a \section-like header and replace + # everything before it with an empty string. This matches the + # following pattern: \[optional-args]{name} + pattern = r'^(.*?)(' + pattern += r'\\\bchapter\b\*?(?:\[(.*?)\])?\{(.*?)\}|' + pattern += r'\\\bpart\b\*?(?:\[(.*?)\])?\{(.*?)\}|' + pattern += r'\\\bsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|' + pattern += r'\\\bsubsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|' + pattern += r'\\\bsubsubsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|' + pattern += r'\\\bparagraph\b\*?(?:\[(.*?)\])?\{(.*?)\}' + pattern += r'\\\bsubparagraph\b\*?(?:\[(.*?)\])?\{(.*?)\}' + pattern += r')' + + # if no section like header is found, then we return an empty string + if not re.search(pattern, file_content, flags=re.DOTALL): + return '' + + # replace everything with the second group of the match + # (i.e. everything after and including the section header) + file_content = re.sub( + pattern=pattern, + repl=r'\2', + string=file_content, + flags=re.DOTALL, # make sure that the dot matches also newlines + ) + + # remove all line comments + file_content = re.sub( + pattern=r'(?m)^%.*\n?', + repl=r'', + string=file_content, + flags=re.MULTILINE, + ) + + # remove all in comments within a line + file_content = re.sub( + # pattern matches a "%" that is not preceded by a backslash + pattern=r'[^\\]%.+$', + repl=r'', + string=file_content, + flags=re.MULTILINE, + ) + + # find the first occurence of either \appendix or \bibliography and + # replace everything after it with an empty string + pattern = r'(' + pattern += r'\\appendix|' + pattern += r'\\begin\{references\}|' + pattern += r'\\begin\{REFERENCES\}|' + pattern += r'\\begin\{thebibliography\}|' + pattern += r'\\bibliography\{.*\}' + pattern += r').*$' + + file_content = re.sub( + pattern=pattern, + repl=r'', + string=file_content, + flags=re.DOTALL, # make sure that the dot matches also newlines + ) + + # inline-expand all non-arg macros + for macro_name, macro_value in non_arg_macros.items(): + file_content = re.sub( + # make pattern grouped to make sure that the macro is not part + # of a longer alphanumeric word + pattern=r'(' + macro_name + r')' + r'([^a-zA-Z0-9])', + # replace the macro with its value and add back the character + # that was matched after the macro + repl=macro_value + r'\2', + string=file_content, + ) + + # inline-expand all macros that use args + # TODO: inline-expand macros with args + for macro_name, macro_value in arg_macros.items(): + pass + + return file_content + + def _build_non_arg_macros_dict(self, file_content): + r"""function takes the content of a tex file and returns a dictionary + that contains the definitions of all macros that do not use arguments. + The dictionary is of the form {macro_name: macro_value}. + + @param file_content: the content of the tex file as a string. + + @return: dict + """ + # regex for extracting \newcommand macros without arguments + non_arg_nc_reg = re.compile( + # this regex matches the following: + # \newcommand{\macro_name}{macro_value} + # \newcommand*{\macro_name}{macro_value} + # where macro_name is only allowed to contain letters and numbers; + # macro_value can contain any character. + pattern=r'\\\bnewcommand\b\*?\{(\\[a-zA-Z0-9]+?)\}\{(.*?)\}$', + flags=re.MULTILINE, + ) + + # regex for extracting \def macros without arguments + non_arg_def_reg = re.compile( + # this regex matches the following: + # \def\macro_name{macro_value} + # where macro_name is only allowed to contain letters and numbers; + # macro_value can contain any character. + pattern=r'\\def\s*(\\[a-zA-Z0-9]+?)\s*\{(.*?)\}$', + flags=re.MULTILINE, + ) + + # Extract all user-defined LaTeX macros from the preamble + macros = {} + for reg in [non_arg_nc_reg, non_arg_def_reg]: + for match in reg.finditer(file_content): + # convert the macro name and value to a raw string that can be + # used in re.sub + macro_name = match.group(1).encode('unicode-escape').decode( + 'utf-8') + macro_val = match.group(2).encode('unicode-escape').decode( + 'utf-8') + + macros[macro_name] = macro_val + + return macros + + +def download_arxiv( + output_path: str, + output_type: str = 'jsonl', + raw_download_dir=None, + keep_raw_download=False, + force_download=False, + url_limit=None, +) -> DJDataset: + """ + Downloads Arxiv tar files and extracts them + + Args: + output_path: The path to the root directory of the files + output_type: The file type to save the data as. + raw_download_dir: Path to store the raw download files for intermediate processing. + If None, they are stored in a folder named "downloads" under output_path. + keep_raw_download: If True, keeps the compressed WARC files that have not been extracted. + force_download: If False, will skip processing all files in output_paths that already exist and + directly read from them instead. + url_limit: The maximum number of raw files to download from the snapshot. If None, all + files from the range of snapshots are downloaded. + """ # noqa: E501 + arxiv_urls = get_arxiv_urls() + if url_limit: + arxiv_urls = arxiv_urls[:url_limit] + output_paths = list( + map(lambda url: os.path.join(output_path, f'{url}.{output_type}'), + arxiv_urls)) + + if not raw_download_dir: + raw_download_dir = os.path.join(output_path, 'downloads') + expand_outdir_and_mkdir(raw_download_dir) + downloader = ArxivDownloader(raw_download_dir) + iterator = ArxivIterator() + extractor = ArxivExtractor() + + output_format = { + 'text': str, + 'id': str, + 'source_id': str, + 'filename': str, + } + dataset = download_and_extract( + arxiv_urls, + output_paths, + downloader, + iterator, + extractor, + output_format, + output_type=output_type, + keep_raw_download=keep_raw_download, + force_download=force_download, + ) + + return dataset diff --git a/data_juicer/download/base.py b/data_juicer/download/base.py new file mode 100644 index 000000000..b15c67bc2 --- /dev/null +++ b/data_juicer/download/base.py @@ -0,0 +1,223 @@ +import json +import os +import subprocess +from abc import ABC, abstractmethod +from functools import partial +from typing import List, Optional, Tuple, Union +from urllib.parse import urljoin + +import pandas as pd +import requests +from bs4 import BeautifulSoup +from datasets import Dataset + +from data_juicer.utils.file_utils import (read_single_partition, + single_partition_write_with_filename) + + +class DocumentDownloader(ABC): + """Abstract class for downloading remote data to disk""" + + def __init__(self): + super().__init__() + + @abstractmethod + def download(self, url): + pass + + +class DocumentIterator(ABC): + """ + Abstract iterator class for reading in raw records that have been + downloaded to disk + """ + + def __init__(self): + super().__init__() + + @abstractmethod + def iterate(self, file_path): + pass + + +class DocumentExtractor(ABC): + """Abstract class for extracting text from records read from disk""" + + def __init__(self): + super().__init__() + + @abstractmethod + def extract(self, content): + pass + + +def _download_and_extract_single_partition(paths: List[Tuple[str, str]], + downloader: DocumentDownloader, + iterator: DocumentIterator, + extractor: DocumentExtractor, + output_type: str, + keep_raw_download: bool, + force_download: bool, + input_meta: Union[str, dict] = None, + meta: Union[str, dict] = None, + item_limit=None) -> pd.DataFrame: + url, output_path = paths + + if os.path.exists(output_path) and not force_download: + partition = read_single_partition([output_path], + filetype=output_type, + add_filename=True) + return partition + + downloaded_file = downloader.download(url) + records = [] + # Iterate over all records in file + item_count = 0 + for item in iterator.iterate(downloaded_file): + item_count += 1 + if item_limit and item_count >= item_limit: + break + record_meta, content = item + # Extract the text from the record + extracted = extractor.extract(content) + if extracted is not None: + text_meta, text = extracted + if text is not None: + line = { + 'text': text, + **text_meta, + **record_meta, + } + records.append(line) + + partition = pd.DataFrame(records) + filename = os.path.basename(output_path) + output_dir = os.path.dirname(output_path) + partition['filename'] = filename + single_partition_write_with_filename(partition, + output_dir, + output_type=output_type) + if not keep_raw_download: + os.remove(downloaded_file) + + return partition + + +def download_and_extract(urls: List[str], + output_paths: List[str], + downloader: DocumentDownloader, + iterator: DocumentIterator, + extractor: DocumentExtractor, + output_format: dict, + output_type: str = 'jsonl', + keep_raw_download=False, + force_download=False, + input_meta: Union[str, dict] = None, + item_limit=None) -> Dataset: + """ + Downloads and extracts a dataset + + Args: + urls: A list of urls to download the dataset from + output_paths: A list of paths to save the final extracted output to. + The raw output of the downloader will be saved using the path given by downloader.download(url). + downloader: A DocumentDownloader that handles retrieving each file from its url and saving it to storage + iterator: A DocumentIterator that handles iterating through the downloaded file's format + extractor: A DocumentExtractor that handles extracting the data from its raw format into text + output_format: A dictionary mappings columns to datatypes for the fields of each datapoint after extraction. + output_type: The file type to save the dataset as. + keep_raw_download: Whether to keep the pre-extracted download file. + force_download: If False, will skip processing all files in output_paths that already exist and + directly read from them instead. + input_meta: A dictionary or a string formatted as a dictionary, which outlines + the field names and their respective data types within the JSONL input file. + + Returns: + A HuggingFace DataSet of the downloaded data + """ # noqa: E501 + if len(urls) == 0: + raise ValueError('No urls were provided to download') + + if len(urls) != len(output_paths): + raise ValueError( + f'Different number of urls and output_paths. ' + f'{len(urls)} urls vs {len(output_paths)} output_paths') + + output_format = dict(sorted(output_format.items())) + part = partial(_download_and_extract_single_partition, + downloader=downloader, + iterator=iterator, + extractor=extractor, + output_type=output_type, + keep_raw_download=keep_raw_download, + force_download=force_download, + input_meta=input_meta, + meta=output_format, + item_limit=item_limit) + combined_df = pd.concat(map(part, zip(urls, + output_paths))) # list of DataFrames + return Dataset.from_pandas(combined_df) + + +def get_wikipedia_urls( + language='en', + wikidumps_index_prefix='https://dumps.wikimedia.org', + dump_date: Optional[str] = None, +) -> List[str]: + """ + Retrieves all urls pointing to the latest Wikipedia dumps + + Args: + language: Desired language of the Wikipedia dump. + wikidumps_index_prefix: The base url for all wikipedia dumps + dump_date: A string formatted as "YYYYMMDD" for the wikipedia dump to use. + If None, latest dump is used. + """ # noqa: E501 + wiki_index_url = urljoin(wikidumps_index_prefix, f'{language}wiki') + if not dump_date: + # First get the index + raw_wiki_index = requests.get(wiki_index_url) + wiki_index = raw_wiki_index.content.decode('utf-8') + wiki_index_parsed = BeautifulSoup(wiki_index, 'lxml') + + # Get all dumps available in the index + dumps = wiki_index_parsed.find_all('a') + dump_date = dumps[-2].text + else: + # A trailing / is needed for the url + dump_date = dump_date + '/' + + # Get the json dump data + wiki_latest_dump = urljoin(wiki_index_url + '/', dump_date) + wiki_latest_dump_status = urljoin(wiki_latest_dump, 'dumpstatus.json') + raw_dump_data = requests.get(wiki_latest_dump_status) + try: + dump_data = json.loads(raw_dump_data.content) + except json.decoder.JSONDecodeError: + raise ValueError(f'No wikipedia dump found for {dump_date[:-1]}') + + # Get all multistream files within the dump data + wikipedia_urls = [] + for ifile in dump_data['jobs']['articlesmultistreamdump']['files']: + if 'xml' in ifile: + url = urljoin(wiki_latest_dump, ifile) + wikipedia_urls.append(url) + + return wikipedia_urls + + +def get_arxiv_urls(): + command =\ + "s5cmd --request-payer=requester ls s3://arxiv/src/ | grep '.tar'" + result = subprocess.run(command, + capture_output=True, + text=True, + shell=True) + + if result.returncode != 0: + raise RuntimeError(f'Unable to get arxiv urls: {result.stderr}') + + urls = result.stdout.split()[3::4] + urls.sort() + + return urls diff --git a/data_juicer/download/commoncrawl.py b/data_juicer/download/commoncrawl.py new file mode 100644 index 000000000..e69de29bb diff --git a/data_juicer/download/wikipedia.py b/data_juicer/download/wikipedia.py new file mode 100644 index 000000000..ae8683b88 --- /dev/null +++ b/data_juicer/download/wikipedia.py @@ -0,0 +1,793 @@ +import bz2 +import codecs +import os +import re +import subprocess +import xml.etree.cElementTree as etree +from urllib.parse import quote, urlparse + +import mwparserfromhell +from datasets import Dataset + +from data_juicer.utils.file_utils import expand_outdir_and_mkdir + +from .base import (DocumentDownloader, DocumentExtractor, DocumentIterator, + download_and_extract, get_wikipedia_urls) + +# The majority of this code is taken from the HuggingFace +# implementation of the Wikipedia dataset preparation: +# https://github.com/huggingface/datasets/blob/7e30308f49f8c85dc7a2ab5aafbff04b5d2f38e2/datasets/wikipedia/wikipedia.py + +MEDIA_ALIASES = { + 'ab': ['Медиа', 'Файл', 'Афаил', 'Амедиа', 'Изображение'], + 'ace': ['Beureukaih', 'Gambar', 'Alat', 'Berkas'], + 'ady': ['Медиа'], + 'af': ['Lêer', 'Beeld'], + 'als': ['Medium', 'Datei', 'Bild'], + 'am': ['ፋይል', 'ስዕል'], + 'an': ['Imachen', 'Imagen'], + 'ang': ['Ymele', 'Biliþ'], + 'ar': ['ميديا', 'صورة', 'وسائط', 'ملف'], + 'arc': ['ܠܦܦܐ', 'ܡܝܕܝܐ'], + 'arz': ['ميديا', 'صورة', 'وسائط', 'ملف'], + 'as': ['চিত্ৰ', 'चित्र', 'চিত্র', 'মাধ্যম'], + 'ast': ['Imaxen', 'Ficheru', 'Imaxe', 'Archivu', 'Imagen', 'Medios'], + 'atj': ['Tipatcimoctakewin', 'Natisinahikaniwoc'], + 'av': ['Медиа', 'Файл', 'Изображение'], + 'ay': ['Medio', 'Archivo', 'Imagen'], + 'az': ['Mediya', 'Şəkil', 'Fayl'], + 'azb': ['رسانه', 'تصویر', 'مدیا', 'فایل', 'رسانه‌ای'], + 'ba': ['Медиа', 'Рәсем', 'Файл', 'Изображение'], + 'bar': ['Medium', 'Datei', 'Bild'], + 'bat-smg': ['Vaizdas', 'Medėjė', 'Abruozdielis'], + 'bcl': ['Medio', 'Ladawan'], + 'be': ['Мультымедыя', 'Файл', 'Выява'], + 'be-x-old': ['Мэдыя', 'Файл', 'Выява'], + 'bg': ['Медия', 'Файл', 'Картинка'], + 'bh': ['मीडिया', 'चित्र'], + 'bjn': ['Barakas', 'Gambar', 'Berkas'], + 'bm': ['Média', 'Fichier'], + 'bn': ['চিত্র', 'মিডিয়া'], + 'bpy': ['ছবি', 'মিডিয়া'], + 'br': ['Skeudenn', 'Restr'], + 'bs': ['Mediji', 'Slika', 'Datoteka', 'Medija'], + 'bug': ['Gambar', 'Berkas'], + 'bxr': ['Файл', 'Меди', 'Изображение'], + 'ca': ['Fitxer', 'Imatge'], + 'cbk-zam': ['Medio', 'Archivo', 'Imagen'], + 'cdo': ['文件', '媒體', '圖像', '檔案'], + 'ce': ['Хlум', 'Медиа', 'Сурт', 'Файл', 'Медйа', 'Изображение'], + 'ceb': ['Payl', 'Medya', 'Imahen'], + 'ch': ['Litratu'], + 'ckb': ['میدیا', 'پەڕگە'], + 'co': ['Immagine'], + 'crh': ['Медиа', 'Resim', 'Файл', 'Fayl', 'Ресим'], + 'cs': ['Soubor', 'Média', 'Obrázok'], + 'csb': ['Òbrôzk', 'Grafika'], + 'cu': ['Видъ', 'Ви́дъ', 'Дѣло', 'Срѣдьства'], + 'cv': ['Медиа', 'Ӳкерчĕк', 'Изображение'], + 'cy': ['Delwedd'], + 'da': ['Billede', 'Fil'], + 'de': ['Medium', 'Datei', 'Bild'], + 'din': ['Ciɛl', 'Apamduööt'], + 'diq': ['Medya', 'Dosya'], + 'dsb': ['Wobraz', 'Dataja', 'Bild', 'Medija'], + 'dty': ['चित्र', 'मिडिया'], + 'dv': ['ފައިލު', 'މީޑިއާ', 'ފައިލް'], + 'el': ['Εικόνα', 'Αρχείο', 'Μέσο', 'Μέσον'], + 'eml': ['Immagine'], + 'eo': ['Dosiero', 'Aŭdvidaĵo'], + 'es': ['Medio', 'Archivo', 'Imagen'], + 'et': ['Pilt', 'Fail', 'Meedia'], + 'eu': ['Irudi', 'Fitxategi'], + 'ext': ['Archivu', 'Imagen', 'Mediu'], + 'fa': ['رسانه', 'تصویر', 'مدیا', 'پرونده', 'رسانه‌ای'], + 'ff': ['Média', 'Fichier'], + 'fi': ['Kuva', 'Tiedosto'], + 'fiu-vro': ['Pilt', 'Meediä'], + 'fo': ['Miðil', 'Mynd'], + 'fr': ['Média', 'Fichier'], + 'frp': ['Émâge', 'Fichiér', 'Mèdia'], + 'frr': ['Medium', 'Datei', 'Bild'], + 'fur': ['Immagine', 'Figure'], + 'fy': ['Ofbyld'], + 'ga': ['Íomhá', 'Meán'], + 'gag': ['Mediya', 'Medya', 'Resim', 'Dosya', 'Dosye'], + 'gan': ['媒体文件', '文件', '文檔', '档案', '媒體', '图像', '圖像', '媒体', '檔案'], + 'gd': ['Faidhle', 'Meadhan'], + 'gl': ['Imaxe', 'Ficheiro', 'Arquivo', 'Imagem'], + 'glk': ['رسانه', 'تصویر', 'پرونده', 'فاىل', 'رسانه‌ای', 'مديا'], + 'gn': ['Medio', 'Imagen', "Ta'ãnga"], + 'gom': ['माध्यम', 'मिडिया', 'फायल'], + 'gor': ['Gambar', 'Berkas'], + 'got': ['𐍆𐌴𐌹𐌻𐌰'], + 'gu': ['દ્રશ્ય-શ્રાવ્ય (મિડિયા)', 'દ્રશ્ય-શ્રાવ્ય_(મિડિયા)', 'ચિત્ર'], + 'gv': ['Coadan', 'Meanyn'], + 'hak': ['文件', '媒體', '圖像', '檔案'], + 'haw': ['Kiʻi', 'Waihona', 'Pāpaho'], + 'he': ['תמונה', 'קו', 'מדיה', 'קובץ'], + 'hi': ['मीडिया', 'चित्र'], + 'hif': ['file', 'saadhan'], + 'hr': ['Mediji', 'DT', 'Slika', 'F', 'Datoteka'], + 'hsb': ['Wobraz', 'Dataja', 'Bild'], + 'ht': ['Imaj', 'Fichye', 'Medya'], + 'hu': ['Kép', 'Fájl', 'Média'], + 'hy': ['Պատկեր', 'Մեդիա'], + 'ia': ['Imagine', 'Multimedia'], + 'id': ['Gambar', 'Berkas'], + 'ig': ['Nká', 'Midia', 'Usòrò', 'Ákwúkwó orünotu', 'Ákwúkwó_orünotu'], + 'ii': ['媒体文件', '文件', '档案', '图像', '媒体'], + 'ilo': ['Midia', 'Papeles'], + 'inh': ['Медиа', 'Файл', 'Изображение'], + 'io': ['Imajo', 'Arkivo'], + 'is': ['Miðill', 'Mynd'], + 'it': ['Immagine'], + 'ja': ['メディア', 'ファイル', '画像'], + 'jbo': ['velsku', 'datnyvei'], + 'jv': ['Barkas', 'Medhia', 'Gambar', 'Médhia'], + 'ka': ['მედია', 'სურათი', 'ფაილი'], + 'kaa': ['Swret', 'Таспа', 'سۋرەت', 'Taspa', "Su'wret", 'Сурет', 'تاسپا'], + 'kab': ['Tugna'], + 'kbd': ['Медиа', 'Файл'], + 'kbp': ['Média', 'Fichier'], + 'kg': ['Fisye'], + 'kk': ['Swret', 'سۋرەت', 'Таспа', 'Taspa', 'Сурет', 'تاسپا'], + 'kl': ['Billede', 'Fiileq', 'Fil'], + 'km': ['ឯកសារ', 'រូបភាព', 'មេឌា', 'មីឌា'], + 'kn': ['ಚಿತ್ರ', 'ಮೀಡಿಯ'], + 'ko': ['미디어', '파일', '그림'], + 'koi': ['Медиа', 'Файл', 'Изображение'], + 'krc': ['Медиа', 'Файл', 'Изображение'], + 'ks': ['میڈیا', 'فَیِل'], + 'ksh': [ + 'Beld', 'Meedije', 'Medie', 'Belld', 'Medium', 'Datei', 'Meedijum', + 'Bild' + ], + 'ku': ['میدیا', 'پەڕگە', 'Medya', 'Wêne'], + 'kv': ['Медиа', 'Файл', 'Изображение'], + 'kw': ['Restren'], + 'ky': ['Медиа', 'Файл'], + 'la': ['Imago', 'Fasciculus'], + 'lad': ['Dossia', 'Medya', 'Archivo', 'Dosya', 'Imagen', 'Meddia'], + 'lb': ['Fichier', 'Bild'], + 'lbe': ['Медиа', 'Сурат', 'Изображение'], + 'lez': ['Медиа', 'Mediya', 'Файл', 'Şəkil', 'Изображение'], + 'lfn': ['Fix'], + 'li': ['Afbeelding', 'Plaetje', 'Aafbeilding'], + 'lij': ['Immaggine', 'Immagine'], + 'lmo': ['Immagine', 'Imàjine', 'Archivi'], + 'ln': ['Média', 'Fichier'], + 'lo': ['ສື່ອ', 'ສື່', 'ຮູບ'], + 'lrc': ['رسانه', 'تصویر', 'رسانه‌ای', 'جانیا', 'أسگ', 'ڤارئسگأر'], + 'lt': ['Vaizdas', 'Medija'], + 'ltg': ['Medeja', 'Fails'], + 'lv': ['Attēls'], + 'mai': ['मेडिया', 'फाइल'], + 'map-bms': ['Barkas', 'Medhia', 'Gambar', 'Médhia'], + 'mdf': ['Медиа', 'Няйф', 'Изображение'], + 'mg': ['Rakitra', 'Sary', 'Média'], + 'mhr': ['Медиа', 'Файл', 'Изображение'], + 'min': ['Gambar', 'Berkas'], + 'mk': ['Податотека', 'Медија', 'Медиум', 'Слика'], + 'ml': ['പ്രമാണം', 'ചി', 'മീഡിയ', 'പ്ര', 'ചിത്രം'], + 'mn': ['Медиа', 'Файл', 'Зураг'], + 'mr': ['चित्र', 'मिडिया'], + 'mrj': ['Медиа', 'Файл', 'Изображение'], + 'ms': ['Fail', 'Imej'], + 'mt': ['Midja', 'Medja', 'Stampa'], + 'mwl': ['Multimédia', 'Fexeiro', 'Ficheiro', 'Arquivo', 'Imagem'], + 'my': ['ဖိုင်', 'မီဒီယာ'], + 'myv': ['Медия', 'Артовкс', 'Изображение'], + 'mzn': ['رسانه', 'تصویر', 'مه‌دیا', 'مدیا', 'پرونده', 'رسانه‌ای'], + 'nah': ['Mēdiatl', 'Īxiptli', 'Imagen'], + 'nap': ['Fiùra', 'Immagine'], + 'nds': ['Datei', 'Bild'], + 'nds-nl': ['Ofbeelding', 'Afbeelding', 'Bestaand'], + 'ne': ['मीडिया', 'चित्र'], + 'new': ['किपा', 'माध्यम'], + 'nl': ['Bestand', 'Afbeelding'], + 'nn': ['Fil', 'Bilde', 'Filpeikar'], + 'no': ['Fil', 'Medium', 'Bilde'], + 'nov': [], + 'nrm': ['Média', 'Fichier'], + 'nso': ['Seswantšho'], + 'nv': ['Eʼelyaaígíí'], + 'oc': ['Imatge', 'Fichièr', 'Mèdia'], + 'olo': ['Kuva', 'Medii', 'Failu'], + 'or': ['ମାଧ୍ୟମ', 'ଫାଇଲ'], + 'os': ['Ныв', 'Медиа', 'Файл', 'Изображение'], + 'pa': ['ਤਸਵੀਰ', 'ਮੀਡੀਆ'], + 'pcd': ['Média', 'Fichier'], + 'pdc': ['Medium', 'Datei', 'Bild', 'Feil'], + 'pfl': ['Dadai', 'Medium', 'Datei', 'Bild'], + 'pi': ['मीडिया', 'पटिमा'], + 'pl': ['Plik', 'Grafika'], + 'pms': ['Figura', 'Immagine'], + 'pnb': ['میڈیا', 'تصویر', 'فائل'], + 'pnt': ['Εικόνα', 'Αρχείον', 'Εικόναν', 'Μέσον'], + 'ps': ['انځور', 'رسنۍ', 'دوتنه'], + 'pt': ['Multimédia', 'Ficheiro', 'Arquivo', 'Imagem'], + 'qu': ['Midya', 'Imagen', 'Rikcha'], + 'rm': ['Multimedia', 'Datoteca'], + 'rmy': ['Fişier', 'Mediya', 'Chitro', 'Imagine'], + 'ro': ['Fişier', 'Imagine', 'Fișier'], + 'roa-rup': ['Fişier', 'Imagine', 'Fișier'], + 'roa-tara': ['Immagine'], + 'ru': ['Медиа', 'Файл', 'Изображение'], + 'rue': ['Медіа', 'Медиа', 'Файл', 'Изображение', 'Зображення'], + 'rw': ['Dosiye', 'Itangazamakuru'], + 'sa': ['चित्रम्', 'माध्यमम्', 'सञ्चिका', 'माध्यम', 'चित्रं'], + 'sah': ['Миэдьийэ', 'Ойуу', 'Билэ', 'Изображение'], + 'sat': ['ᱨᱮᱫ', 'ᱢᱤᱰᱤᱭᱟ'], + 'sc': ['Immàgini'], + 'scn': ['Immagine', 'Mmàggini', 'Mèdia'], + 'sd': ['عڪس', 'ذريعات', 'فائل'], + 'se': ['Fiila'], + 'sg': ['Média', 'Fichier'], + 'sh': ['Mediji', 'Slika', 'Медија', 'Datoteka', 'Medija', 'Слика'], + 'si': ['රූපය', 'මාධ්‍යය', 'ගොනුව'], + 'sk': ['Súbor', 'Obrázok', 'Médiá'], + 'sl': ['Slika', 'Datoteka'], + 'sq': ['Figura', 'Skeda'], + 'sr': [ + 'Датотека', + 'Medij', + 'Slika', + 'Медија', + 'Datoteka', + 'Медиј', + 'Medija', + 'Слика', + ], + 'srn': ['Afbeelding', 'Gefre'], + 'stq': ['Bielde', 'Bild'], + 'su': ['Média', 'Gambar'], + 'sv': ['Fil', 'Bild'], + 'sw': ['Faili', 'Picha'], + 'szl': ['Plik', 'Grafika'], + 'ta': ['படிமம்', 'ஊடகம்'], + 'tcy': ['ಮಾದ್ಯಮೊ', 'ಫೈಲ್'], + 'te': ['ఫైలు', 'దస్త్రం', 'బొమ్మ', 'మీడియా'], + 'tet': ['Imajen', 'Arquivo', 'Imagem'], + 'tg': ['Акс', 'Медиа'], + 'th': ['ไฟล์', 'สื่อ', 'ภาพ'], + 'ti': ['ፋይል', 'ሜድያ'], + 'tk': ['Faýl'], + 'tl': ['Midya', 'Talaksan'], + 'tpi': ['Fail'], + 'tr': ['Medya', 'Resim', 'Dosya', 'Ortam'], + 'tt': ['Медиа', 'Рәсем', 'Файл', 'Räsem', 'Изображение'], + 'ty': ['Média', 'Fichier'], + 'tyv': ['Медиа', 'Файл', 'Изображение'], + 'udm': ['Медиа', 'Файл', 'Суред', 'Изображение'], + 'ug': ['ۋاسىتە', 'ھۆججەت'], + 'uk': ['Медіа', 'Медиа', 'Файл', 'Изображение', 'Зображення'], + 'ur': ['میڈیا', 'تصویر', 'وسیط', 'زریعہ', 'فائل', 'ملف'], + 'uz': ['Mediya', 'Tasvir', 'Fayl'], + 'vec': ['Immagine', 'Imàjine', 'Mèdia'], + 'vep': ['Pilt', 'Fail'], + 'vi': ['Phương_tiện', 'Tập_tin', 'Hình', 'Tập tin', 'Phương tiện'], + 'vls': ['Afbeelding', 'Ofbeeldienge'], + 'vo': ['Ragiv', 'Magod', 'Nünamakanäd'], + 'wa': ['Imådje'], + 'war': ['Medya', 'Fayl', 'Paypay'], + 'wo': ['Xibaarukaay', 'Dencukaay'], + 'wuu': ['文件', '档案', '图像', '媒体'], + 'xal': ['Аһар', 'Боомг', 'Изображение', 'Зург'], + 'xmf': ['მედია', 'სურათი', 'ფაილი'], + 'yi': ['מעדיע', 'תמונה', 'טעקע', 'בילד'], + 'yo': ['Fáìlì', 'Amóhùnmáwòrán', 'Àwòrán'], + 'za': ['媒体文件', '文件', '档案', '图像', '媒体'], + 'zea': ['Afbeelding', 'Plaetje'], + 'zh': ['媒体文件', 'F', '文件', '媒體', '档案', '图像', '圖像', '媒体', '檔案'], + 'zh-classical': ['文件', '媒體', '圖像', '檔案'], + 'zh-min-nan': ['tóng-àn', '文件', '媒體', 'Mûi-thé', '圖像', '檔案'], + 'zh-yue': [ + '檔', + '档', + '文件', + '图', + '媒體', + '圖', + '档案', + '图像', + '圖像', + '媒体', + '檔案', + ], +} + +CAT_ALIASES = { + 'ab': ['Категория', 'Акатегориа'], + 'ace': ['Kawan', 'Kategori'], + 'af': ['Kategorie'], + 'ak': ['Nkyekyem'], + 'als': ['Kategorie'], + 'am': ['መደብ'], + 'an': ['Categoría'], + 'ang': ['Flocc'], + 'ar': ['تصنيف'], + 'arc': ['ܣܕܪܐ'], + 'arz': ['تصنيف'], + 'as': ['CAT', 'শ্ৰেণী', 'श्रेणी', 'শ্রেণী'], + 'ast': ['Categoría'], + 'atj': ['Tipanictawin'], + 'av': ['Категория'], + 'ay': ['Categoría'], + 'az': ['Kateqoriya'], + 'azb': ['بؤلمه'], + 'ba': ['Төркөм', 'Категория'], + 'bar': ['Kategorie'], + 'bat-smg': ['Kategorija', 'Kateguorėjė'], + 'bcl': ['Kategorya'], + 'be': ['Катэгорыя'], + 'be-x-old': ['Катэгорыя'], + 'bg': ['Категория'], + 'bh': ['श्रेणी'], + 'bjn': ['Tumbung', 'Kategori'], + 'bm': ['Catégorie'], + 'bn': ['বিষয়শ্রেণী', 'വിഭാഗം'], + 'bpy': ['থাক'], + 'br': ['Rummad'], + 'bs': ['Kategorija'], + 'bug': ['Kategori'], + 'bxr': ['Категори', 'Категория'], + 'ca': ['Categoria'], + 'cbk-zam': ['Categoría'], + 'cdo': ['分類'], + 'ce': ['Категори', 'Тоба', 'Кадегар'], + 'ceb': ['Kategoriya'], + 'ch': ['Katigoria'], + 'ckb': ['پ', 'پۆل'], + 'co': ['Categoria'], + 'crh': ['Категория', 'Kategoriya'], + 'cs': ['Kategorie'], + 'csb': ['Kategòrëjô'], + 'cu': ['Катигорї', 'Категория', 'Катигорїꙗ'], + 'cv': ['Категори'], + 'cy': ['Categori'], + 'da': ['Kategori'], + 'de': ['Kategorie'], + 'din': ['Bekätakthook'], + 'diq': ['Kategoriye', 'Kategori'], + 'dsb': ['Kategorija'], + 'dty': ['श्रेणी'], + 'dv': ['ޤިސްމު'], + 'el': ['Κατηγορία'], + 'eml': ['Categoria'], + 'eo': ['Kategorio'], + 'es': ['CAT', 'Categoría'], + 'et': ['Kategooria'], + 'eu': ['Kategoria'], + 'ext': ['Categoría', 'Categoria'], + 'fa': ['رده'], + 'ff': ['Catégorie'], + 'fi': ['Luokka'], + 'fiu-vro': ['Katõgooria'], + 'fo': ['Bólkur'], + 'fr': ['Catégorie'], + 'frp': ['Catègorie'], + 'frr': ['Kategorie'], + 'fur': ['Categorie'], + 'fy': ['Kategory'], + 'ga': ['Rang', 'Catagóir'], + 'gag': ['Kategori', 'Kategoriya'], + 'gan': ['分類', '分类'], + 'gd': ['Roinn-seòrsa'], + 'gl': ['Categoría'], + 'glk': ['جرگه', 'رده'], + 'gn': ['Ñemohenda'], + 'gom': ['वर्ग', 'श्रेणी'], + 'gor': ['Dalala'], + 'got': ['𐌷𐌰𐌽𐍃𐌰'], + 'gu': ['શ્રેણી', 'CAT', 'શ્રે'], + 'gv': ['Ronney'], + 'hak': ['分類'], + 'haw': ['Māhele'], + 'he': ['קטגוריה', 'קט'], + 'hi': ['श्र', 'श्रेणी'], + 'hif': ['vibhag'], + 'hr': ['CT', 'KT', 'Kategorija'], + 'hsb': ['Kategorija'], + 'ht': ['Kategori'], + 'hu': ['Kategória'], + 'hy': ['Կատեգորիա'], + 'ia': ['Categoria'], + 'id': ['Kategori'], + 'ie': ['Categorie'], + 'ig': ['Ébéonọr', 'Òtù'], + 'ii': ['分类'], + 'ilo': ['Kategoria'], + 'inh': ['ОагӀат'], + 'io': ['Kategorio'], + 'is': ['Flokkur'], + 'it': ['CAT', 'Categoria'], + 'ja': ['カテゴリ'], + 'jbo': ['klesi'], + 'jv': ['Kategori'], + 'ka': ['კატეგორია'], + 'kaa': ['Sanat', 'Kategoriya', 'Санат', 'سانات'], + 'kab': ['Taggayt'], + 'kbd': ['Категория', 'Категориэ'], + 'kbp': ['Catégorie'], + 'kg': ['Kalasi'], + 'kk': ['Sanat', 'Санат', 'سانات'], + 'kl': ['Sumut_atassuseq', 'Kategori', 'Sumut atassuseq'], + 'km': ['ចំនាត់ថ្នាក់ក្រុម', 'ចំណាត់ក្រុម', 'ចំណាត់ថ្នាក់ក្រុម'], + 'kn': ['ವರ್ಗ'], + 'ko': ['분류'], + 'koi': ['Категория'], + 'krc': ['Категория'], + 'ks': ['زٲژ'], + 'ksh': [ + 'Saachjropp', + 'Saachjrop', + 'Katejori', + 'Kategorie', + 'Saachjrupp', + 'Kattejori', + 'Sachjrop', + ], + 'ku': ['Kategorî', 'پۆل'], + 'kv': ['Категория'], + 'kw': ['Class', 'Klass'], + 'ky': ['Категория'], + 'la': ['Categoria'], + 'lad': ['Kateggoría', 'Katēggoría', 'Categoría'], + 'lb': ['Kategorie'], + 'lbe': ['Категория'], + 'lez': ['Категория'], + 'lfn': ['Categoria'], + 'li': ['Categorie', 'Kategorie'], + 'lij': ['Categorîa', 'Categoria'], + 'lmo': ['Categuria', 'Categoria'], + 'ln': ['Catégorie'], + 'lo': ['ໝວດ'], + 'lrc': ['دأسە'], + 'lt': ['Kategorija'], + 'ltg': ['Kategoreja'], + 'lv': ['Kategorija'], + 'mai': ['CA', 'श्रेणी'], + 'map-bms': ['Kategori'], + 'mdf': ['Категорие', 'Категория'], + 'mg': ['Sokajy', 'Catégorie'], + 'mhr': ['Категория', 'Категорий'], + 'min': ['Kategori'], + 'mk': ['Категорија'], + 'ml': ['വിഭാഗം', 'വി', 'വർഗ്ഗം', 'വ'], + 'mn': ['Ангилал'], + 'mr': ['वर्ग'], + 'mrj': ['Категори', 'Категория'], + 'ms': ['Kategori'], + 'mt': ['Kategorija'], + 'mwl': ['Catadorie', 'Categoria'], + 'my': ['ကဏ္ဍ'], + 'myv': ['Категория'], + 'mzn': ['رج', 'رده'], + 'nah': ['Neneuhcāyōtl', 'Categoría'], + 'nap': ['Categurìa', 'Categoria'], + 'nds': ['Kategorie'], + 'nds-nl': ['Categorie', 'Kattegerie', 'Kategorie'], + 'ne': ['श्रेणी'], + 'new': ['पुचः'], + 'nl': ['Categorie'], + 'nn': ['Kategori'], + 'no': ['Kategori'], + 'nrm': ['Catégorie'], + 'nso': ['Setensele'], + 'nv': ['Tʼááłáhági_átʼéego', 'Tʼááłáhági átʼéego'], + 'oc': ['Categoria'], + 'olo': ['Kategourii'], + 'or': ['ବିଭାଗ', 'ଶ୍ରେଣୀ'], + 'os': ['Категори'], + 'pa': ['ਸ਼੍ਰੇਣੀ'], + 'pcd': ['Catégorie'], + 'pdc': ['Abdeeling', 'Kategorie'], + 'pfl': ['Kadegorie', 'Sachgrubb', 'Kategorie'], + 'pi': ['विभाग'], + 'pl': ['Kategoria'], + 'pms': ['Categorìa'], + 'pnb': ['گٹھ'], + 'pnt': ['Κατηγορίαν'], + 'ps': ['وېشنيزه'], + 'pt': ['Categoria'], + 'qu': ['Katiguriya'], + 'rm': ['Categoria'], + 'rmy': ['Shopni'], + 'ro': ['Categorie'], + 'roa-rup': ['Categorie'], + 'roa-tara': ['Categoria'], + 'ru': ['Категория', 'К'], + 'rue': ['Категория', 'Катеґорія'], + 'rw': ['Ikiciro'], + 'sa': ['वर्गः'], + 'sah': ['Категория'], + 'sat': ['ᱛᱷᱚᱠ'], + 'sc': ['Categoria'], + 'scn': ['Catigurìa'], + 'sd': ['زمرو'], + 'se': ['Kategoriija'], + 'sg': ['Catégorie'], + 'sh': ['Kategorija', 'Категорија'], + 'si': ['ප්‍රවර්ගය'], + 'sk': ['Kategória'], + 'sl': ['Kategorija'], + 'sq': ['Kategoria', 'Kategori'], + 'sr': ['Kategorija', 'Категорија'], + 'srn': ['Categorie', 'Guru'], + 'stq': ['Kategorie'], + 'su': ['Kategori'], + 'sv': ['Kategori'], + 'sw': ['Jamii'], + 'szl': ['Kategoryjo', 'Kategoria'], + 'ta': ['பகுப்பு'], + 'tcy': ['ವರ್ಗೊ'], + 'te': ['వర్గం'], + 'tet': ['Kategoría', 'Kategoria'], + 'tg': ['Гурӯҳ'], + 'th': ['หมวดหมู่'], + 'ti': ['መደብ'], + 'tk': ['Kategoriýa'], + 'tl': ['Kategorya', 'Kaurian'], + 'tpi': ['Grup'], + 'tr': ['Kategori', 'KAT'], + 'tt': ['Төркем', 'Törkem', 'Категория'], + 'ty': ['Catégorie'], + 'tyv': ['Аңгылал', 'Категория'], + 'udm': ['Категория'], + 'ug': ['تۈر'], + 'uk': ['Категория', 'Категорія'], + 'ur': ['زمرہ'], + 'uz': ['Turkum', 'Kategoriya'], + 'vec': ['Categoria'], + 'vep': ['Kategorii'], + 'vi': ['Thể_loại', 'Thể loại'], + 'vls': ['Categorie'], + 'vo': ['Klad'], + 'wa': ['Categoreye'], + 'war': ['Kaarangay'], + 'wo': ['Wàll', 'Catégorie'], + 'wuu': ['分类'], + 'xal': ['Янз', 'Әәшл'], + 'xmf': ['კატეგორია'], + 'yi': ['קאטעגאריע', 'קאַטעגאָריע'], + 'yo': ['Ẹ̀ka'], + 'za': ['分类'], + 'zea': ['Categorie'], + 'zh': ['分类', '分類', 'CAT'], + 'zh-classical': ['分類', 'CAT'], + 'zh-min-nan': ['分類', 'Lūi-pia̍t'], + 'zh-yue': ['分类', '分類', '类', '類'], +} + + +class WikipediaDownloader(DocumentDownloader): + + def __init__(self, download_dir, verbose=False): + super().__init__() + self._download_dir = download_dir + self._verbose = verbose + + def download(self, url): + urlpath = urlparse(url).path[1:] + output_name = urlpath.replace('/', '-') + output_file = os.path.join(self._download_dir, output_name) + if os.path.exists(output_file): + print(f'bz2 file: {output_file} exists. Not downloading') + else: + print(f'Downloading {url} and writing to {output_file}') + # Download with either wget or s5cmd (aws) + cmd = ['wget', url, '-O', output_file] + if self._verbose: + stdout, stderr = None, None + else: + stdout, stderr = subprocess.DEVNULL, subprocess.DEVNULL + p = subprocess.run( + cmd, + stdout=stdout, + stderr=stderr, + ) + if p.returncode != 0: + print(f'Failed to download {url} to {output_file}') + + return output_file + + +class WikipediaIterator(DocumentIterator): + + def __init__(self, language='en', log_frequency=1000): + super().__init__() + self._language = language + self._log_frequency = log_frequency + self._counter = 0 + + def iterate(self, file_path): + self._counter = 0 + bname = os.path.split(file_path)[-1] + input_file = bz2.BZ2File(filename=file_path) + utf_f = codecs.getreader('utf-8')(input_file) + context = etree.iterparse(utf_f, events=('end', )) + + for i, (unused_event, elem) in enumerate(context): + if not elem.tag.endswith('page'): + continue + if self._counter > 0 and self._counter % self._log_frequency == 0: + print(f'Extracted {self._counter} articles from {file_path}') + self._counter += 1 + + namespace = elem.tag[:-4] + title = elem.find(f'./{namespace}title').text + ns = elem.find(f'./{namespace}ns').text + id_ = elem.find(f'./{namespace}id').text + red_ = elem.find(f'./{namespace}redirect') + + url = f'https://{self._language}.wikipedia.org/wiki/{quote(title)}' + + # Filter pages that are not in the "main" namespace. + if ns != '0': + elem.clear() + continue + + raw_content = elem.find( + f'./{namespace}revision/{namespace}text').text + elem.clear() + + # Filter redirects. + if raw_content is None or red_ is not None: + continue + + yield { + 'title': title, + 'id': id_, + 'url': url, + 'language': self._language, + 'source_id': f'{bname}', + }, raw_content + + +class WikipediaExtractor(DocumentExtractor): + + def __init__(self, language='en', parser=mwparserfromhell): + super().__init__() + self._language = language + self._parser = parser + + def extract(self, content): + wikicode = self._parser.parse(content) + + # Filters for magic words / parser instructions -- e.g., __NOTOC__ + re_rm_magic = re.compile('__[A-Z]*__', flags=re.UNICODE) + + # Filters for file/image links. + media_prefixes = '|'.join(['File', 'Image', 'Media'] + + MEDIA_ALIASES.get(self._language, [])) + re_rm_wikilink = re.compile(f'^(?:{media_prefixes}):', + flags=re.IGNORECASE | re.UNICODE) + + def rm_wikilink(obj): + return bool(re_rm_wikilink.match(str(obj.title))) + + # Filters for references and tables + def rm_tag(obj): + return str(obj.tag) in {'ref', 'table'} + + # Leave category links in-place but remove the category prefixes + cat_prefixes = '|'.join(['Category'] + + CAT_ALIASES.get(self._language, [])) + re_clean_wikilink = re.compile(f'^(?:{cat_prefixes}):', + flags=re.IGNORECASE | re.UNICODE) + + def is_category(obj): + return bool(re_clean_wikilink.match(str(obj.title))) + + def clean_wikilink(obj): + text = obj.__strip__() + text = re.sub(re_clean_wikilink, '', text) + obj.text = text + + def try_replace_obj(obj): + try: + clean_wikilink(obj) + except ValueError: + # For unknown reasons, objects are sometimes not found. + pass + + def try_remove_obj(obj, section): + try: + section.remove(obj) + except ValueError: + # For unknown reasons, objects are sometimes not found. + pass + + section_text = [] + # Filter individual sections to clean. + wiki_code_kwargs = { + 'flat': True, + 'include_lead': True, + 'include_headings': True, + } + for section in wikicode.get_sections(**wiki_code_kwargs): + for obj in section.ifilter_wikilinks(recursive=True): + if rm_wikilink(obj): + try_remove_obj(obj, section) + elif is_category(obj): + try_replace_obj(obj) + for obj in section.ifilter_tags(matches=rm_tag, + recursive=True): + try_remove_obj(obj, section) + + section_text.append( + re.sub( + re_rm_magic, + '', + section.strip_code().strip(), + )) + # Don't return any meta here + return {}, '\n\n'.join(section_text) + + +def download_wikipedia(output_path: str, + language: str = 'en', + dump_date=None, + output_type: str = 'jsonl', + raw_download_dir=None, + keep_raw_download=False, + force_download=False, + url_limit=None, + item_limit=None) -> Dataset: + """ + Downloads the latest Wikipedia dumps and extracts them using mwparserfromhell + + Args: + output_path: The path to the root directory of the files + language: The language of the Wikipedia articles to download + dump_date: A string formatted as "YYYYMMDD" for the wikipedia dump to use. + If None, latest dump is used. + output_type: The file type to save the data as. + raw_download_dir: Path to store the raw download files for intermediate processing. + If None, they are stored in a folder named "downloads" under output_path. + keep_raw_download: If True, keeps the bz2 files that have not been extracted. + force_download: If False, will skip processing all files in output_paths that already exist and + directly read from them instead. + url_limit: The maximum number of raw files to download from the snapshot. If None, all + files from the range of snapshots are downloaded. + """ # noqa: E501 + wikipedia_urls = get_wikipedia_urls(language=language, dump_date=dump_date) + if url_limit: + wikipedia_urls = wikipedia_urls[:url_limit] + output_paths = list( + map( + lambda url: os.path.join(output_path, + url.split('/')[-1] + f'.{output_type}'), + wikipedia_urls, + )) + + if not raw_download_dir: + raw_download_dir = os.path.join(output_path, 'downloads') + expand_outdir_and_mkdir(raw_download_dir) + + downloader = WikipediaDownloader(download_dir=raw_download_dir) + iterator = WikipediaIterator(language=language) + extractor = WikipediaExtractor(language=language) + + output_format = { + 'text': str, + 'title': str, + 'id': str, + 'url': str, + 'language': str, + 'source_id': str, + 'filename': str, + } + dataset = download_and_extract(wikipedia_urls, + output_paths, + downloader, + iterator, + extractor, + output_format, + output_type=output_type, + keep_raw_download=keep_raw_download, + force_download=force_download, + item_limit=item_limit) + + return dataset diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py index e2fc241cd..6f9d12869 100644 --- a/data_juicer/utils/file_utils.py +++ b/data_juicer/utils/file_utils.py @@ -1,13 +1,16 @@ +import ast import asyncio import copy import hashlib import os import re import shutil +import warnings from datetime import datetime, timezone from pathlib import Path -from typing import AsyncGenerator, List, Union +from typing import AsyncGenerator, List, Optional, Union +import pandas as pd from datasets.utils.extract import ZstdExtractor as Extractor from data_juicer.utils.constant import DEFAULT_PREFIX, Fields @@ -229,3 +232,167 @@ def copy_data(from_dir, to_dir, data_path): os.makedirs(parent_dir) shutil.copy2(from_path, to_path) return True + + +def expand_outdir_and_mkdir(outdir): + _outdir = os.path.abspath(os.path.expanduser(outdir)) + if not os.path.exists(_outdir): + os.makedirs(_outdir) + return _outdir + + +def single_partition_write_with_filename( + df: pd.DataFrame, + output_file_dir: str, + keep_filename_column: bool = False, + output_type: str = 'jsonl', +) -> pd.Series: + """ + This function processes a DataFrame and writes it to disk + + Args: + df: A DataFrame. + output_file_dir: The output file path. + keep_filename_column: Whether to keep or drop the "filename" column, if it exists. + output_type="jsonl": The type of output file to write. + Returns: + If the DataFrame is non-empty, return a Series containing a single element, True. + If the DataFrame is empty, return a Series containing a single element, False. + + """ # noqa: E501 + assert 'filename' in df.columns + + if len(df) > 0: + empty_partition = False + else: + warnings.warn('Empty partition found') + empty_partition = True + + # if is_cudf_type(df): + # import cudf + # success_ser = cudf.Series([empty_partition]) + # else: + success_ser = pd.Series([empty_partition]) + + if not empty_partition: + filenames = df.filename.unique() + filenames = list(filenames) + num_files = len(filenames) + + for filename in filenames: + out_df = df[df.filename == filename] if num_files > 1 else df + if not keep_filename_column: + out_df = out_df.drop('filename', axis=1) + + filename = Path(filename).stem + output_file_path = os.path.join(output_file_dir, filename) + + if output_type == 'jsonl': + output_file_path = output_file_path + '.jsonl' + out_df.to_json( + output_file_path, + orient='records', + lines=True, + force_ascii=False, + ) + + elif output_type == 'parquet': + output_file_path = output_file_path + '.parquet' + out_df.to_parquet(output_file_path) + + else: + raise ValueError(f'Unknown output type: {output_type}') + + return success_ser + + +def read_single_partition( + files, + filetype='jsonl', + add_filename=False, + input_meta: Union[str, dict] = None, + columns: Optional[List[str]] = None, + **kwargs, +) -> pd.DataFrame: + """ + This function reads a file with cuDF, sorts the columns of the DataFrame + and adds a "filename" column. + + Args: + files: The path to the jsonl files to read. + add_filename: Whether to add a "filename" column to the DataFrame. + input_meta: A dictionary or a string formatted as a dictionary, which outlines + the field names and their respective data types within the JSONL input file. + columns: If not None, only these columns will be read from the file. + There is a significant performance gain when specifying columns for Parquet files. + + Returns: + A pandas DataFrame. + + """ # noqa: E501 + if input_meta is not None and filetype != 'jsonl': + warnings.warn('input_meta is only valid for JSONL files and' + 'will be ignored for other file formats..') + + if filetype in ['jsonl', 'json']: + read_kwargs = {'lines': filetype == 'jsonl'} + read_kwargs['dtype'] = False + read_f = pd.read_json + + if input_meta is not None: + read_kwargs['dtype'] \ + = (ast.literal_eval(input_meta) + if type(input_meta) == str else input_meta) + + elif filetype == 'parquet': + read_kwargs = {'columns': columns} + read_f = pd.read_parquet + + else: + raise RuntimeError('Could not read data, please check file type') + + read_files_one_at_a_time = True + if read_files_one_at_a_time: + concat_f = pd.concat + df_ls = [] + for file in files: + df = read_f(file, **read_kwargs, **kwargs) + if add_filename: + df['filename'] = os.path.basename(file) + df_ls.append(df) + df = concat_f(df_ls, ignore_index=True) + else: + df = read_f(files, **read_kwargs, **kwargs) + + if filetype in ['jsonl', 'json'] and columns is not None: + if add_filename and 'filename' not in columns: + columns.append('filename') + df = df[columns] + + df = df[sorted(df.columns)] + return df + + +def get_all_files_paths_under(root, + recurse_subdirectories=True, + followlinks=False): + """ + This function returns a list of all the files under a specified directory. + Args: + root: The path to the directory to read. + recurse_subdirecties: Whether to recurse into subdirectories. + Please note that this can be slow for large + number of files. + followlinks: Whether to follow symbolic links. + """ # noqa: E501 + if recurse_subdirectories: + file_ls = [ + os.path.join(r, f) + for r, subdirs, files in os.walk(root, followlinks=followlinks) + for f in files + ] + else: + file_ls = [entry.path for entry in os.scandir(root)] + + file_ls.sort() + return file_ls diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt index 7e614159b..5396ce025 100644 --- a/environments/minimal_requires.txt +++ b/environments/minimal_requires.txt @@ -26,3 +26,4 @@ multiprocess==0.70.12 dill==0.3.4 psutil pydantic>=2.0 +mwparserfromhell diff --git a/tests/download/__init__.py b/tests/download/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/download/test_download.py b/tests/download/test_download.py new file mode 100644 index 000000000..481c1ff5b --- /dev/null +++ b/tests/download/test_download.py @@ -0,0 +1,23 @@ +import tempfile +from data_juicer.download.wikipedia import ( + get_wikipedia_urls, download_wikipedia +) + +class TestDownload: + + def test_wikipedia_urls(self): + dump_date = "20241101" + urls = get_wikipedia_urls(dump_date=dump_date) + assert len(urls) > 3 + assert urls[0] == "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2" + assert urls[1] == "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream2.xml-p41243p151573.bz2" + assert urls[2] == "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream3.xml-p151574p311329.bz2" + + def test_wikipedia_download(self): + dump_date = "20241101" + output_directory = tempfile.gettempdir() + "/dj_temp/" + url_limit = 5 + item_limit = 10 + wiki_df = download_wikipedia(output_directory, dump_date=dump_date, url_limit=url_limit, item_limit=item_limit) + sample = wiki_df.take(50) + assert len(sample) == 50 \ No newline at end of file From 50f8d3dc484efce9613848945771ef951d396c78 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 21 Nov 2024 21:46:54 -0800 Subject: [PATCH 03/56] refactor formatter; add dataset_builder --- data_juicer/core/dataset_builder.py | 64 +++++++++++++++++++ .../download/{__inif__.py => __init__.py} | 0 data_juicer/download/arxiv.py | 4 +- .../download/{base.py => downloader.py} | 0 data_juicer/download/wikipedia.py | 11 ++-- data_juicer/format/formatter.py | 61 +----------------- data_juicer/format/mixture_formatter.py | 3 +- tests/download/test_download.py | 7 +- 8 files changed, 81 insertions(+), 69 deletions(-) create mode 100644 data_juicer/core/dataset_builder.py rename data_juicer/download/{__inif__.py => __init__.py} (100%) rename data_juicer/download/{base.py => downloader.py} (100%) diff --git a/data_juicer/core/dataset_builder.py b/data_juicer/core/dataset_builder.py new file mode 100644 index 000000000..35376d0b1 --- /dev/null +++ b/data_juicer/core/dataset_builder.py @@ -0,0 +1,64 @@ +import os + +from data_juicer.format import RemoteFormatter +from data_juicer.format.formatter import FORMATTERS, BaseFormatter +from data_juicer.utils.file_utils import (find_files_with_suffix, + is_absolute_path) + + +def load_formatter(dataset_path, + text_keys=None, + suffixes=None, + add_suffix=False, + **kwargs) -> BaseFormatter: + """ + Load the appropriate formatter for different types of data formats. + + :param dataset_path: Path to dataset file or dataset directory + :param text_keys: key names of field that stores sample text. + Default: None + :param suffixes: the suffix of files that will be read. Default: + None + :return: a dataset formatter. + """ + + if suffixes is None: + suffixes = [] + ext_num = {} + if os.path.isdir(dataset_path) or os.path.isfile(dataset_path): + file_dict = find_files_with_suffix(dataset_path, suffixes) + if not file_dict: + raise IOError( + 'Unable to find files matching the suffix from {}'.format( + dataset_path)) + for ext in file_dict: + ext_num[ext] = len(file_dict[ext]) + + # local dataset + if ext_num: + formatter_num = {} + for name, formatter in FORMATTERS.modules.items(): + formatter_num[name] = 0 + for ext in ext_num: + if ext in formatter.SUFFIXES: + formatter_num[name] += ext_num[ext] + formatter = max(formatter_num, key=lambda x: formatter_num[x]) + target_suffixes = set(ext_num.keys()).intersection( + set(FORMATTERS.modules[formatter].SUFFIXES)) + return FORMATTERS.modules[formatter](dataset_path, + text_keys=text_keys, + suffixes=target_suffixes, + add_suffix=add_suffix, + **kwargs) + + # try huggingface dataset hub + elif not is_absolute_path(dataset_path) and dataset_path.count('/') <= 1: + return RemoteFormatter(dataset_path, text_keys=text_keys, **kwargs) + + # no data + else: + raise ValueError(f'Unable to load the dataset from [{dataset_path}]. ' + f'It might be because Data-Juicer doesn\'t support ' + f'the format of this dataset, or the path of this ' + f'dataset is incorrect.Please check if it\'s a valid ' + f'dataset path and retry.') diff --git a/data_juicer/download/__inif__.py b/data_juicer/download/__init__.py similarity index 100% rename from data_juicer/download/__inif__.py rename to data_juicer/download/__init__.py diff --git a/data_juicer/download/arxiv.py b/data_juicer/download/arxiv.py index af77427e1..b810c4c27 100644 --- a/data_juicer/download/arxiv.py +++ b/data_juicer/download/arxiv.py @@ -5,8 +5,8 @@ import tarfile import tempfile -from base import (DocumentDownloader, DocumentExtractor, DocumentIterator, - download_and_extract, get_arxiv_urls) +from downloader import (DocumentDownloader, DocumentExtractor, + DocumentIterator, download_and_extract, get_arxiv_urls) from data_juicer.core.data import DJDataset from data_juicer.utils.file_utils import (expand_outdir_and_mkdir, diff --git a/data_juicer/download/base.py b/data_juicer/download/downloader.py similarity index 100% rename from data_juicer/download/base.py rename to data_juicer/download/downloader.py diff --git a/data_juicer/download/wikipedia.py b/data_juicer/download/wikipedia.py index ae8683b88..6e8e29b52 100644 --- a/data_juicer/download/wikipedia.py +++ b/data_juicer/download/wikipedia.py @@ -3,16 +3,17 @@ import os import re import subprocess +import urllib.parse as up import xml.etree.cElementTree as etree -from urllib.parse import quote, urlparse import mwparserfromhell from datasets import Dataset from data_juicer.utils.file_utils import expand_outdir_and_mkdir -from .base import (DocumentDownloader, DocumentExtractor, DocumentIterator, - download_and_extract, get_wikipedia_urls) +from .downloader import (DocumentDownloader, DocumentExtractor, + DocumentIterator, download_and_extract, + get_wikipedia_urls) # The majority of this code is taken from the HuggingFace # implementation of the Wikipedia dataset preparation: @@ -569,7 +570,7 @@ def __init__(self, download_dir, verbose=False): self._verbose = verbose def download(self, url): - urlpath = urlparse(url).path[1:] + urlpath = up.urlparse(url).path[1:] output_name = urlpath.replace('/', '-') output_file = os.path.join(self._download_dir, output_name) if os.path.exists(output_file): @@ -621,7 +622,7 @@ def iterate(self, file_path): id_ = elem.find(f'./{namespace}id').text red_ = elem.find(f'./{namespace}redirect') - url = f'https://{self._language}.wikipedia.org/wiki/{quote(title)}' + url = f'https://{self._language}.wikipedia.org/wiki/{up.quote(title)}' # noqa: E501 # Filter pages that are not in the "main" namespace. if ns != '0': diff --git a/data_juicer/format/formatter.py b/data_juicer/format/formatter.py index 2a8cd99ed..48690f48b 100644 --- a/data_juicer/format/formatter.py +++ b/data_juicer/format/formatter.py @@ -5,8 +5,7 @@ from loguru import logger from data_juicer.utils.constant import Fields -from data_juicer.utils.file_utils import (find_files_with_suffix, - is_absolute_path) +from data_juicer.utils.file_utils import find_files_with_suffix from data_juicer.utils.registry import Registry FORMATTERS = Registry('Formatters') @@ -266,61 +265,3 @@ def rel2abs(sample, path_keys, dataset_dir): 'might not be able to find by Data-Juicer.') return dataset - - -def load_formatter(dataset_path, - text_keys=None, - suffixes=None, - add_suffix=False, - **kwargs) -> BaseFormatter: - """ - Load the appropriate formatter for different types of data formats. - - :param dataset_path: Path to dataset file or dataset directory - :param text_keys: key names of field that stores sample text. - Default: None - :param suffixes: the suffix of files that will be read. Default: - None - :return: a dataset formatter. - """ - - if suffixes is None: - suffixes = [] - ext_num = {} - if os.path.isdir(dataset_path) or os.path.isfile(dataset_path): - file_dict = find_files_with_suffix(dataset_path, suffixes) - if not file_dict: - raise IOError( - 'Unable to find files matching the suffix from {}'.format( - dataset_path)) - for ext in file_dict: - ext_num[ext] = len(file_dict[ext]) - - # local dataset - if ext_num: - formatter_num = {} - for name, formatter in FORMATTERS.modules.items(): - formatter_num[name] = 0 - for ext in ext_num: - if ext in formatter.SUFFIXES: - formatter_num[name] += ext_num[ext] - formatter = max(formatter_num, key=lambda x: formatter_num[x]) - target_suffixes = set(ext_num.keys()).intersection( - set(FORMATTERS.modules[formatter].SUFFIXES)) - return FORMATTERS.modules[formatter](dataset_path, - text_keys=text_keys, - suffixes=target_suffixes, - add_suffix=add_suffix, - **kwargs) - - # try huggingface dataset hub - elif not is_absolute_path(dataset_path) and dataset_path.count('/') <= 1: - return RemoteFormatter(dataset_path, text_keys=text_keys, **kwargs) - - # no data - else: - raise ValueError(f'Unable to load the dataset from [{dataset_path}]. ' - f'It might be because Data-Juicer doesn\'t support ' - f'the format of this dataset, or the path of this ' - f'dataset is incorrect.Please check if it\'s a valid ' - f'dataset path and retry.') diff --git a/data_juicer/format/mixture_formatter.py b/data_juicer/format/mixture_formatter.py index 6c13bdd7c..35f2bc578 100644 --- a/data_juicer/format/mixture_formatter.py +++ b/data_juicer/format/mixture_formatter.py @@ -5,7 +5,8 @@ from datasets import Dataset, concatenate_datasets from loguru import logger -from .formatter import BaseFormatter, load_formatter +from ..core.dataset_builder import load_formatter +from .formatter import BaseFormatter class MixtureFormatter(BaseFormatter): diff --git a/tests/download/test_download.py b/tests/download/test_download.py index 481c1ff5b..907c93ded 100644 --- a/tests/download/test_download.py +++ b/tests/download/test_download.py @@ -1,3 +1,4 @@ +import unittest import tempfile from data_juicer.download.wikipedia import ( get_wikipedia_urls, download_wikipedia @@ -20,4 +21,8 @@ def test_wikipedia_download(self): item_limit = 10 wiki_df = download_wikipedia(output_directory, dump_date=dump_date, url_limit=url_limit, item_limit=item_limit) sample = wiki_df.take(50) - assert len(sample) == 50 \ No newline at end of file + assert len(sample) == 50 + + +if __name__ == '__main__': + unittest.main() From a089de4b8be2e9bbcfbe33e61ea1abb3509db341 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Tue, 26 Nov 2024 11:09:38 -0800 Subject: [PATCH 04/56] add config files and test entry --- configs/datasets/local_json.yaml | 6 ++++++ configs/datasets/local_parquet.yaml | 6 ++++++ configs/datasets/remote_arxiv.yaml | 9 +++++++++ configs/datasets/remote_commoncrawl.yaml | 10 ++++++++++ configs/datasets/remote_huggingface.yaml | 7 +++++++ configs/datasets/remote_modelscope.yaml | 7 +++++++ configs/datasets/remote_wiki.yaml | 9 +++++++++ tests/config/demo_4_dataset_test.yaml | 22 ++++++++++++++++++++++ tests/config/test_config_funcs.py | 2 ++ 9 files changed, 78 insertions(+) create mode 100644 configs/datasets/local_json.yaml create mode 100644 configs/datasets/local_parquet.yaml create mode 100644 configs/datasets/remote_arxiv.yaml create mode 100644 configs/datasets/remote_commoncrawl.yaml create mode 100644 configs/datasets/remote_huggingface.yaml create mode 100644 configs/datasets/remote_modelscope.yaml create mode 100644 configs/datasets/remote_wiki.yaml create mode 100644 tests/config/demo_4_dataset_test.yaml diff --git a/configs/datasets/local_json.yaml b/configs/datasets/local_json.yaml new file mode 100644 index 000000000..0d7c2252a --- /dev/null +++ b/configs/datasets/local_json.yaml @@ -0,0 +1,6 @@ +# global parameters +project_name: 'dataset-local-json' +dataset: + type: 'local' + format: 'json' + path: 'path/to/json/file' diff --git a/configs/datasets/local_parquet.yaml b/configs/datasets/local_parquet.yaml new file mode 100644 index 000000000..57c33e406 --- /dev/null +++ b/configs/datasets/local_parquet.yaml @@ -0,0 +1,6 @@ +# global parameters +project_name: 'dataset-local-parquet' +dataset: + type: 'local' + format: 'parquet' + path: 'path/to/parquet/file' diff --git a/configs/datasets/remote_arxiv.yaml b/configs/datasets/remote_arxiv.yaml new file mode 100644 index 000000000..fe97674e6 --- /dev/null +++ b/configs/datasets/remote_arxiv.yaml @@ -0,0 +1,9 @@ +# global parameters +project_name: 'dataset-remote-arxiv' +dataset: + type: 'remote' + source: 'arxiv' + lang: 'en' + dump_date: 'latest' + force_download: false + url_limit: 2 diff --git a/configs/datasets/remote_commoncrawl.yaml b/configs/datasets/remote_commoncrawl.yaml new file mode 100644 index 000000000..8757d5627 --- /dev/null +++ b/configs/datasets/remote_commoncrawl.yaml @@ -0,0 +1,10 @@ +# global parameters +project_name: 'dataset-remote-commoncrawl' +dataset: + type: 'remote' + source: 'commoncrawl' + start_snapshot: '2020-50' + end_snapshot: '2021-04' + aws: true + force_download: false + url_limit: 2 diff --git a/configs/datasets/remote_huggingface.yaml b/configs/datasets/remote_huggingface.yaml new file mode 100644 index 000000000..9f2b4eb19 --- /dev/null +++ b/configs/datasets/remote_huggingface.yaml @@ -0,0 +1,7 @@ +# global parameters +project_name: 'dataset-remote-huggingface' +dataset: + type: 'remote' + source: 'huggingface' + org: "HuggingFaceFW" + name: 'fineweb' diff --git a/configs/datasets/remote_modelscope.yaml b/configs/datasets/remote_modelscope.yaml new file mode 100644 index 000000000..433138d4f --- /dev/null +++ b/configs/datasets/remote_modelscope.yaml @@ -0,0 +1,7 @@ +# global parameters +project_name: 'dataset-remote-modelscope' +dataset: + type: 'remote' + source: 'modelscope' + org: "modelscope" + name: 'clue' diff --git a/configs/datasets/remote_wiki.yaml b/configs/datasets/remote_wiki.yaml new file mode 100644 index 000000000..6e94c7549 --- /dev/null +++ b/configs/datasets/remote_wiki.yaml @@ -0,0 +1,9 @@ +# global parameters +project_name: 'dataset-remote-wiki' +dataset: + type: 'remote' + source: 'wiki' + lang: 'en' + dump_date: 'latest' + force_download: false + url_limit: 2 diff --git a/tests/config/demo_4_dataset_test.yaml b/tests/config/demo_4_dataset_test.yaml new file mode 100644 index 000000000..bffc15f79 --- /dev/null +++ b/tests/config/demo_4_dataset_test.yaml @@ -0,0 +1,22 @@ +# Process config example for Arxiv dataset + +# global parameters +project_name: 'test_demo' +dataset: + type: 'local' + format: 'json' + path: './demos/data/demo-dataset.jsonl' +np: 4 # number of subprocess to process your dataset + +export_path: './outputs/demo/demo-processed.parquet' + +# process schedule +# a list of several process operators with their arguments +process: + - whitespace_normalization_mapper: + - language_id_score_filter: + lang: 'zh' + - document_deduplicator: # deduplicate text samples using md5 hashing exact matching method + lowercase: false # whether to convert text to lower case + ignore_non_character: false + - remove_table_text_mapper: diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index 1cb7c4463..76d8d1e8d 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -15,6 +15,8 @@ test_bad_yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'demo_4_test_bad_val.yaml') +test_yaml_dataset_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), + 'demo_4_dataset_test.yaml') class ConfigTest(DataJuicerTestCaseBase): From 5a717d798aa2ca38def806f17a01787ec088032b Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Mon, 2 Dec 2024 09:21:56 -0800 Subject: [PATCH 05/56] initial dataset_builder --- data_juicer/core/dataset_builder.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/data_juicer/core/dataset_builder.py b/data_juicer/core/dataset_builder.py index 35376d0b1..8995e31e8 100644 --- a/data_juicer/core/dataset_builder.py +++ b/data_juicer/core/dataset_builder.py @@ -6,6 +6,28 @@ is_absolute_path) +class DatasetBuilder(object): + + def __init__(self, cfg): + self.formatters = self.build(cfg) + + def build_formatters(self, cfg): + # build out all the formatters with cfg + formatters = [] + for ds_config in cfg.configs: + formatters.append(self._build_formatter(ds_config)) + return formatters + + def _build_formatter(self, ds_config): + # initialize formatter based on remote or local dataset config + return None + + def build_dataaset(self): + # handle mixture dataset, nested dataset + for f in self.formatters: + f.load_dataset() + + def load_formatter(dataset_path, text_keys=None, suffixes=None, From ffba7e775c0f62598c66aa140e1340775b6c27dd Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 4 Dec 2024 11:49:28 -0800 Subject: [PATCH 06/56] add mixture dataset support; type/subtype --- configs/datasets/local_json.yaml | 5 ++-- configs/datasets/local_parquet.yaml | 5 ++-- configs/datasets/mixture.yaml | 12 +++++++++ configs/datasets/remote_arxiv.yaml | 2 +- configs/datasets/remote_commoncrawl.yaml | 2 +- configs/datasets/remote_huggingface.yaml | 8 +++--- configs/datasets/remote_modelscope.yaml | 8 +++--- configs/datasets/remote_wiki.yaml | 2 +- data_juicer/core/executor.py | 32 +++++++++++++++++++++--- 9 files changed, 60 insertions(+), 16 deletions(-) create mode 100644 configs/datasets/mixture.yaml diff --git a/configs/datasets/local_json.yaml b/configs/datasets/local_json.yaml index 0d7c2252a..a8d36b270 100644 --- a/configs/datasets/local_json.yaml +++ b/configs/datasets/local_json.yaml @@ -2,5 +2,6 @@ project_name: 'dataset-local-json' dataset: type: 'local' - format: 'json' - path: 'path/to/json/file' + subtype: 'json' + path: + - 'path/to/json/file' diff --git a/configs/datasets/local_parquet.yaml b/configs/datasets/local_parquet.yaml index 57c33e406..3b658f660 100644 --- a/configs/datasets/local_parquet.yaml +++ b/configs/datasets/local_parquet.yaml @@ -2,5 +2,6 @@ project_name: 'dataset-local-parquet' dataset: type: 'local' - format: 'parquet' - path: 'path/to/parquet/file' + subtype: 'parquet' + path: + - 'path/to/parquet/file' diff --git a/configs/datasets/mixture.yaml b/configs/datasets/mixture.yaml new file mode 100644 index 000000000..39e6ae47b --- /dev/null +++ b/configs/datasets/mixture.yaml @@ -0,0 +1,12 @@ +project_name: 'dataset-mixture' +dataset: + - type: 'local' + subtype: 'json' + weight: 1.0 + path: + - 'path/to/json/file' + - type: 'local' + subtype: 'csv' + weight: 1.0 + files: + - 'path/to/csv/file' diff --git a/configs/datasets/remote_arxiv.yaml b/configs/datasets/remote_arxiv.yaml index fe97674e6..f170e0ac2 100644 --- a/configs/datasets/remote_arxiv.yaml +++ b/configs/datasets/remote_arxiv.yaml @@ -2,7 +2,7 @@ project_name: 'dataset-remote-arxiv' dataset: type: 'remote' - source: 'arxiv' + subtype: 'arxiv' lang: 'en' dump_date: 'latest' force_download: false diff --git a/configs/datasets/remote_commoncrawl.yaml b/configs/datasets/remote_commoncrawl.yaml index 8757d5627..444f4bca5 100644 --- a/configs/datasets/remote_commoncrawl.yaml +++ b/configs/datasets/remote_commoncrawl.yaml @@ -2,7 +2,7 @@ project_name: 'dataset-remote-commoncrawl' dataset: type: 'remote' - source: 'commoncrawl' + subtype: 'commoncrawl' start_snapshot: '2020-50' end_snapshot: '2021-04' aws: true diff --git a/configs/datasets/remote_huggingface.yaml b/configs/datasets/remote_huggingface.yaml index 9f2b4eb19..820a01a6d 100644 --- a/configs/datasets/remote_huggingface.yaml +++ b/configs/datasets/remote_huggingface.yaml @@ -2,6 +2,8 @@ project_name: 'dataset-remote-huggingface' dataset: type: 'remote' - source: 'huggingface' - org: "HuggingFaceFW" - name: 'fineweb' + subtype: 'huggingface' + path: "HuggingFaceFW/fineweb" + name: "CC-MAIN-2024-10" + split: "train" + limit: 1000 diff --git a/configs/datasets/remote_modelscope.yaml b/configs/datasets/remote_modelscope.yaml index 433138d4f..266df96a9 100644 --- a/configs/datasets/remote_modelscope.yaml +++ b/configs/datasets/remote_modelscope.yaml @@ -2,6 +2,8 @@ project_name: 'dataset-remote-modelscope' dataset: type: 'remote' - source: 'modelscope' - org: "modelscope" - name: 'clue' + subtype: 'modelscope' + path: 'modelscope/clue' + subset_name: 'afqmc' + split: 'train' + limit: 1000 diff --git a/configs/datasets/remote_wiki.yaml b/configs/datasets/remote_wiki.yaml index 6e94c7549..fb6c85b28 100644 --- a/configs/datasets/remote_wiki.yaml +++ b/configs/datasets/remote_wiki.yaml @@ -2,7 +2,7 @@ project_name: 'dataset-remote-wiki' dataset: type: 'remote' - source: 'wiki' + subtype: 'wiki' lang: 'en' dump_date: 'latest' force_download: false diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index d9445dad0..4d4450cf6 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -1,6 +1,7 @@ import os +from abc import ABC, abstractmethod from time import time -from typing import Optional +from typing import List, Optional, Tuple from jsonargparse import Namespace from loguru import logger @@ -24,7 +25,29 @@ from .tracer import Tracer -class Executor: +class ExecutorBase(ABC): + + @abstractmethod + def __init__(self, cfg: Optional[Namespace] = None): + pass + + @abstractmethod + def run(self, + load_data_np: Optional[PositiveInt] = None, + skip_return=False): + pass + + @abstractmethod + def can_handle_data(self, types: List[Tuple[str, str]]): + """ + types is a list of tuples, [(type, subtype), (type, subtype), ...]; + different executor types will specific whether it can handle these + type/subtype combos + """ + pass + + +class Executor(ExecutorBase): """ This Executor class is used to process a specific dataset. @@ -152,7 +175,7 @@ def run(self, :param skip_return: skip return for API called. :return: processed dataset. """ - # 1. format data + # 1. load data if self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available: logger.info('Loading dataset from checkpoint...') dataset = self.ckpt_manager.load_ckpt() @@ -211,3 +234,6 @@ def run(self, if not skip_return: return dataset + + def can_handle_data(self, types: List[Tuple[str, str]]): + pass From 79ae9809794973ca7158999b6c5a4bba7086e7ba Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 4 Dec 2024 12:23:42 -0800 Subject: [PATCH 07/56] RayExecutor with ExecutorBase --- data_juicer/core/ray_executor.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 1d90e31b3..9e5348c45 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -1,6 +1,8 @@ import time +from typing import Optional from loguru import logger +from pydantic import PositiveInt from data_juicer.config import init_configs from data_juicer.core.ray_data import RayDataset @@ -9,12 +11,13 @@ from data_juicer.utils.lazy_loader import LazyLoader from .adapter import Adapter +from .executor import ExecutorBase ray = LazyLoader('ray', 'ray') rd = LazyLoader('rd', 'ray.data') -class RayExecutor: +class RayExecutor(ExecutorBase): """ Executor based on Ray. @@ -42,11 +45,14 @@ def __init__(self, cfg=None): logger.info('Initing Ray ...') ray.init(self.cfg.ray_address) - def run(self, load_data_np=None): + def run(self, + load_data_np: Optional[PositiveInt] = None, + skip_return=False): """ - Running the dataset process pipeline. + Running the dataset process pipeline :param load_data_np: number of workers when loading the dataset. + :param skip_return: skip return for API called. :return: processed dataset. """ # 1. load data @@ -89,4 +95,5 @@ def run(self, load_data_np=None): # 4. data export logger.info('Exporting dataset to disk...') dataset.data.write_json(self.cfg.export_path, force_ascii=False) - return dataset + if not skip_return: + return dataset From e6a6e7197969613a2a5d2cfd4190150a64c5507d Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 4 Dec 2024 14:06:55 -0800 Subject: [PATCH 08/56] get rid of subtype for local dataset; depending on ext for proper routing --- configs/datasets/local_json.yaml | 1 - configs/datasets/local_parquet.yaml | 1 - configs/datasets/mixture.yaml | 2 -- tools/postprocess/data_mixture.py | 2 +- 4 files changed, 1 insertion(+), 5 deletions(-) diff --git a/configs/datasets/local_json.yaml b/configs/datasets/local_json.yaml index a8d36b270..e1b27b304 100644 --- a/configs/datasets/local_json.yaml +++ b/configs/datasets/local_json.yaml @@ -2,6 +2,5 @@ project_name: 'dataset-local-json' dataset: type: 'local' - subtype: 'json' path: - 'path/to/json/file' diff --git a/configs/datasets/local_parquet.yaml b/configs/datasets/local_parquet.yaml index 3b658f660..ff9c67a31 100644 --- a/configs/datasets/local_parquet.yaml +++ b/configs/datasets/local_parquet.yaml @@ -2,6 +2,5 @@ project_name: 'dataset-local-parquet' dataset: type: 'local' - subtype: 'parquet' path: - 'path/to/parquet/file' diff --git a/configs/datasets/mixture.yaml b/configs/datasets/mixture.yaml index 39e6ae47b..b67a9b5e2 100644 --- a/configs/datasets/mixture.yaml +++ b/configs/datasets/mixture.yaml @@ -1,12 +1,10 @@ project_name: 'dataset-mixture' dataset: - type: 'local' - subtype: 'json' weight: 1.0 path: - 'path/to/json/file' - type: 'local' - subtype: 'csv' weight: 1.0 files: - 'path/to/csv/file' diff --git a/tools/postprocess/data_mixture.py b/tools/postprocess/data_mixture.py index db89a2a1f..c2508796d 100644 --- a/tools/postprocess/data_mixture.py +++ b/tools/postprocess/data_mixture.py @@ -58,7 +58,7 @@ def run_mixture(): e.g. 1) a single data path 2) multiple datasets in the format: dataset1-path - dataset1-file dataset3-path ...' + dataset1-file dataset3-path ...' """ args = parse_args() From eb300f0d3fefd089509be5733b73f450f8a8ac2a Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 4 Dec 2024 14:07:44 -0800 Subject: [PATCH 09/56] use source instead of sub_type for remote dataset configs --- configs/datasets/remote_arxiv.yaml | 2 +- configs/datasets/remote_commoncrawl.yaml | 2 +- configs/datasets/remote_huggingface.yaml | 2 +- configs/datasets/remote_modelscope.yaml | 2 +- configs/datasets/remote_wiki.yaml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/datasets/remote_arxiv.yaml b/configs/datasets/remote_arxiv.yaml index f170e0ac2..fe97674e6 100644 --- a/configs/datasets/remote_arxiv.yaml +++ b/configs/datasets/remote_arxiv.yaml @@ -2,7 +2,7 @@ project_name: 'dataset-remote-arxiv' dataset: type: 'remote' - subtype: 'arxiv' + source: 'arxiv' lang: 'en' dump_date: 'latest' force_download: false diff --git a/configs/datasets/remote_commoncrawl.yaml b/configs/datasets/remote_commoncrawl.yaml index 444f4bca5..8757d5627 100644 --- a/configs/datasets/remote_commoncrawl.yaml +++ b/configs/datasets/remote_commoncrawl.yaml @@ -2,7 +2,7 @@ project_name: 'dataset-remote-commoncrawl' dataset: type: 'remote' - subtype: 'commoncrawl' + source: 'commoncrawl' start_snapshot: '2020-50' end_snapshot: '2021-04' aws: true diff --git a/configs/datasets/remote_huggingface.yaml b/configs/datasets/remote_huggingface.yaml index 820a01a6d..4da90d00a 100644 --- a/configs/datasets/remote_huggingface.yaml +++ b/configs/datasets/remote_huggingface.yaml @@ -2,7 +2,7 @@ project_name: 'dataset-remote-huggingface' dataset: type: 'remote' - subtype: 'huggingface' + source: 'huggingface' path: "HuggingFaceFW/fineweb" name: "CC-MAIN-2024-10" split: "train" diff --git a/configs/datasets/remote_modelscope.yaml b/configs/datasets/remote_modelscope.yaml index 266df96a9..88b76c461 100644 --- a/configs/datasets/remote_modelscope.yaml +++ b/configs/datasets/remote_modelscope.yaml @@ -2,7 +2,7 @@ project_name: 'dataset-remote-modelscope' dataset: type: 'remote' - subtype: 'modelscope' + source: 'modelscope' path: 'modelscope/clue' subset_name: 'afqmc' split: 'train' diff --git a/configs/datasets/remote_wiki.yaml b/configs/datasets/remote_wiki.yaml index fb6c85b28..6e94c7549 100644 --- a/configs/datasets/remote_wiki.yaml +++ b/configs/datasets/remote_wiki.yaml @@ -2,7 +2,7 @@ project_name: 'dataset-remote-wiki' dataset: type: 'remote' - subtype: 'wiki' + source: 'wiki' lang: 'en' dump_date: 'latest' force_download: false From 456eea14e9f69638e9dfee525f6c408810fdcd97 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 4 Dec 2024 14:22:23 -0800 Subject: [PATCH 10/56] arxiv downloader return Dataset instead of DJDataset --- data_juicer/download/arxiv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_juicer/download/arxiv.py b/data_juicer/download/arxiv.py index b810c4c27..513474f1c 100644 --- a/data_juicer/download/arxiv.py +++ b/data_juicer/download/arxiv.py @@ -5,10 +5,10 @@ import tarfile import tempfile +from datasets import Dataset from downloader import (DocumentDownloader, DocumentExtractor, DocumentIterator, download_and_extract, get_arxiv_urls) -from data_juicer.core.data import DJDataset from data_juicer.utils.file_utils import (expand_outdir_and_mkdir, get_all_files_paths_under) @@ -355,7 +355,7 @@ def download_arxiv( keep_raw_download=False, force_download=False, url_limit=None, -) -> DJDataset: +) -> Dataset: """ Downloads Arxiv tar files and extracts them From c25e40f59f95ba5317ac8f30ec5ffee01b2b4561 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 5 Dec 2024 10:55:26 -0800 Subject: [PATCH 11/56] rewrite CLI datapath with test cases --- data_juicer/core/dataset_builder.py | 131 ++++++++++++------------ data_juicer/core/executor.py | 6 +- data_juicer/format/mixture_formatter.py | 101 +++++++++++------- data_juicer/utils/file_utils.py | 4 +- data_juicer/utils/sample.py | 35 +++++++ tests/core/data/sample.json | 1 + tests/core/data/sample.txt | 1 + tests/core/data/test_config.yaml | 5 + tests/core/test_dataset_builder.py | 45 ++++++++ 9 files changed, 222 insertions(+), 107 deletions(-) create mode 100644 data_juicer/utils/sample.py create mode 100644 tests/core/data/sample.json create mode 100644 tests/core/data/sample.txt create mode 100644 tests/core/data/test_config.yaml create mode 100644 tests/core/test_dataset_builder.py diff --git a/data_juicer/core/dataset_builder.py b/data_juicer/core/dataset_builder.py index 8995e31e8..03d92da8b 100644 --- a/data_juicer/core/dataset_builder.py +++ b/data_juicer/core/dataset_builder.py @@ -1,86 +1,83 @@ import os +from typing import List, Tuple, Union -from data_juicer.format import RemoteFormatter -from data_juicer.format.formatter import FORMATTERS, BaseFormatter -from data_juicer.utils.file_utils import (find_files_with_suffix, - is_absolute_path) +from data_juicer.core.data import NestedDataset +from data_juicer.core.ray_data import RayDataset +from data_juicer.utils.file_utils import is_absolute_path class DatasetBuilder(object): - def __init__(self, cfg): - self.formatters = self.build(cfg) + def __init__(self, + dataset_cfg, + max_samples=0, + generated_dataset_config=None, + text_keys=None, + suffixes=None, + add_suffix=False, + **kwargs): + self.loaders = [] + # mixture or single - def build_formatters(self, cfg): - # build out all the formatters with cfg - formatters = [] - for ds_config in cfg.configs: - formatters.append(self._build_formatter(ds_config)) - return formatters - - def _build_formatter(self, ds_config): - # initialize formatter based on remote or local dataset config - return None - - def build_dataaset(self): + def load_dataset(self) -> Union[NestedDataset, RayDataset]: # handle mixture dataset, nested dataset + # handle sampling of mixture datasets + # for f in self.formatters: f.load_dataset() + return None -def load_formatter(dataset_path, - text_keys=None, - suffixes=None, - add_suffix=False, - **kwargs) -> BaseFormatter: +def rewrite_cli_datapath(dataset_path) -> List: """ - Load the appropriate formatter for different types of data formats. + rewrite the dataset_path from CLI into proper dataset config format + that is compatible with YAML config style; retrofitting CLI input + of local files and huggingface path - :param dataset_path: Path to dataset file or dataset directory - :param text_keys: key names of field that stores sample text. - Default: None - :param suffixes: the suffix of files that will be read. Default: - None - :return: a dataset formatter. + :param dataset_path: a dataset file or a dataset dir or a list of + them, e.g. ` ds1.jsonl ds2_dir ds3_file.json` + :return: list of dataset configs """ + paths, weights = parse_cli_datapath(dataset_path) + ret = [] + for p, w in zip(paths, weights): + if os.path.isdir(p) or os.path.isfile(p): + # local files + ret.append({'type': 'local', 'path': [p], 'weight': w}) + elif not is_absolute_path(p) and not p.startswith( + '.') and p.count('/') <= 1: + # remote huggingface + ret.append({'type': 'huggingface', 'path': p, 'split': 'train'}) + else: + # + raise ValueError( + f'Unable to load the dataset from [{dataset_path}]. ' + f'Data-Juicer CLI mode only supports local files ' + f'w or w/o weights, or huggingface path') + return ret - if suffixes is None: - suffixes = [] - ext_num = {} - if os.path.isdir(dataset_path) or os.path.isfile(dataset_path): - file_dict = find_files_with_suffix(dataset_path, suffixes) - if not file_dict: - raise IOError( - 'Unable to find files matching the suffix from {}'.format( - dataset_path)) - for ext in file_dict: - ext_num[ext] = len(file_dict[ext]) - # local dataset - if ext_num: - formatter_num = {} - for name, formatter in FORMATTERS.modules.items(): - formatter_num[name] = 0 - for ext in ext_num: - if ext in formatter.SUFFIXES: - formatter_num[name] += ext_num[ext] - formatter = max(formatter_num, key=lambda x: formatter_num[x]) - target_suffixes = set(ext_num.keys()).intersection( - set(FORMATTERS.modules[formatter].SUFFIXES)) - return FORMATTERS.modules[formatter](dataset_path, - text_keys=text_keys, - suffixes=target_suffixes, - add_suffix=add_suffix, - **kwargs) +def parse_cli_datapath(dataset_path) -> Tuple[List[str], List[float]]: + """ + Split every dataset path and its weight. + + :param dataset_path: a dataset file or a dataset dir or a list of + them, e.g. ` ds1.jsonl ds2_dir ds3_file.json` + :return: list of dataset path and list of weights + """ + data_prefix = dataset_path.split() + prefixes = [] + weights = [] - # try huggingface dataset hub - elif not is_absolute_path(dataset_path) and dataset_path.count('/') <= 1: - return RemoteFormatter(dataset_path, text_keys=text_keys, **kwargs) + for i in range(len(data_prefix)): + try: + value = max(float(data_prefix[i]), 0.0) + weights.append(value) + except: # noqa: E722 + value = data_prefix[i].strip() + # if not set weight, use 1.0 as default + if i == 0 or len(weights) == len(prefixes): + weights.append(1.0) + prefixes.append(value) - # no data - else: - raise ValueError(f'Unable to load the dataset from [{dataset_path}]. ' - f'It might be because Data-Juicer doesn\'t support ' - f'the format of this dataset, or the path of this ' - f'dataset is incorrect.Please check if it\'s a valid ' - f'dataset path and retry.') + return prefixes, weights diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index 4d4450cf6..882265b9b 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -9,6 +9,7 @@ from data_juicer.config import init_configs from data_juicer.core.data import Dataset +from data_juicer.core.dataset_builder import DatasetBuilder from data_juicer.format.load import load_formatter from data_juicer.format.mixture_formatter import MixtureFormatter from data_juicer.ops import OPERATORS, load_ops @@ -76,14 +77,15 @@ def __init__(self, cfg: Optional[Namespace] = None): f'[{self.cfg.cache_compress}]') cache_utils.CACHE_COMPRESS = self.cfg.cache_compress - # setup formatter - logger.info('Setting up data formatter...') + # setup dataset builder + logger.info('Setting up dataset builder...') self.formatter = load_formatter( dataset_path=self.cfg.dataset_path, generated_dataset_config=self.cfg.generated_dataset_config, text_keys=self.cfg.text_keys, suffixes=self.cfg.suffixes, add_suffix=self.cfg.add_suffix) + self.dataset_builder = DatasetBuilder(cfg) # whether to use checkpoint mechanism. If it's true, Executor will # check if there are existing checkpoints first and try to load the diff --git a/data_juicer/format/mixture_formatter.py b/data_juicer/format/mixture_formatter.py index 35f2bc578..ff34fac3b 100644 --- a/data_juicer/format/mixture_formatter.py +++ b/data_juicer/format/mixture_formatter.py @@ -1,12 +1,15 @@ -from itertools import chain, repeat +import os from typing import List, Union import numpy as np from datasets import Dataset, concatenate_datasets from loguru import logger -from ..core.dataset_builder import load_formatter -from .formatter import BaseFormatter +from data_juicer.format.formatter import (FORMATTERS, BaseFormatter, + RemoteFormatter) +from data_juicer.utils.file_utils import (find_files_with_suffix, + is_absolute_path) +from data_juicer.utils.sample import random_sample class MixtureFormatter(BaseFormatter): @@ -89,38 +92,6 @@ def _get_weight(self, data_prefix): prefixes.append(value) return prefixes, weights - @classmethod - def random_sample(cls, dataset, weight=1.0, sample_number=0, seed=None): - """ - Randomly sample a subset from a dataset with weight or number, - if sample number is bigger than 0, we will use sample - number instead of weight. - :param dataset: a HuggingFace dataset - :param weight: sample ratio of dataset - :param sample_number: sample number of dataset - :param seed: random sample seed, if None, 42 as default - :return: a subset of dataset - """ - if seed is None: - seed = 42 - - ds_samples = dataset.num_rows - if sample_number <= 0: - sample_number = int(np.ceil(ds_samples * weight)) - - if sample_number == ds_samples: - return dataset - - sample_index = range(sample_number) - - n_repeat = int(np.ceil(sample_number / ds_samples)) - 1 - if n_repeat > 0: - remain_samples = sample_number - n_repeat * ds_samples - sample_index = chain(*repeat(range(ds_samples), n_repeat), - range(remain_samples)) - - return dataset.shuffle(seed=seed).select(sample_index) - def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: """ Load a mixed dataset. @@ -134,7 +105,7 @@ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: self.sample_numbers, self.formatters): dataset = formatter.load_dataset(num_proc, global_cfg) - sampled = self.random_sample(dataset, weight, sample_num) + sampled = random_sample(dataset, weight, sample_num) logger.info(f'sampled {len(sampled)} from ' f'{len(dataset)}') dataset_list.append(sampled) @@ -143,3 +114,61 @@ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: mixed_dataset = NestedDataset(concatenate_datasets(dataset_list)) logger.info(f'There are {len(mixed_dataset)} in final dataset') return mixed_dataset + + +def load_formatter(dataset_path, + text_keys=None, + suffixes=None, + add_suffix=False, + **kwargs) -> BaseFormatter: + """ + Load the appropriate formatter for different types of data formats. + + :param dataset_path: Path to dataset file or dataset directory + :param text_keys: key names of field that stores sample text. + Default: None + :param suffixes: the suffix of files that will be read. Default: + None + :return: a dataset formatter. + """ + + if suffixes is None: + suffixes = [] + ext_num = {} + if os.path.isdir(dataset_path) or os.path.isfile(dataset_path): + file_dict = find_files_with_suffix(dataset_path, suffixes) + if not file_dict: + raise IOError( + 'Unable to find files matching the suffix from {}'.format( + dataset_path)) + for ext in file_dict: + ext_num[ext] = len(file_dict[ext]) + + # local dataset + if ext_num: + formatter_num = {} + for name, formatter in FORMATTERS.modules.items(): + formatter_num[name] = 0 + for ext in ext_num: + if ext in formatter.SUFFIXES: + formatter_num[name] += ext_num[ext] + formatter = max(formatter_num, key=lambda x: formatter_num[x]) + target_suffixes = set(ext_num.keys()).intersection( + set(FORMATTERS.modules[formatter].SUFFIXES)) + return FORMATTERS.modules[formatter](dataset_path, + text_keys=text_keys, + suffixes=target_suffixes, + add_suffix=add_suffix, + **kwargs) + + # try huggingface dataset hub + elif not is_absolute_path(dataset_path) and dataset_path.count('/') <= 1: + return RemoteFormatter(dataset_path, text_keys=text_keys, **kwargs) + + # no data + else: + raise ValueError(f'Unable to load the dataset from [{dataset_path}]. ' + f'It might be because Data-Juicer doesn\'t support ' + f'the format of this dataset, or the path of this ' + f'dataset is incorrect.Please check if it\'s a valid ' + f'dataset path and retry.') diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py index 6f9d12869..43c36091c 100644 --- a/data_juicer/utils/file_utils.py +++ b/data_juicer/utils/file_utils.py @@ -8,7 +8,7 @@ import warnings from datetime import datetime, timezone from pathlib import Path -from typing import AsyncGenerator, List, Optional, Union +from typing import AsyncGenerator, Dict, List, Optional, Union import pandas as pd from datasets.utils.extract import ZstdExtractor as Extractor @@ -49,7 +49,7 @@ async def follow_read( def find_files_with_suffix( path: Union[str, Path], - suffixes: Union[str, List[str], None] = None) -> List[str]: + suffixes: Union[str, List[str], None] = None) -> Dict[str, List[str]]: """ Traverse a path to find all files with the specified suffixes. diff --git a/data_juicer/utils/sample.py b/data_juicer/utils/sample.py new file mode 100644 index 000000000..17275c588 --- /dev/null +++ b/data_juicer/utils/sample.py @@ -0,0 +1,35 @@ +from itertools import chain, repeat + +import numpy as np + + +def random_sample(dataset, weight=1.0, sample_number=0, seed=None): + """ + Randomly sample a subset from a dataset with weight or number, + if sample number is bigger than 0, we will use sample + number instead of weight. + :param dataset: a HuggingFace dataset + :param weight: sample ratio of dataset + :param sample_number: sample number of dataset + :param seed: random sample seed, if None, 42 as default + :return: a subset of dataset + """ + if seed is None: + seed = 42 + + ds_samples = dataset.num_rows + if sample_number <= 0: + sample_number = int(np.ceil(ds_samples * weight)) + + if sample_number == ds_samples: + return dataset + + sample_index = range(sample_number) + + n_repeat = int(np.ceil(sample_number / ds_samples)) - 1 + if n_repeat > 0: + remain_samples = sample_number - n_repeat * ds_samples + sample_index = chain(*repeat(range(ds_samples), n_repeat), + range(remain_samples)) + + return dataset.shuffle(seed=seed).select(sample_index) diff --git a/tests/core/data/sample.json b/tests/core/data/sample.json new file mode 100644 index 000000000..95ba2d78a --- /dev/null +++ b/tests/core/data/sample.json @@ -0,0 +1 @@ +{"text": "Today is Sunday and it's a happy day!"} \ No newline at end of file diff --git a/tests/core/data/sample.txt b/tests/core/data/sample.txt new file mode 100644 index 000000000..698ad7c54 --- /dev/null +++ b/tests/core/data/sample.txt @@ -0,0 +1 @@ +Today is Sunday and it's a happy day! diff --git a/tests/core/data/test_config.yaml b/tests/core/data/test_config.yaml new file mode 100644 index 000000000..642ecd958 --- /dev/null +++ b/tests/core/data/test_config.yaml @@ -0,0 +1,5 @@ +project_name: 'dataset-local-json' +dataset: + type: 'local' + path: + - 'sample.json' \ No newline at end of file diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py new file mode 100644 index 000000000..f446b680b --- /dev/null +++ b/tests/core/test_dataset_builder.py @@ -0,0 +1,45 @@ +import os +from data_juicer.core.dataset_builder import rewrite_cli_datapath +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS + +@SKIPPED_TESTS.register_module() +class DatasetBuilderTest(DataJuicerTestCaseBase): + + def test_rewrite_cli_datapath_local_single_file(self): + dataset_path = "./data/sample.txt" + ans = rewrite_cli_datapath(dataset_path) + self.assertEqual( + [{'path': ['./data/sample.txt'], 'type': 'local', 'weight': 1.0}], ans) + + def test_rewrite_cli_datapath_local_directory(self): + dataset_path = "./data" + ans = rewrite_cli_datapath(dataset_path) + self.assertEqual( + [{'path': ['./data'], 'type': 'local', 'weight': 1.0}], ans) + + def test_rewrite_cli_datapath_absolute_path(self): + dataset_path = os.curdir + "/data/sample.txt" + ans = rewrite_cli_datapath(dataset_path) + self.assertEqual( + [{'type': 'local', 'path': [dataset_path], 'weight': 1.0}], ans) + + def test_rewrite_cli_datapath_hf(self): + dataset_path = "hf-internal-testing/librispeech_asr_dummy" + ans = rewrite_cli_datapath(dataset_path) + self.assertEqual([{'path': 'hf-internal-testing/librispeech_asr_dummy', + 'split': 'train', + 'type': 'huggingface'}], + ans) + + def test_rewrite_cli_datapath_local_wrong_files(self): + dataset_path = "./missingDir" + self.assertRaisesRegex(ValueError, "Unable to load the dataset", + rewrite_cli_datapath, dataset_path) + + def test_rewrite_cli_datapath_with_weights(self): + dataset_path = "0.5 ./data/sample.json ./data/sample.txt" + ans = rewrite_cli_datapath(dataset_path) + self.assertEqual( + [{'path': ['./data/sample.json'], 'type': 'local', 'weight': 0.5}, + {'path': ['./data/sample.txt'], 'type': 'local', 'weight': 1.0}], + ans) \ No newline at end of file From 75ffe3f99863be7953513528a3fa6affc323db26 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 5 Dec 2024 17:56:40 -0800 Subject: [PATCH 12/56] add executor and dataload strategy logic --- .../{local_json.yaml => ondisk_json.yaml} | 4 +- ...local_parquet.yaml => ondisk_parquet.yaml} | 4 +- data_juicer/core/__init__.py | 6 +- data_juicer/core/data/__init__.py | 4 + .../core/{ => data}/dataset_builder.py | 30 +++-- .../core/{data.py => data/dj_dataset.py} | 0 data_juicer/core/data/load_strategy.py | 121 ++++++++++++++++++ .../core/{ray_data.py => data/ray_dataset.py} | 114 ++++++++--------- data_juicer/core/executor/__init__.py | 6 + data_juicer/core/executor/base.py | 20 +++ data_juicer/core/executor/factory.py | 23 ++++ .../local_executor.py} | 53 ++------ .../core/{ => executor}/ray_executor.py | 13 +- data_juicer/download/downloader.py | 3 +- data_juicer/utils/sample.py | 54 ++++++++ data_juicer/utils/unittest_utils.py | 2 +- tests/core/test_dataset_builder.py | 12 +- tools/hpo/execute_hpo_3sigma.py | 2 +- tools/hpo/objects.py | 2 +- tools/process_data.py | 2 +- 20 files changed, 339 insertions(+), 136 deletions(-) rename configs/datasets/{local_json.yaml => ondisk_json.yaml} (54%) rename configs/datasets/{local_parquet.yaml => ondisk_parquet.yaml} (54%) create mode 100644 data_juicer/core/data/__init__.py rename data_juicer/core/{ => data}/dataset_builder.py (75%) rename data_juicer/core/{data.py => data/dj_dataset.py} (100%) create mode 100644 data_juicer/core/data/load_strategy.py rename data_juicer/core/{ray_data.py => data/ray_dataset.py} (100%) create mode 100644 data_juicer/core/executor/__init__.py create mode 100644 data_juicer/core/executor/base.py create mode 100644 data_juicer/core/executor/factory.py rename data_juicer/core/{executor.py => executor/local_executor.py} (87%) rename data_juicer/core/{ => executor}/ray_executor.py (92%) diff --git a/configs/datasets/local_json.yaml b/configs/datasets/ondisk_json.yaml similarity index 54% rename from configs/datasets/local_json.yaml rename to configs/datasets/ondisk_json.yaml index e1b27b304..985fd71f4 100644 --- a/configs/datasets/local_json.yaml +++ b/configs/datasets/ondisk_json.yaml @@ -1,6 +1,6 @@ # global parameters -project_name: 'dataset-local-json' +project_name: 'dataset-ondisk-json' dataset: - type: 'local' + type: 'ondisk' path: - 'path/to/json/file' diff --git a/configs/datasets/local_parquet.yaml b/configs/datasets/ondisk_parquet.yaml similarity index 54% rename from configs/datasets/local_parquet.yaml rename to configs/datasets/ondisk_parquet.yaml index ff9c67a31..ce7ad46cf 100644 --- a/configs/datasets/local_parquet.yaml +++ b/configs/datasets/ondisk_parquet.yaml @@ -1,6 +1,6 @@ # global parameters -project_name: 'dataset-local-parquet' +project_name: 'dataset-ondisk-parquet' dataset: - type: 'local' + type: 'ondisk' path: - 'path/to/parquet/file' diff --git a/data_juicer/core/__init__.py b/data_juicer/core/__init__.py index 6fbb9d5d6..edb24ad7b 100644 --- a/data_juicer/core/__init__.py +++ b/data_juicer/core/__init__.py @@ -1,7 +1,7 @@ from .adapter import Adapter from .analyzer import Analyzer from .data import NestedDataset -from .executor import Executor +from .executor import ExecutorFactory, LocalExecutor, RayExecutor from .exporter import Exporter from .monitor import Monitor from .tracer import Tracer @@ -10,7 +10,9 @@ 'Adapter', 'Analyzer', 'NestedDataset', - 'Executor', + 'ExecutorFactory', + 'LocalExecutor', + 'RayExecutor', 'Exporter', 'Monitor', 'Tracer', diff --git a/data_juicer/core/data/__init__.py b/data_juicer/core/data/__init__.py new file mode 100644 index 000000000..0c8ec69ca --- /dev/null +++ b/data_juicer/core/data/__init__.py @@ -0,0 +1,4 @@ +from .dj_dataset import DJDataset, NestedDataset +from .ray_dataset import RayDataset + +__all__ = ['DJDataset', 'NestedDataset', 'RayDataset'] diff --git a/data_juicer/core/dataset_builder.py b/data_juicer/core/data/dataset_builder.py similarity index 75% rename from data_juicer/core/dataset_builder.py rename to data_juicer/core/data/dataset_builder.py index 03d92da8b..1dd5a830c 100644 --- a/data_juicer/core/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -2,22 +2,24 @@ from typing import List, Tuple, Union from data_juicer.core.data import NestedDataset -from data_juicer.core.ray_data import RayDataset +from data_juicer.core.data.ray_dataset import RayDataset from data_juicer.utils.file_utils import is_absolute_path class DatasetBuilder(object): - def __init__(self, - dataset_cfg, - max_samples=0, - generated_dataset_config=None, - text_keys=None, - suffixes=None, - add_suffix=False, - **kwargs): - self.loaders = [] - # mixture or single + def __init__(self, cfg): + if cfg.dataset_path is not None: + ds_configs = rewrite_cli_datapath(cfg.dataset_path) + elif cfg.dataset is not None: + ds_configs = cfg.dataset + else: + raise ValueError( + 'Unable to initialize dataset; should have one of ' + 'dataset_path or dataset in configurations') + for config in ds_configs: + # initialize data loader from strategy + pass def load_dataset(self) -> Union[NestedDataset, RayDataset]: # handle mixture dataset, nested dataset @@ -43,9 +45,9 @@ def rewrite_cli_datapath(dataset_path) -> List: for p, w in zip(paths, weights): if os.path.isdir(p) or os.path.isfile(p): # local files - ret.append({'type': 'local', 'path': [p], 'weight': w}) - elif not is_absolute_path(p) and not p.startswith( - '.') and p.count('/') <= 1: + ret.append({'type': 'ondisk', 'path': [p], 'weight': w}) + elif (not is_absolute_path(p) and not p.startswith('.') + and p.count('/') <= 1): # remote huggingface ret.append({'type': 'huggingface', 'path': p, 'split': 'train'}) else: diff --git a/data_juicer/core/data.py b/data_juicer/core/data/dj_dataset.py similarity index 100% rename from data_juicer/core/data.py rename to data_juicer/core/data/dj_dataset.py diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py new file mode 100644 index 000000000..8b8d93e7c --- /dev/null +++ b/data_juicer/core/data/load_strategy.py @@ -0,0 +1,121 @@ +from abc import ABC, abstractmethod +from typing import Union + +from data_juicer.core.data import DJDataset, RayDataset + +# based on executor type and data source type, use different +# data load strategy to product corresponding datasets +# DJDataset, RayDataset, DaskDataset, etc + + +class DataLoadStrategyRegistry: + + def __init__(self): + self._registry = {} + + def register(self, key: tuple, strategy): + """Register a strategy for a specific tuple key.""" + if key in self._registry: + raise ValueError(f'Strategy for key {key} is already registered.') + self._registry[key] = strategy + + def get_strategy(self, key: tuple): + """Retrieve the strategy for a specific tuple key.""" + if key not in self._registry: + raise ValueError(f'No strategy registered for key {key}.') + return self._registry[key] + + def register_decorator(self, key: tuple): + """Decorator for registering a strategy with a specific tuple key.""" + + def decorator(func): + self.register(key, func) + return func # Return the original function + + return decorator + + +DATALOAD_STRATEGY_REGISTRY = DataLoadStrategyRegistry() + + +class DataLoadStrategyFactory: + + @classmethod + def create_dataload_strategy(cls, executor_type, dataset_type, + dataset_source): + DATALOAD_STRATEGY_REGISTRY.get_strategy( + (executor_type, dataset_type, dataset_source)) + + +class DataLoadStrategy(ABC): + + @abstractmethod + def load_data(self) -> Union[DJDataset, RayDataset]: + pass + + +@DATALOAD_STRATEGY_REGISTRY.register_decorator(('ray', 'ondisk', 'json')) +class RayOndiskJsonDataLoadStrategy(DataLoadStrategy): + + def load_data(self): + pass + + +@DATALOAD_STRATEGY_REGISTRY.register_decorator( + ('ray', 'remote', 'huggingface')) +class RayHuggingfaceDataLoadStrategy(DataLoadStrategy): + + def load_data(self): + pass + + +@DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'ondisk', 'Json')) +class LocalOndiskJsonDataLoadStrategy(DataLoadStrategy): + + def load_data(self): + pass + + +@DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'ondisk', 'Parquet')) +class LocalOndiskParquetDataLoadStrategy(DataLoadStrategy): + + def load_data(self): + pass + + +@DATALOAD_STRATEGY_REGISTRY.register_decorator( + ('local', 'remote', 'huggingface')) +class LocalHuggingfaceDataLoadStrategy(DataLoadStrategy): + + def load_data(self): + pass + + +@DATALOAD_STRATEGY_REGISTRY.register_decorator( + ('local', 'remote', 'modelscope')) +class LocalModelScopeDataLoadStrategy(DataLoadStrategy): + + def load_data(self): + pass + + +@DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'remote', 'arxiv')) +class LocalArxivDataLoadStrategy(DataLoadStrategy): + + def load_data(self): + pass + + +@DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'remote', 'wiki')) +class LocalWikiDataLoadStrategy(DataLoadStrategy): + + def load_data(self): + pass + + +@DATALOAD_STRATEGY_REGISTRY.register_decorator( + ('local', 'remote', 'commoncrawl')) +class LocalCommonCrawlDataLoadStrategy(DataLoadStrategy): + + def load_data(self): + pass diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/data/ray_dataset.py similarity index 100% rename from data_juicer/core/ray_data.py rename to data_juicer/core/data/ray_dataset.py index 0c131561e..352cd95dc 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/data/ray_dataset.py @@ -13,6 +13,63 @@ rd = LazyLoader('rd', 'ray.data') +class RayDataset(DJDataset): + + def __init__(self, + dataset: rd.Dataset, + dataset_path: str = None, + cfg=None) -> None: + self.data = preprocess_dataset(dataset, dataset_path, cfg) + self.num_proc = None + if cfg: + self.num_proc = cfg.np + + def process(self, + operators, + *, + exporter=None, + checkpointer=None, + tracer=None) -> DJDataset: + if operators is None: + return self + if not isinstance(operators, list): + operators = [operators] + for op in operators: + self._run_single_op(op) + return self + + def _run_single_op(self, op): + op_proc = calculate_np(op._name, op.mem_required, op.cpu_required, + self.num_proc, op.use_cuda()) + num_gpus = get_num_gpus(op, op_proc) + try: + batch_size = getattr(op, 'batch_size', + 1) if op.is_batched_op() else 1 + if isinstance(op, Mapper): + self.data = self.data.map_batches(op.process, + batch_size=batch_size, + batch_format='pyarrow', + num_gpus=num_gpus) + elif isinstance(op, Filter): + self.data = self.data.map_batches(op.compute_stats, + batch_size=batch_size, + batch_format='pyarrow', + num_gpus=num_gpus) + if op.stats_export_path is not None: + self.data.write_json(op.stats_export_path, + force_ascii=False) + self.data = self.data.filter(op.process) + else: + logger.error( + 'Ray executor only support Filter and Mapper OPs for now') + raise NotImplementedError + except: # noqa: E722 + logger.error(f'An error occurred during Op [{op._name}].') + import traceback + traceback.print_exc() + exit(1) + + def is_valid_path(item, dataset_dir): full_path = os.path.abspath(os.path.join(dataset_dir, item)) return os.path.exists(full_path) @@ -75,60 +132,3 @@ def get_num_gpus(op, op_proc): return 0 proc_per_gpu = op_proc / cuda_device_count() return 1.0 / proc_per_gpu - - -class RayDataset(DJDataset): - - def __init__(self, - dataset: rd.Dataset, - dataset_path: str = None, - cfg=None) -> None: - self.data = preprocess_dataset(dataset, dataset_path, cfg) - self.num_proc = None - if cfg: - self.num_proc = cfg.np - - def process(self, - operators, - *, - exporter=None, - checkpointer=None, - tracer=None) -> DJDataset: - if operators is None: - return self - if not isinstance(operators, list): - operators = [operators] - for op in operators: - self._run_single_op(op) - return self - - def _run_single_op(self, op): - op_proc = calculate_np(op._name, op.mem_required, op.cpu_required, - self.num_proc, op.use_cuda()) - num_gpus = get_num_gpus(op, op_proc) - try: - batch_size = getattr(op, 'batch_size', - 1) if op.is_batched_op() else 1 - if isinstance(op, Mapper): - self.data = self.data.map_batches(op.process, - batch_size=batch_size, - batch_format='pyarrow', - num_gpus=num_gpus) - elif isinstance(op, Filter): - self.data = self.data.map_batches(op.compute_stats, - batch_size=batch_size, - batch_format='pyarrow', - num_gpus=num_gpus) - if op.stats_export_path is not None: - self.data.write_json(op.stats_export_path, - force_ascii=False) - self.data = self.data.filter(op.process) - else: - logger.error( - 'Ray executor only support Filter and Mapper OPs for now') - raise NotImplementedError - except: # noqa: E722 - logger.error(f'An error occurred during Op [{op._name}].') - import traceback - traceback.print_exc() - exit(1) diff --git a/data_juicer/core/executor/__init__.py b/data_juicer/core/executor/__init__.py new file mode 100644 index 000000000..31c448137 --- /dev/null +++ b/data_juicer/core/executor/__init__.py @@ -0,0 +1,6 @@ +from .base import ExecutorBase +from .factory import ExecutorFactory +from .local_executor import LocalExecutor +from .ray_executor import RayExecutor + +__all__ = ['ExecutorBase', 'ExecutorFactory', 'LocalExecutor', 'RayExecutor'] diff --git a/data_juicer/core/executor/base.py b/data_juicer/core/executor/base.py new file mode 100644 index 000000000..9ed7e602d --- /dev/null +++ b/data_juicer/core/executor/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from jsonargparse import Namespace +from pydantic import PositiveInt + +from data_juicer.config import init_configs + + +class ExecutorBase(ABC): + + @abstractmethod + def __init__(self, cfg: Optional[Namespace] = None): + self.cfg = init_configs() if cfg is None else cfg + + @abstractmethod + def run(self, + load_data_np: Optional[PositiveInt] = None, + skip_return=False): + raise NotImplementedError diff --git a/data_juicer/core/executor/factory.py b/data_juicer/core/executor/factory.py new file mode 100644 index 000000000..22944068e --- /dev/null +++ b/data_juicer/core/executor/factory.py @@ -0,0 +1,23 @@ +from typing import Union + +from local_executor import LocalExecutor +from ray_executor import RayExecutor + + +class ExecutorFactory: + + @staticmethod + def create_executor( + executor_type: str) -> Union[LocalExecutor, RayExecutor]: + if executor_type == 'local': + return LocalExecutor() + elif executor_type == 'ray': + return RayExecutor() + # TODO: add nemo support + # elif executor_type == "nemo": + # return NemoExecutor() + # TODO: add dask support + # elif executor_type == "dask": + # return DaskExecutor() + else: + raise ValueError('Unsupported executor type') diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor/local_executor.py similarity index 87% rename from data_juicer/core/executor.py rename to data_juicer/core/executor/local_executor.py index 882265b9b..b0eca6435 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor/local_executor.py @@ -1,54 +1,30 @@ import os -from abc import ABC, abstractmethod from time import time -from typing import List, Optional, Tuple +from typing import Optional +from datasets import Dataset from jsonargparse import Namespace from loguru import logger from pydantic import PositiveInt -from data_juicer.config import init_configs -from data_juicer.core.data import Dataset -from data_juicer.core.dataset_builder import DatasetBuilder +from data_juicer.core.adapter import Adapter +from data_juicer.core.data.dataset_builder import DatasetBuilder +from data_juicer.core.executor import ExecutorBase +from data_juicer.core.exporter import Exporter +from data_juicer.core.tracer import Tracer from data_juicer.format.load import load_formatter from data_juicer.format.mixture_formatter import MixtureFormatter from data_juicer.ops import OPERATORS, load_ops from data_juicer.ops.op_fusion import fuse_operators -from data_juicer.utils import cache_utils -from data_juicer.utils.ckpt_utils import CheckpointManager - -from ..ops.selector.frequency_specified_field_selector import \ +from data_juicer.ops.selector.frequency_specified_field_selector import \ FrequencySpecifiedFieldSelector -from ..ops.selector.topk_specified_field_selector import \ +from data_juicer.ops.selector.topk_specified_field_selector import \ TopkSpecifiedFieldSelector -from .adapter import Adapter -from .exporter import Exporter -from .tracer import Tracer - - -class ExecutorBase(ABC): - - @abstractmethod - def __init__(self, cfg: Optional[Namespace] = None): - pass - - @abstractmethod - def run(self, - load_data_np: Optional[PositiveInt] = None, - skip_return=False): - pass - - @abstractmethod - def can_handle_data(self, types: List[Tuple[str, str]]): - """ - types is a list of tuples, [(type, subtype), (type, subtype), ...]; - different executor types will specific whether it can handle these - type/subtype combos - """ - pass +from data_juicer.utils import cache_utils +from data_juicer.utils.ckpt_utils import CheckpointManager -class Executor(ExecutorBase): +class LocalExecutor(ExecutorBase): """ This Executor class is used to process a specific dataset. @@ -62,7 +38,7 @@ def __init__(self, cfg: Optional[Namespace] = None): :param cfg: optional jsonargparse Namespace. """ - self.cfg = init_configs() if cfg is None else cfg + super().__init__(cfg) self.work_dir = self.cfg.work_dir @@ -236,6 +212,3 @@ def run(self, if not skip_return: return dataset - - def can_handle_data(self, types: List[Tuple[str, str]]): - pass diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/executor/ray_executor.py similarity index 92% rename from data_juicer/core/ray_executor.py rename to data_juicer/core/executor/ray_executor.py index 9e5348c45..c6fcef5e8 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/executor/ray_executor.py @@ -4,15 +4,13 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.config import init_configs -from data_juicer.core.ray_data import RayDataset +from data_juicer.core.adapter import Adapter +from data_juicer.core.data.ray_dataset import RayDataset +from data_juicer.core.executor import ExecutorBase from data_juicer.ops import load_ops from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils.lazy_loader import LazyLoader -from .adapter import Adapter -from .executor import ExecutorBase - ray = LazyLoader('ray', 'ray') rd = LazyLoader('rd', 'ray.data') @@ -35,14 +33,13 @@ def __init__(self, cfg=None): :param cfg: optional config dict. """ - self.cfg = init_configs() if cfg is None else cfg + super().__init__(cfg) self.work_dir = self.cfg.work_dir - self.adapter = Adapter(self.cfg) # init ray - logger.info('Initing Ray ...') + logger.info('Initializing Ray ...') ray.init(self.cfg.ray_address) def run(self, diff --git a/data_juicer/download/downloader.py b/data_juicer/download/downloader.py index b15c67bc2..a5b1db3ae 100644 --- a/data_juicer/download/downloader.py +++ b/data_juicer/download/downloader.py @@ -51,7 +51,7 @@ def extract(self, content): pass -def _download_and_extract_single_partition(paths: List[Tuple[str, str]], +def _download_and_extract_single_partition(paths: Tuple[str, str], downloader: DocumentDownloader, iterator: DocumentIterator, extractor: DocumentExtractor, @@ -131,6 +131,7 @@ def download_and_extract(urls: List[str], directly read from them instead. input_meta: A dictionary or a string formatted as a dictionary, which outlines the field names and their respective data types within the JSONL input file. + item_limit: limit on number of items downloaded; for sampling and testing purposes Returns: A HuggingFace DataSet of the downloaded data diff --git a/data_juicer/utils/sample.py b/data_juicer/utils/sample.py index 17275c588..0164dbec0 100644 --- a/data_juicer/utils/sample.py +++ b/data_juicer/utils/sample.py @@ -1,6 +1,60 @@ from itertools import chain, repeat import numpy as np +from datasets import Dataset +from loguru import logger + +from data_juicer.ops.selector import (FrequencySpecifiedFieldSelector, + TopkSpecifiedFieldSelector) + + +class SamplingMixin: + + def sample_data(self, + dataset_to_sample: Dataset = None, + load_data_np=None, + sample_ratio: float = 1.0, + sample_algo: str = 'uniform', + **kwargs): + """ + Sample a subset from the given dataset. + + :param dataset_to_sample: Dataset to sample from. If None, will use + the formatter linked by the executor. Default is None. + :param load_data_np: number of workers when loading the dataset. + :param sample_ratio: The ratio of the sample size to the original + dataset size. Default is 1.0 (no sampling). + :param sample_algo: Sampling algorithm to use. Options are "uniform", + "frequency_specified_field_selector", or + "topk_specified_field_selector". + Default is "uniform". + :return: A sampled Dataset. + """ + # Determine the dataset to sample from + if dataset_to_sample is not None: + dataset = dataset_to_sample + elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available: + logger.info('Loading dataset from checkpoint...') + dataset = self.ckpt_manager.load_ckpt() + elif hasattr(self, 'formatter'): + logger.info('Loading dataset from data formatter...') + if load_data_np is None: + load_data_np = self.cfg.np + dataset = self.formatter.load_dataset(load_data_np, self.cfg) + else: + raise ValueError('No dataset available to sample from.') + + # Perform sampling based on the specified algorithm + if sample_algo == 'uniform': + return random_sample(dataset, sample_ratio) + elif sample_algo == 'frequency_specified_field_selector': + dj_op = FrequencySpecifiedFieldSelector(**kwargs) + return dj_op.process(dataset) + elif sample_algo == 'topk_specified_field_selector': + dj_op = TopkSpecifiedFieldSelector(**kwargs) + return dj_op.process(dataset) + else: + raise ValueError(f'Unsupported sample_algo: {sample_algo}') def random_sample(dataset, weight=1.0, sample_number=0, seed=None): diff --git a/data_juicer/utils/unittest_utils.py b/data_juicer/utils/unittest_utils.py index 1e66c55cc..5e8db5808 100644 --- a/data_juicer/utils/unittest_utils.py +++ b/data_juicer/utils/unittest_utils.py @@ -6,7 +6,7 @@ from data_juicer import is_cuda_available from data_juicer.core.data import DJDataset, NestedDataset -from data_juicer.core.ray_data import RayDataset +from data_juicer.core.data.ray_dataset import RayDataset from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import free_models from data_juicer.utils.registry import Registry diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index f446b680b..ad55ec867 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -1,5 +1,5 @@ import os -from data_juicer.core.dataset_builder import rewrite_cli_datapath +from data_juicer.core.data.dataset_builder import rewrite_cli_datapath from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS @SKIPPED_TESTS.register_module() @@ -9,19 +9,19 @@ def test_rewrite_cli_datapath_local_single_file(self): dataset_path = "./data/sample.txt" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'path': ['./data/sample.txt'], 'type': 'local', 'weight': 1.0}], ans) + [{'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}], ans) def test_rewrite_cli_datapath_local_directory(self): dataset_path = "./data" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'path': ['./data'], 'type': 'local', 'weight': 1.0}], ans) + [{'path': ['./data'], 'type': 'ondisk', 'weight': 1.0}], ans) def test_rewrite_cli_datapath_absolute_path(self): dataset_path = os.curdir + "/data/sample.txt" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'type': 'local', 'path': [dataset_path], 'weight': 1.0}], ans) + [{'type': 'ondisk', 'path': [dataset_path], 'weight': 1.0}], ans) def test_rewrite_cli_datapath_hf(self): dataset_path = "hf-internal-testing/librispeech_asr_dummy" @@ -40,6 +40,6 @@ def test_rewrite_cli_datapath_with_weights(self): dataset_path = "0.5 ./data/sample.json ./data/sample.txt" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'path': ['./data/sample.json'], 'type': 'local', 'weight': 0.5}, - {'path': ['./data/sample.txt'], 'type': 'local', 'weight': 1.0}], + [{'path': ['./data/sample.json'], 'type': 'ondisk', 'weight': 0.5}, + {'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}], ans) \ No newline at end of file diff --git a/tools/hpo/execute_hpo_3sigma.py b/tools/hpo/execute_hpo_3sigma.py index 4073fa0d9..0d8cdcd3d 100644 --- a/tools/hpo/execute_hpo_3sigma.py +++ b/tools/hpo/execute_hpo_3sigma.py @@ -36,7 +36,7 @@ def main(): if cfg.executor_type == 'default': executor = Executor(cfg) elif cfg.executor_type == 'ray': - from data_juicer.core.ray_executor import RayExecutor + from data_juicer.core.executor.ray_executor import RayExecutor executor = RayExecutor(cfg) executor.run() diff --git a/tools/hpo/objects.py b/tools/hpo/objects.py index eff8eff59..e398d1550 100644 --- a/tools/hpo/objects.py +++ b/tools/hpo/objects.py @@ -33,7 +33,7 @@ def obj_quality_score(dj_cfg): if dj_cfg.executor_type == 'default': executor = Executor(dj_cfg) elif dj_cfg.executor_type == 'ray': - from data_juicer.core.ray_executor import RayExecutor + from data_juicer.core.executor.ray_executor import RayExecutor executor = RayExecutor(dj_cfg) else: raise NotImplementedError( diff --git a/tools/process_data.py b/tools/process_data.py index a97ef9a40..71241cf41 100644 --- a/tools/process_data.py +++ b/tools/process_data.py @@ -10,7 +10,7 @@ def main(): if cfg.executor_type == 'default': executor = Executor(cfg) elif cfg.executor_type == 'ray': - from data_juicer.core.ray_executor import RayExecutor + from data_juicer.core.executor.ray_executor import RayExecutor executor = RayExecutor(cfg) executor.run() From 4fb6e176de7adb7b638a3518fe8370deb645234f Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Fri, 6 Dec 2024 08:48:37 -0800 Subject: [PATCH 13/56] add layered load strategies --- configs/datasets/ondisk_json.yaml | 6 ++-- configs/datasets/ondisk_parquet.yaml | 6 ++-- configs/datasets/remote_arxiv.yaml | 12 ++++---- configs/datasets/remote_commoncrawl.yaml | 14 ++++----- configs/datasets/remote_huggingface.yaml | 12 ++++---- data_juicer/config/config.py | 9 +++++- data_juicer/core/data/load_strategy.py | 39 ++++++++++++++++++------ data_juicer/core/sandbox/factories.py | 2 +- 8 files changed, 64 insertions(+), 36 deletions(-) diff --git a/configs/datasets/ondisk_json.yaml b/configs/datasets/ondisk_json.yaml index 985fd71f4..af3ac7049 100644 --- a/configs/datasets/ondisk_json.yaml +++ b/configs/datasets/ondisk_json.yaml @@ -1,6 +1,6 @@ # global parameters project_name: 'dataset-ondisk-json' dataset: - type: 'ondisk' - path: - - 'path/to/json/file' + - type: 'ondisk' + path: + - 'path/to/json/file' diff --git a/configs/datasets/ondisk_parquet.yaml b/configs/datasets/ondisk_parquet.yaml index ce7ad46cf..0e5256a20 100644 --- a/configs/datasets/ondisk_parquet.yaml +++ b/configs/datasets/ondisk_parquet.yaml @@ -1,6 +1,6 @@ # global parameters project_name: 'dataset-ondisk-parquet' dataset: - type: 'ondisk' - path: - - 'path/to/parquet/file' + - type: 'ondisk' + path: + - 'path/to/parquet/file' diff --git a/configs/datasets/remote_arxiv.yaml b/configs/datasets/remote_arxiv.yaml index fe97674e6..1d14596c2 100644 --- a/configs/datasets/remote_arxiv.yaml +++ b/configs/datasets/remote_arxiv.yaml @@ -1,9 +1,9 @@ # global parameters project_name: 'dataset-remote-arxiv' dataset: - type: 'remote' - source: 'arxiv' - lang: 'en' - dump_date: 'latest' - force_download: false - url_limit: 2 + - type: 'remote' + source: 'arxiv' + lang: 'en' + dump_date: 'latest' + force_download: false + url_limit: 2 diff --git a/configs/datasets/remote_commoncrawl.yaml b/configs/datasets/remote_commoncrawl.yaml index 8757d5627..ce9f7a049 100644 --- a/configs/datasets/remote_commoncrawl.yaml +++ b/configs/datasets/remote_commoncrawl.yaml @@ -1,10 +1,10 @@ # global parameters project_name: 'dataset-remote-commoncrawl' dataset: - type: 'remote' - source: 'commoncrawl' - start_snapshot: '2020-50' - end_snapshot: '2021-04' - aws: true - force_download: false - url_limit: 2 + - type: 'remote' + source: 'commoncrawl' + start_snapshot: '2020-50' + end_snapshot: '2021-04' + aws: true + force_download: false + url_limit: 2 diff --git a/configs/datasets/remote_huggingface.yaml b/configs/datasets/remote_huggingface.yaml index 4da90d00a..13738a6c5 100644 --- a/configs/datasets/remote_huggingface.yaml +++ b/configs/datasets/remote_huggingface.yaml @@ -1,9 +1,9 @@ # global parameters project_name: 'dataset-remote-huggingface' dataset: - type: 'remote' - source: 'huggingface' - path: "HuggingFaceFW/fineweb" - name: "CC-MAIN-2024-10" - split: "train" - limit: 1000 + - type: 'remote' + source: 'huggingface' + path: "HuggingFaceFW/fineweb" + name: "CC-MAIN-2024-10" + split: "train" + limit: 1000 diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 76a20b786..c27f70c5c 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -10,7 +10,7 @@ import yaml from jsonargparse import (ActionConfigFile, ArgumentParser, Namespace, dict_to_namespace, namespace_to_dict) -from jsonargparse.typehints import ActionTypeHint +from jsonargparse._typehints import ActionTypeHint from jsonargparse.typing import ClosedUnitInterval, NonNegativeInt, PositiveInt from loguru import logger @@ -87,6 +87,13 @@ def init_configs(args: Optional[List[str]] = None): help='Path to datasets with optional weights(0.0-1.0), 1.0 as ' 'default. Accepted format: dataset1-path dataset2-path ' ' dataset3-path ...') + parser.add_argument( + '--dataset', + type=Union[List[Dict], Dict], + default=[], + help='Dataset setting to define local/remote datasets; could be a ' + 'dict or a list of dicts; refer to configs/datasets for more ' + 'detailed examples') parser.add_argument( '--generated_dataset_config', type=Dict, diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 8b8d93e7c..3d44bf34d 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -54,8 +54,29 @@ def load_data(self) -> Union[DJDataset, RayDataset]: pass +class RayDataLoadStrategy(DataLoadStrategy): + + @abstractmethod + def load_data(self) -> RayDataset: + pass + + +class LocalDataLoadStrategy(DataLoadStrategy): + + @abstractmethod + def load_data(self) -> DJDataset: + pass + + +# TODO dask support +# class DaskDataLoadStrategy(DataLoadStrategy): +# @abstractmethod +# def load_data(self) -> Union[DaskDataset]: +# pass + + @DATALOAD_STRATEGY_REGISTRY.register_decorator(('ray', 'ondisk', 'json')) -class RayOndiskJsonDataLoadStrategy(DataLoadStrategy): +class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): def load_data(self): pass @@ -63,21 +84,21 @@ def load_data(self): @DATALOAD_STRATEGY_REGISTRY.register_decorator( ('ray', 'remote', 'huggingface')) -class RayHuggingfaceDataLoadStrategy(DataLoadStrategy): +class RayHuggingfaceDataLoadStrategy(RayDataLoadStrategy): def load_data(self): pass @DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'ondisk', 'Json')) -class LocalOndiskJsonDataLoadStrategy(DataLoadStrategy): +class LocalOndiskJsonDataLoadStrategy(LocalDataLoadStrategy): def load_data(self): pass @DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'ondisk', 'Parquet')) -class LocalOndiskParquetDataLoadStrategy(DataLoadStrategy): +class LocalOndiskParquetDataLoadStrategy(LocalDataLoadStrategy): def load_data(self): pass @@ -85,7 +106,7 @@ def load_data(self): @DATALOAD_STRATEGY_REGISTRY.register_decorator( ('local', 'remote', 'huggingface')) -class LocalHuggingfaceDataLoadStrategy(DataLoadStrategy): +class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy): def load_data(self): pass @@ -93,21 +114,21 @@ def load_data(self): @DATALOAD_STRATEGY_REGISTRY.register_decorator( ('local', 'remote', 'modelscope')) -class LocalModelScopeDataLoadStrategy(DataLoadStrategy): +class LocalModelScopeDataLoadStrategy(LocalDataLoadStrategy): def load_data(self): pass @DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'remote', 'arxiv')) -class LocalArxivDataLoadStrategy(DataLoadStrategy): +class LocalArxivDataLoadStrategy(LocalDataLoadStrategy): def load_data(self): pass @DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'remote', 'wiki')) -class LocalWikiDataLoadStrategy(DataLoadStrategy): +class LocalWikiDataLoadStrategy(LocalDataLoadStrategy): def load_data(self): pass @@ -115,7 +136,7 @@ def load_data(self): @DATALOAD_STRATEGY_REGISTRY.register_decorator( ('local', 'remote', 'commoncrawl')) -class LocalCommonCrawlDataLoadStrategy(DataLoadStrategy): +class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy): def load_data(self): pass diff --git a/data_juicer/core/sandbox/factories.py b/data_juicer/core/sandbox/factories.py index a0b7fc8b7..c3aeeedc0 100644 --- a/data_juicer/core/sandbox/factories.py +++ b/data_juicer/core/sandbox/factories.py @@ -1,5 +1,5 @@ from data_juicer.core import Analyzer as DJAnalyzer -from data_juicer.core import Executor as DJExecutor +from data_juicer.core.executor import LocalExecutor as DJExecutor from data_juicer.core.sandbox.evaluators import (Gpt3QualityEvaluator, InceptionEvaluator, VBenchEvaluator) From cb5b80a4ba0768a5abebc60ea56e62a292d4cd79 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Mon, 9 Dec 2024 22:21:13 -0800 Subject: [PATCH 14/56] fix circular dependency; add dataset config test --- data_juicer/core/__init__.py | 3 +- data_juicer/core/data/__init__.py | 6 +- data_juicer/core/executor/factory.py | 4 +- data_juicer/core/executor/local_executor.py | 102 ++++++++++---------- data_juicer/ops/selector/random_selector.py | 9 +- data_juicer/utils/sample.py | 54 ----------- tests/core/data/test_config.yaml | 4 +- tests/core/test_dataset_builder.py | 24 ++++- 8 files changed, 87 insertions(+), 119 deletions(-) diff --git a/data_juicer/core/__init__.py b/data_juicer/core/__init__.py index edb24ad7b..f5450fabc 100644 --- a/data_juicer/core/__init__.py +++ b/data_juicer/core/__init__.py @@ -1,7 +1,7 @@ from .adapter import Adapter from .analyzer import Analyzer from .data import NestedDataset -from .executor import ExecutorFactory, LocalExecutor, RayExecutor +from .executor import ExecutorBase, ExecutorFactory, LocalExecutor, RayExecutor from .exporter import Exporter from .monitor import Monitor from .tracer import Tracer @@ -13,6 +13,7 @@ 'ExecutorFactory', 'LocalExecutor', 'RayExecutor', + 'ExecutorBase', 'Exporter', 'Monitor', 'Tracer', diff --git a/data_juicer/core/data/__init__.py b/data_juicer/core/data/__init__.py index 0c8ec69ca..d93899665 100644 --- a/data_juicer/core/data/__init__.py +++ b/data_juicer/core/data/__init__.py @@ -1,4 +1,6 @@ -from .dj_dataset import DJDataset, NestedDataset +from .dj_dataset import DJDataset, NestedDataset, wrap_func_with_nested_access from .ray_dataset import RayDataset -__all__ = ['DJDataset', 'NestedDataset', 'RayDataset'] +__all__ = [ + 'DJDataset', 'NestedDataset', 'RayDataset', 'wrap_func_with_nested_access' +] diff --git a/data_juicer/core/executor/factory.py b/data_juicer/core/executor/factory.py index 22944068e..3133e3b1a 100644 --- a/data_juicer/core/executor/factory.py +++ b/data_juicer/core/executor/factory.py @@ -1,7 +1,7 @@ from typing import Union -from local_executor import LocalExecutor -from ray_executor import RayExecutor +from .local_executor import LocalExecutor +from .ray_executor import RayExecutor class ExecutorFactory: diff --git a/data_juicer/core/executor/local_executor.py b/data_juicer/core/executor/local_executor.py index 5cb5f27d1..ac8aca353 100644 --- a/data_juicer/core/executor/local_executor.py +++ b/data_juicer/core/executor/local_executor.py @@ -13,15 +13,13 @@ from data_juicer.core.exporter import Exporter from data_juicer.core.tracer import Tracer from data_juicer.format.load import load_formatter -from data_juicer.format.mixture_formatter import MixtureFormatter from data_juicer.ops import OPERATORS, load_ops from data_juicer.ops.op_fusion import fuse_operators -from data_juicer.ops.selector.frequency_specified_field_selector import \ - FrequencySpecifiedFieldSelector -from data_juicer.ops.selector.topk_specified_field_selector import \ - TopkSpecifiedFieldSelector +from data_juicer.ops.selector import (FrequencySpecifiedFieldSelector, + TopkSpecifiedFieldSelector) from data_juicer.utils import cache_utils from data_juicer.utils.ckpt_utils import CheckpointManager +from data_juicer.utils.sample import random_sample class LocalExecutor(ExecutorBase): @@ -97,52 +95,6 @@ def __init__(self, cfg: Optional[Namespace] = None): logger.info('Trace for all ops.') self.op_list_to_trace = set(OPERATORS.modules.keys()) - def sample_data(self, - dataset_to_sample: Dataset = None, - load_data_np=None, - sample_ratio: float = 1.0, - sample_algo: str = 'uniform', - **kwargs): - """ - Sample a subset from the given dataset. - - :param dataset_to_sample: Dataset to sample from. If None, will use - the formatter linked by the executor. Default is None. - :param load_data_np: number of workers when loading the dataset. - :param sample_ratio: The ratio of the sample size to the original - dataset size. Default is 1.0 (no sampling). - :param sample_algo: Sampling algorithm to use. Options are "uniform", - "frequency_specified_field_selector", or - "topk_specified_field_selector". - Default is "uniform". - :return: A sampled Dataset. - """ - # Determine the dataset to sample from - if dataset_to_sample is not None: - dataset = dataset_to_sample - elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available: - logger.info('Loading dataset from checkpoint...') - dataset = self.ckpt_manager.load_ckpt() - elif hasattr(self, 'formatter'): - logger.info('Loading dataset from data formatter...') - if load_data_np is None: - load_data_np = self.cfg.np - dataset = self.formatter.load_dataset(load_data_np, self.cfg) - else: - raise ValueError('No dataset available to sample from.') - - # Perform sampling based on the specified algorithm - if sample_algo == 'uniform': - return MixtureFormatter.random_sample(dataset, sample_ratio) - elif sample_algo == 'frequency_specified_field_selector': - dj_op = FrequencySpecifiedFieldSelector(**kwargs) - return dj_op.process(dataset) - elif sample_algo == 'topk_specified_field_selector': - dj_op = TopkSpecifiedFieldSelector(**kwargs) - return dj_op.process(dataset) - else: - raise ValueError(f'Unsupported sample_algo: {sample_algo}') - def run(self, load_data_np: Optional[PositiveInt] = None, skip_return=False): @@ -215,3 +167,51 @@ def run(self, if not skip_return: return dataset + + def sample_data(self, + dataset_to_sample: Dataset = None, + load_data_np=None, + sample_ratio: float = 1.0, + sample_algo: str = 'uniform', + **kwargs): + """ + Sample a subset from the given dataset. + TODO add support other than LocalExecutor + + :param executor: executor + :param dataset_to_sample: Dataset to sample from. If None, will use + the formatter linked by the executor. Default is None. + :param load_data_np: number of workers when loading the dataset. + :param sample_ratio: The ratio of the sample size to the original + dataset size. Default is 1.0 (no sampling). + :param sample_algo: Sampling algorithm to use. Options are "uniform", + "frequency_specified_field_selector", or + "topk_specified_field_selector". + Default is "uniform". + :return: A sampled Dataset. + """ + # Determine the dataset to sample from + if dataset_to_sample is not None: + dataset = dataset_to_sample + elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available: + logger.info('Loading dataset from checkpoint...') + dataset = self.ckpt_manager.load_ckpt() + elif hasattr(self, 'formatter'): + logger.info('Loading dataset from data formatter...') + if load_data_np is None: + load_data_np = self.cfg.np + dataset = self.formatter.load_dataset(load_data_np, self.cfg) + else: + raise ValueError('No dataset available to sample from.') + + # Perform sampling based on the specified algorithm + if sample_algo == 'uniform': + return random_sample(dataset, sample_ratio) + elif sample_algo == 'frequency_specified_field_selector': + dj_op = FrequencySpecifiedFieldSelector(**kwargs) + return dj_op.process(dataset) + elif sample_algo == 'topk_specified_field_selector': + dj_op = TopkSpecifiedFieldSelector(**kwargs) + return dj_op.process(dataset) + else: + raise ValueError(f'Unsupported sample_algo: {sample_algo}') diff --git a/data_juicer/ops/selector/random_selector.py b/data_juicer/ops/selector/random_selector.py index c3990ab19..f92d82b68 100644 --- a/data_juicer/ops/selector/random_selector.py +++ b/data_juicer/ops/selector/random_selector.py @@ -3,9 +3,8 @@ from pydantic import Field, PositiveInt from typing_extensions import Annotated -from data_juicer.format.mixture_formatter import MixtureFormatter - -from ..base_op import OPERATORS, Selector +from data_juicer.ops.base_op import OPERATORS, Selector +from data_juicer.utils.sample import random_sample @OPERATORS.register_module('random_selector') @@ -41,7 +40,6 @@ def process(self, dataset): if self.select_ratio is None and self.select_num is None: return dataset - select_num = 0 if not self.select_ratio: select_num = self.select_num else: @@ -49,5 +47,4 @@ def process(self, dataset): if self.select_num and self.select_num < select_num: select_num = self.select_num - return MixtureFormatter.random_sample(dataset, - sample_number=select_num) + return random_sample(dataset, sample_number=select_num) diff --git a/data_juicer/utils/sample.py b/data_juicer/utils/sample.py index 0164dbec0..17275c588 100644 --- a/data_juicer/utils/sample.py +++ b/data_juicer/utils/sample.py @@ -1,60 +1,6 @@ from itertools import chain, repeat import numpy as np -from datasets import Dataset -from loguru import logger - -from data_juicer.ops.selector import (FrequencySpecifiedFieldSelector, - TopkSpecifiedFieldSelector) - - -class SamplingMixin: - - def sample_data(self, - dataset_to_sample: Dataset = None, - load_data_np=None, - sample_ratio: float = 1.0, - sample_algo: str = 'uniform', - **kwargs): - """ - Sample a subset from the given dataset. - - :param dataset_to_sample: Dataset to sample from. If None, will use - the formatter linked by the executor. Default is None. - :param load_data_np: number of workers when loading the dataset. - :param sample_ratio: The ratio of the sample size to the original - dataset size. Default is 1.0 (no sampling). - :param sample_algo: Sampling algorithm to use. Options are "uniform", - "frequency_specified_field_selector", or - "topk_specified_field_selector". - Default is "uniform". - :return: A sampled Dataset. - """ - # Determine the dataset to sample from - if dataset_to_sample is not None: - dataset = dataset_to_sample - elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available: - logger.info('Loading dataset from checkpoint...') - dataset = self.ckpt_manager.load_ckpt() - elif hasattr(self, 'formatter'): - logger.info('Loading dataset from data formatter...') - if load_data_np is None: - load_data_np = self.cfg.np - dataset = self.formatter.load_dataset(load_data_np, self.cfg) - else: - raise ValueError('No dataset available to sample from.') - - # Perform sampling based on the specified algorithm - if sample_algo == 'uniform': - return random_sample(dataset, sample_ratio) - elif sample_algo == 'frequency_specified_field_selector': - dj_op = FrequencySpecifiedFieldSelector(**kwargs) - return dj_op.process(dataset) - elif sample_algo == 'topk_specified_field_selector': - dj_op = TopkSpecifiedFieldSelector(**kwargs) - return dj_op.process(dataset) - else: - raise ValueError(f'Unsupported sample_algo: {sample_algo}') def random_sample(dataset, weight=1.0, sample_number=0, seed=None): diff --git a/tests/core/data/test_config.yaml b/tests/core/data/test_config.yaml index 642ecd958..9620bed65 100644 --- a/tests/core/data/test_config.yaml +++ b/tests/core/data/test_config.yaml @@ -1,5 +1,5 @@ -project_name: 'dataset-local-json' +project_name: 'dataset-ondisk-json' dataset: - type: 'local' + type: 'ondisk' path: - 'sample.json' \ No newline at end of file diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index ad55ec867..63b0b8343 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -1,4 +1,12 @@ import os +import unittest +from argparse import Namespace +from contextlib import redirect_stdout +from io import StringIO + +from networkx.classes import is_empty + +from data_juicer.config import init_configs from data_juicer.core.data.dataset_builder import rewrite_cli_datapath from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS @@ -42,4 +50,18 @@ def test_rewrite_cli_datapath_with_weights(self): self.assertEqual( [{'path': ['./data/sample.json'], 'type': 'ondisk', 'weight': 0.5}, {'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}], - ans) \ No newline at end of file + ans) + + def test_dataset_builder_ondisk_config(self): + test_config_file = './data/test_config.yaml' + out = StringIO() + with redirect_stdout(out): + cfg = init_configs(args=f'--config {test_config_file}'.split()) + self.assertIsInstance(cfg, Namespace) + self.assertEqual(cfg.project_name, 'dataset-ondisk-json') + self.assertEqual(cfg.dataset, {'path': ['sample.json'], 'type': 'ondisk'}) + self.assertEqual(not cfg.dataset_path, True) + + +if __name__ == '__main__': + unittest.main() From daf7a85c7783fa69e0ba80a75290941c2243e31b Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Tue, 10 Dec 2024 11:13:24 -0800 Subject: [PATCH 15/56] update dataset_path parsing in config --- data_juicer/config/config.py | 13 ++++++++----- tests/core/data/test_config_list.yaml | 8 ++++++++ tests/core/test_dataset_builder.py | 15 ++++++++++++++- 3 files changed, 30 insertions(+), 6 deletions(-) create mode 100644 tests/core/data/test_config_list.yaml diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 990fd0e45..18571e34a 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -415,19 +415,22 @@ def init_setup_from_cfg(cfg: Namespace): # check and get dataset dir if cfg.get('dataset_path', None) and os.path.exists(cfg.dataset_path): + logger.warning('dataset_path config is set and a valid local path') cfg.dataset_path = os.path.abspath(cfg.dataset_path) if os.path.isdir(cfg.dataset_path): cfg.dataset_dir = cfg.dataset_path else: cfg.dataset_dir = os.path.dirname(cfg.dataset_path) - elif cfg.dataset_path == '': - logger.warning('dataset_path is empty by default.') + elif cfg.dataset_path == '' and cfg.get('dataset', None): + logger.warning('dataset_path config is empty; dataset is present') cfg.dataset_dir = '' else: logger.warning(f'dataset_path [{cfg.dataset_path}] is not a valid ' - f'local path. Please check and retry, otherwise we ' - f'will treat it as a remote dataset or a mixture of ' - f'several datasets.') + f'local path, AND dataset is not present. ' + f'Please check and retry, otherwise we ' + f'will treat dataset_path as a remote dataset or a ' + f'mixture of several datasets.') + cfg.dataset_dir = '' # check number of processes np diff --git a/tests/core/data/test_config_list.yaml b/tests/core/data/test_config_list.yaml new file mode 100644 index 000000000..61ad61162 --- /dev/null +++ b/tests/core/data/test_config_list.yaml @@ -0,0 +1,8 @@ +project_name: 'dataset-ondisk-list' +dataset: + - type: 'ondisk' + path: + - 'sample.json' + - type: 'ondisk' + path: + - 'sample.txt' \ No newline at end of file diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index 63b0b8343..32bd04e2f 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -59,7 +59,20 @@ def test_dataset_builder_ondisk_config(self): cfg = init_configs(args=f'--config {test_config_file}'.split()) self.assertIsInstance(cfg, Namespace) self.assertEqual(cfg.project_name, 'dataset-ondisk-json') - self.assertEqual(cfg.dataset, {'path': ['sample.json'], 'type': 'ondisk'}) + self.assertEqual(cfg.dataset, + {'path': ['sample.json'], 'type': 'ondisk'}) + self.assertEqual(not cfg.dataset_path, True) + + def test_dataset_builder_ondisk_config_list(self): + test_config_file = './data/test_config_list.yaml' + out = StringIO() + with redirect_stdout(out): + cfg = init_configs(args=f'--config {test_config_file}'.split()) + self.assertIsInstance(cfg, Namespace) + self.assertEqual(cfg.project_name, 'dataset-ondisk-list') + self.assertEqual(cfg.dataset,[ + {'path': ['sample.json'], 'type': 'ondisk'}, + {'path': ['sample.txt'], 'type': 'ondisk'}]) self.assertEqual(not cfg.dataset_path, True) From 7c48892fd70de1fba00ac499eae1a51616dc3306 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 11 Dec 2024 11:03:43 -0800 Subject: [PATCH 16/56] fix download test case; add wildcard matching for load strategy --- data_juicer/core/data/dataset_builder.py | 25 ++- data_juicer/core/data/load_strategy.py | 196 +++++++++++++++++------ data_juicer/download/downloader.py | 2 +- tests/download/test_download.py | 23 ++- 4 files changed, 178 insertions(+), 68 deletions(-) diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 1dd5a830c..ea7841d34 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -2,6 +2,7 @@ from typing import List, Tuple, Union from data_juicer.core.data import NestedDataset +from data_juicer.core.data.load_strategy import DataLoadStrategyRegistry from data_juicer.core.data.ray_dataset import RayDataset from data_juicer.utils.file_utils import is_absolute_path @@ -9,6 +10,7 @@ class DatasetBuilder(object): def __init__(self, cfg): + # defaults to use dataset_path if cfg.dataset_path is not None: ds_configs = rewrite_cli_datapath(cfg.dataset_path) elif cfg.dataset is not None: @@ -17,17 +19,26 @@ def __init__(self, cfg): raise ValueError( 'Unable to initialize dataset; should have one of ' 'dataset_path or dataset in configurations') - for config in ds_configs: - # initialize data loader from strategy - pass + # dataset config could be a list or a single entry; retrofit + if not isinstance(ds_configs, list): + ds_configs = [ds_configs] + self.load_strategies = [] + for ds_config in ds_configs: + # initialize data loading strategy + executor_type = cfg.get('executor_type', None) + data_type = ds_config.get('type', None) + data_source = ds_config.get('source', None) + self.load_strategies.append( + DataLoadStrategyRegistry.get_strategy_class( + executor_type, data_type, data_source)(ds_config)) def load_dataset(self) -> Union[NestedDataset, RayDataset]: # handle mixture dataset, nested dataset # handle sampling of mixture datasets - # - for f in self.formatters: - f.load_dataset() - return None + _datasets = [] + for f in self.load_strategies: + _datasets.append(f.load_data()) + return _datasets[0] def rewrite_cli_datapath(dataset_path) -> List: diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 3d44bf34d..8fd3ddb74 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -1,5 +1,7 @@ +import fnmatch from abc import ABC, abstractmethod -from typing import Union +from dataclasses import dataclass +from typing import Dict, Optional, Type, Union from data_juicer.core.data import DJDataset, RayDataset @@ -8,53 +10,124 @@ # DJDataset, RayDataset, DaskDataset, etc -class DataLoadStrategyRegistry: - - def __init__(self): - self._registry = {} - - def register(self, key: tuple, strategy): - """Register a strategy for a specific tuple key.""" - if key in self._registry: - raise ValueError(f'Strategy for key {key} is already registered.') - self._registry[key] = strategy +@dataclass(frozen=True) +class StrategyKey: + """ + Immutable key for strategy registration with wildcard support + """ + executor_type: str + data_type: str + data_source: str - def get_strategy(self, key: tuple): - """Retrieve the strategy for a specific tuple key.""" - if key not in self._registry: - raise ValueError(f'No strategy registered for key {key}.') - return self._registry[key] + def matches(self, other: 'StrategyKey') -> bool: + """ + Check if this key matches another key with wildcard support - def register_decorator(self, key: tuple): - """Decorator for registering a strategy with a specific tuple key.""" + Supports Unix-style wildcards: + - '*' matches any string + - '?' matches any single character + - '[seq]' matches any character in seq + - '[!seq]' matches any character not in seq + """ + return (fnmatch.fnmatch(other.executor_type, self.executor_type) + and fnmatch.fnmatch(other.data_type, self.data_type) + and fnmatch.fnmatch(other.data_source, self.data_source)) - def decorator(func): - self.register(key, func) - return func # Return the original function - return decorator +class DataLoadStrategy(ABC): + """ + abstract class for data load strategy + """ + def __init__(self, ds_config: Dict): + self.ds_config = ds_config -DATALOAD_STRATEGY_REGISTRY = DataLoadStrategyRegistry() + @abstractmethod + def load_data(self) -> Union[DJDataset, RayDataset]: + pass -class DataLoadStrategyFactory: +class DataLoadStrategyRegistry: + """ + Flexible strategy registry with wildcard matching + """ + _strategies: Dict[StrategyKey, Type[DataLoadStrategy]] = {} @classmethod - def create_dataload_strategy(cls, executor_type, dataset_type, - dataset_source): - DATALOAD_STRATEGY_REGISTRY.get_strategy( - (executor_type, dataset_type, dataset_source)) + def get_strategy_class( + cls, executor_type: str, data_type: str, + data_source: str) -> Optional[Type[DataLoadStrategy]]: + """ + Retrieve the most specific matching strategy + + Matching priority: + 1. Exact match + 2. Wildcard matches from most specific to most general + """ + # Create the lookup key + lookup_key = StrategyKey(executor_type, data_type, data_source) + + # First, check for exact match + exact_match = cls._strategies.get(lookup_key) + if exact_match: + return exact_match + + # Find all matching wildcard strategies + matching_strategies = [] + for registered_key, strategy in cls._strategies.items(): + if registered_key.matches(lookup_key): + matching_strategies.append((registered_key, strategy)) + + # Sort matching strategies by specificity (fewer wildcards first) + if matching_strategies: + + def specificity_score(key: StrategyKey) -> int: + """ + Calculate specificity score (lower is more specific) + Exact match: 0 + One wildcard: 1 + Two wildcards: 2 + All wildcards: 3 + """ + return sum(1 for part in + [key.executor_type, key.data_type, key.data_source] + if part == '*') + + matching_strategies.sort(key=lambda x: specificity_score(x[0])) + return matching_strategies[0][1] + + # No matching strategy found + return None + @classmethod + def register(cls, executor_type: str, data_type: str, data_source: str): + """ + Decorator for registering data load strategies with wildcard support + + :param executor_type: Type of executor (e.g., 'local', 'ray') + :param data_type: Type of data (e.g., 'ondisk', 'remote') + :param data_source: Specific data source (e.g., 'arxiv', 's3') + :return: Decorator function + """ + + def decorator(strategy_class: Type[DataLoadStrategy]): + """ + Register the strategy class for the given key + + :param strategy_class: Strategy class to register + :return: Original strategy class + """ + key = StrategyKey(executor_type, data_type, data_source) + cls._strategies[key] = strategy_class + return strategy_class -class DataLoadStrategy(ABC): - - @abstractmethod - def load_data(self) -> Union[DJDataset, RayDataset]: - pass + return decorator class RayDataLoadStrategy(DataLoadStrategy): + """ + abstract class for data load strategy for RayExecutor + """ @abstractmethod def load_data(self) -> RayDataset: @@ -62,6 +135,9 @@ def load_data(self) -> RayDataset: class LocalDataLoadStrategy(DataLoadStrategy): + """ + abstract class for data load strategy for LocalExecutor + """ @abstractmethod def load_data(self) -> DJDataset: @@ -74,69 +150,83 @@ def load_data(self) -> DJDataset: # def load_data(self) -> Union[DaskDataset]: # pass +# TODO nemo support +# class NemoDataLoadStrategy(DataLoadStrategy): +# @abstractmethod +# def load_data(self) -> Union[NemoDataset]: +# pass + -@DATALOAD_STRATEGY_REGISTRY.register_decorator(('ray', 'ondisk', 'json')) +@DataLoadStrategyRegistry.register('ray', 'ondisk', 'json') class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): def load_data(self): pass -@DATALOAD_STRATEGY_REGISTRY.register_decorator( - ('ray', 'remote', 'huggingface')) +@DataLoadStrategyRegistry.register('ray', 'remote', 'huggingface') class RayHuggingfaceDataLoadStrategy(RayDataLoadStrategy): def load_data(self): pass -@DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'ondisk', 'Json')) -class LocalOndiskJsonDataLoadStrategy(LocalDataLoadStrategy): - - def load_data(self): - pass - - -@DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'ondisk', 'Parquet')) -class LocalOndiskParquetDataLoadStrategy(LocalDataLoadStrategy): +@DataLoadStrategyRegistry.register('local', 'ondisk', '*') +class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy): + """ + data load strategy for on disk data for LocalExecutor + rely on AutoFormatter for actual data loading + """ def load_data(self): pass -@DATALOAD_STRATEGY_REGISTRY.register_decorator( - ('local', 'remote', 'huggingface')) +@DataLoadStrategyRegistry.register('local', 'remote', 'huggingface') class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy): + """ + data load strategy for Huggingface dataset for LocalExecutor + """ def load_data(self): pass -@DATALOAD_STRATEGY_REGISTRY.register_decorator( - ('local', 'remote', 'modelscope')) +@DataLoadStrategyRegistry.register('local', 'remote', 'modelscope') class LocalModelScopeDataLoadStrategy(LocalDataLoadStrategy): + """ + data load strategy for ModelScope dataset for LocalExecutor + """ def load_data(self): pass -@DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'remote', 'arxiv')) +@DataLoadStrategyRegistry.register('local', 'remote', 'arxiv') class LocalArxivDataLoadStrategy(LocalDataLoadStrategy): + """ + data load strategy for arxiv dataset for LocalExecutor + """ def load_data(self): pass -@DATALOAD_STRATEGY_REGISTRY.register_decorator(('local', 'remote', 'wiki')) +@DataLoadStrategyRegistry.register('local', 'remote', 'wiki') class LocalWikiDataLoadStrategy(LocalDataLoadStrategy): + """ + data load strategy for wiki dataset for LocalExecutor + """ def load_data(self): pass -@DATALOAD_STRATEGY_REGISTRY.register_decorator( - ('local', 'remote', 'commoncrawl')) +@DataLoadStrategyRegistry.register('local', 'remote', 'commoncrawl') class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy): + """ + data load strategy for commoncrawl dataset for LocalExecutor + """ def load_data(self): pass diff --git a/data_juicer/download/downloader.py b/data_juicer/download/downloader.py index a5b1db3ae..8cebfec54 100644 --- a/data_juicer/download/downloader.py +++ b/data_juicer/download/downloader.py @@ -75,7 +75,7 @@ def _download_and_extract_single_partition(paths: Tuple[str, str], item_count = 0 for item in iterator.iterate(downloaded_file): item_count += 1 - if item_limit and item_count >= item_limit: + if item_limit and item_count > item_limit: break record_meta, content = item # Extract the text from the record diff --git a/tests/download/test_download.py b/tests/download/test_download.py index 907c93ded..d07e46e5b 100644 --- a/tests/download/test_download.py +++ b/tests/download/test_download.py @@ -1,10 +1,20 @@ import unittest import tempfile +import os +import shutil from data_juicer.download.wikipedia import ( get_wikipedia_urls, download_wikipedia ) -class TestDownload: +class TestDownload(unittest.TestCase): + def setUp(self): + # Creates a temporary directory that persists until you delete it + self.temp_dir = tempfile.mkdtemp(prefix='dj_test_') + + def tearDown(self): + # Clean up the temporary directory after each test + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) def test_wikipedia_urls(self): dump_date = "20241101" @@ -16,12 +26,11 @@ def test_wikipedia_urls(self): def test_wikipedia_download(self): dump_date = "20241101" - output_directory = tempfile.gettempdir() + "/dj_temp/" - url_limit = 5 - item_limit = 10 - wiki_df = download_wikipedia(output_directory, dump_date=dump_date, url_limit=url_limit, item_limit=item_limit) - sample = wiki_df.take(50) - assert len(sample) == 50 + url_limit = 1 + item_limit = 50 + wiki_df = download_wikipedia(self.temp_dir, dump_date=dump_date, url_limit=url_limit, item_limit=item_limit) + sample = wiki_df.take(10) + assert len(sample) == 10 if __name__ == '__main__': From 940b44d07364adc963070abf85a6643e00beffb2 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 11 Dec 2024 11:10:56 -0800 Subject: [PATCH 17/56] add test case for load strategy wild card matching --- tests/core/test_dataload_strategy.py | 116 +++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 tests/core/test_dataload_strategy.py diff --git a/tests/core/test_dataload_strategy.py b/tests/core/test_dataload_strategy.py new file mode 100644 index 000000000..b3f489cfb --- /dev/null +++ b/tests/core/test_dataload_strategy.py @@ -0,0 +1,116 @@ +import unittest +from data_juicer.core.data.load_strategy import ( + DataLoadStrategyRegistry, DataLoadStrategy, StrategyKey +) + +class MockStrategy(DataLoadStrategy): + def load_data(self): + pass + +class TestDataLoadStrategyRegistry(unittest.TestCase): + def setUp(self): + # Clear existing strategies before each test + DataLoadStrategyRegistry._strategies = {} + + def test_exact_match(self): + # Register a specific strategy + @DataLoadStrategyRegistry.register('local', 'ondisk', 'json') + class TestStrategy(MockStrategy): + pass + + # Test exact match + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') + self.assertEqual(strategy, TestStrategy) + + # Test no match + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'csv') + self.assertIsNone(strategy) + + def test_wildcard_matching(self): + # Register strategies with different wildcard patterns + @DataLoadStrategyRegistry.register('local', 'ondisk', '*') + class AllFilesStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('local', '*', '*') + class AllLocalStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('*', '*', '*') + class FallbackStrategy(MockStrategy): + pass + + # Test specific matches + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') + self.assertEqual(strategy, AllFilesStrategy) # Should match most specific wildcard + + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'remote', 'json') + self.assertEqual(strategy, AllLocalStrategy) # Should match second level wildcard + + strategy = DataLoadStrategyRegistry.get_strategy_class('ray', 'remote', 'json') + self.assertEqual(strategy, FallbackStrategy) # Should match most general wildcard + + def test_specificity_priority(self): + @DataLoadStrategyRegistry.register('*', '*', '*') + class GeneralStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('local', '*', '*') + class LocalStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('local', 'ondisk', '*') + class LocalOndiskStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('local', 'ondisk', 'json') + class ExactStrategy(MockStrategy): + pass + + # Test matching priority + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') + self.assertEqual(strategy, ExactStrategy) # Should match exact first + + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'csv') + self.assertEqual(strategy, LocalOndiskStrategy) # Should match one wildcard + + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'remote', 'json') + self.assertEqual(strategy, LocalStrategy) # Should match two wildcards + + strategy = DataLoadStrategyRegistry.get_strategy_class('ray', 'remote', 'json') + self.assertEqual(strategy, GeneralStrategy) # Should match general wildcard + + def test_pattern_matching(self): + @DataLoadStrategyRegistry.register('local', 'ondisk', '*.json') + class JsonStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('local', 'ondisk', 'data_[0-9]*') + class NumberedDataStrategy(MockStrategy): + pass + + # Test pattern matching + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'test.json') + self.assertEqual(strategy, JsonStrategy) + + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'data_123') + self.assertEqual(strategy, NumberedDataStrategy) + + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'test.csv') + self.assertIsNone(strategy) + + def test_strategy_key_matches(self): + # Test StrategyKey matching directly + wildcard_key = StrategyKey('*', 'ondisk', '*.json') + specific_key = StrategyKey('local', 'ondisk', 'test.json') + + self.assertTrue(wildcard_key.matches(specific_key)) + self.assertFalse(specific_key.matches(wildcard_key)) # Exact keys don't match wildcards + + # Test pattern matching + pattern_key = StrategyKey('local', '*', 'data_[0-9]*') + match_key = StrategyKey('local', 'ondisk', 'data_123') + no_match_key = StrategyKey('local', 'ondisk', 'data_abc') + + self.assertTrue(pattern_key.matches(match_key)) + self.assertFalse(pattern_key.matches(no_match_key)) From b80f9913916dd83d9b42b0b5ded6f187d6b95c5f Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 11 Dec 2024 13:02:18 -0800 Subject: [PATCH 18/56] add more test cases for datapath rewrite logic; fix rewrite to handle space in file name --- data_juicer/core/data/dataset_builder.py | 18 ++- tests/core/test_dataset_builder.py | 163 +++++++++++++++++++++-- 2 files changed, 165 insertions(+), 16 deletions(-) diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index ea7841d34..553761516 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -1,4 +1,5 @@ import os +import shlex from typing import List, Tuple, Union from data_juicer.core.data import NestedDataset @@ -78,16 +79,25 @@ def parse_cli_datapath(dataset_path) -> Tuple[List[str], List[float]]: them, e.g. ` ds1.jsonl ds2_dir ds3_file.json` :return: list of dataset path and list of weights """ - data_prefix = dataset_path.split() + # Handle empty input + if not dataset_path or not dataset_path.strip(): + return [], [] + + # Use shlex to properly handle quoted strings + try: + tokens = shlex.split(dataset_path) + except ValueError as e: + raise ValueError(f'Invalid dataset path format: {e}') + prefixes = [] weights = [] - for i in range(len(data_prefix)): + for i in range(len(tokens)): try: - value = max(float(data_prefix[i]), 0.0) + value = max(float(tokens[i]), 0.0) weights.append(value) except: # noqa: E722 - value = data_prefix[i].strip() + value = tokens[i].strip() # if not set weight, use 1.0 as default if i == 0 or len(weights) == len(prefixes): weights.append(1.0) diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index 32bd04e2f..86e0aab0e 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -1,35 +1,32 @@ import os import unittest +from unittest.mock import patch from argparse import Namespace from contextlib import redirect_stdout from io import StringIO - -from networkx.classes import is_empty - from data_juicer.config import init_configs -from data_juicer.core.data.dataset_builder import rewrite_cli_datapath +from data_juicer.core.data.dataset_builder import rewrite_cli_datapath, parse_cli_datapath from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS @SKIPPED_TESTS.register_module() class DatasetBuilderTest(DataJuicerTestCaseBase): + def setUp(self): + # Get the directory where this test file is located + test_file_dir = os.path.dirname(os.path.abspath(__file__)) + os.chdir(test_file_dir) + def test_rewrite_cli_datapath_local_single_file(self): dataset_path = "./data/sample.txt" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}], ans) + [{'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}], ans) def test_rewrite_cli_datapath_local_directory(self): dataset_path = "./data" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'path': ['./data'], 'type': 'ondisk', 'weight': 1.0}], ans) - - def test_rewrite_cli_datapath_absolute_path(self): - dataset_path = os.curdir + "/data/sample.txt" - ans = rewrite_cli_datapath(dataset_path) - self.assertEqual( - [{'type': 'ondisk', 'path': [dataset_path], 'weight': 1.0}], ans) + [{'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}], ans) def test_rewrite_cli_datapath_hf(self): dataset_path = "hf-internal-testing/librispeech_asr_dummy" @@ -75,6 +72,148 @@ def test_dataset_builder_ondisk_config_list(self): {'path': ['sample.txt'], 'type': 'ondisk'}]) self.assertEqual(not cfg.dataset_path, True) + @patch('os.path.isdir') + @patch('os.path.isfile') + def test_rewrite_cli_datapath_local_files(self, mock_isfile, mock_isdir): + # Mock os.path.isdir and os.path.isfile to simulate local files + mock_isfile.side_effect = lambda x: x.endswith('.jsonl') + mock_isdir.side_effect = lambda x: x.endswith('_dir') + + dataset_path = "1.0 ds1.jsonl 2.0 ds2_dir 3.0 ds3.jsonl" + expected = [ + {'type': 'ondisk', 'path': ['ds1.jsonl'], 'weight': 1.0}, + {'type': 'ondisk', 'path': ['ds2_dir'], 'weight': 2.0}, + {'type': 'ondisk', 'path': ['ds3.jsonl'], 'weight': 3.0} + ] + result = rewrite_cli_datapath(dataset_path) + self.assertEqual(result, expected) + + def test_rewrite_cli_datapath_huggingface(self): + dataset_path = "1.0 huggingface/dataset" + expected = [ + {'type': 'huggingface', 'path': 'huggingface/dataset', 'split': 'train'} + ] + result = rewrite_cli_datapath(dataset_path) + self.assertEqual(result, expected) + + def test_rewrite_cli_datapath_invalid(self): + dataset_path = "1.0 ./invalid_path" + with self.assertRaises(ValueError): + rewrite_cli_datapath(dataset_path) + + def test_parse_cli_datapath(self): + dataset_path = "1.0 ds1.jsonl 2.0 ds2_dir 3.0 ds3.jsonl" + expected_paths = ['ds1.jsonl', 'ds2_dir', 'ds3.jsonl'] + expected_weights = [1.0, 2.0, 3.0] + paths, weights = parse_cli_datapath(dataset_path) + self.assertEqual(paths, expected_paths) + self.assertEqual(weights, expected_weights) + + def test_parse_cli_datapath_default_weight(self): + dataset_path = "ds1.jsonl ds2_dir 2.0 ds3.jsonl" + expected_paths = ['ds1.jsonl', 'ds2_dir', 'ds3.jsonl'] + expected_weights = [1.0, 1.0, 2.0] + paths, weights = parse_cli_datapath(dataset_path) + self.assertEqual(paths, expected_paths) + self.assertEqual(weights, expected_weights) + + + def test_parse_cli_datapath_edge_cases(self): + # Test various edge cases + test_cases = [ + # Empty string + ("", [], []), + # Single path + ("file.txt", ['file.txt'], [1.0]), + # Multiple spaces between items + ("file1.txt file2.txt", ['file1.txt', 'file2.txt'], [1.0, 1.0]), + # Tab characters + ("file1.txt\tfile2.txt", ['file1.txt', 'file2.txt'], [1.0, 1.0]), + # Paths with spaces in them (quoted) + ('"my file.txt" 1.5 "other file.txt"', + ['my file.txt', 'other file.txt'], + [1.0, 1.5]), + ] + + for input_path, expected_paths, expected_weights in test_cases: + paths, weights = parse_cli_datapath(input_path) + self.assertEqual(paths, expected_paths, + f"Failed paths for input: {input_path}") + self.assertEqual(weights, expected_weights, + f"Failed weights for input: {input_path}") + + def test_parse_cli_datapath_valid_weights(self): + # Test various valid weight formats + test_cases = [ + ("1.0 file.txt", ['file.txt'], [1.0]), + ("1.5 file1.txt 2.0 file2.txt", + ['file1.txt', 'file2.txt'], + [1.5, 2.0]), + ("0.5 file1.txt file2.txt 1.5 file3.txt", + ['file1.txt', 'file2.txt', 'file3.txt'], + [0.5, 1.0, 1.5]), + # Test integer weights + ("1 file.txt", ['file.txt'], [1.0]), + ("2 file1.txt 3 file2.txt", + ['file1.txt', 'file2.txt'], + [2.0, 3.0]), + ] + + for input_path, expected_paths, expected_weights in test_cases: + paths, weights = parse_cli_datapath(input_path) + self.assertEqual(paths, expected_paths, + f"Failed paths for input: {input_path}") + self.assertEqual(weights, expected_weights, + f"Failed weights for input: {input_path}") + + def test_parse_cli_datapath_special_characters(self): + # Test paths with special characters + test_cases = [ + # Paths with hyphens and underscores + ("my-file_1.txt", ['my-file_1.txt'], [1.0]), + # Paths with dots + ("path/to/file.with.dots.txt", ['path/to/file.with.dots.txt'], [1.0]), + # Paths with special characters + ("file#1.txt", ['file#1.txt'], [1.0]), + # Mixed case with weight + ("1.0 Path/To/File.TXT", ['Path/To/File.TXT'], [1.0]), + # Multiple paths with special characters + ("2.0 file#1.txt 3.0 path/to/file-2.txt", + ['file#1.txt', 'path/to/file-2.txt'], + [2.0, 3.0]), + ] + + for input_path, expected_paths, expected_weights in test_cases: + paths, weights = parse_cli_datapath(input_path) + self.assertEqual(paths, expected_paths, + f"Failed paths for input: {input_path}") + self.assertEqual(weights, expected_weights, + f"Failed weights for input: {input_path}") + + def test_parse_cli_datapath_multiple_datasets(self): + # Test multiple datasets with various weight combinations + test_cases = [ + # Multiple datasets with all weights specified + ("0.5 data1.txt 1.5 data2.txt 2.0 data3.txt", + ['data1.txt', 'data2.txt', 'data3.txt'], + [0.5, 1.5, 2.0]), + # Mix of weighted and unweighted datasets + ("data1.txt 1.5 data2.txt data3.txt", + ['data1.txt', 'data2.txt', 'data3.txt'], + [1.0, 1.5, 1.0]), + # Multiple datasets with same weight + ("2.0 data1.txt 2.0 data2.txt 2.0 data3.txt", + ['data1.txt', 'data2.txt', 'data3.txt'], + [2.0, 2.0, 2.0]), + ] + + for input_path, expected_paths, expected_weights in test_cases: + paths, weights = parse_cli_datapath(input_path) + self.assertEqual(paths, expected_paths, + f"Failed paths for input: {input_path}") + self.assertEqual(weights, expected_weights, + f"Failed weights for input: {input_path}") + if __name__ == '__main__': unittest.main() From 0d5d4ba0c74a269a1226c16eb3da1b508a546a71 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 11 Dec 2024 13:14:58 -0800 Subject: [PATCH 19/56] materialize symlinks for duplicates --- tests/ops/data/img1_dup.png | 1 + tests/ops/data/img2_dup.jpg | 1 + tests/ops/data/img3_dup.jpg | 1 + tests/ops/data/img3_dup_dup.jpg | 1 + tests/ops/data/video1_dup.mp4 | 1 + tests/ops/data/video2_dup.mp4 | 1 + tests/ops/data/video3_dup.mp4 | 1 + tests/ops/data/video3_dup_dup.mp4 | 1 + 8 files changed, 8 insertions(+) create mode 120000 tests/ops/data/img1_dup.png create mode 120000 tests/ops/data/img2_dup.jpg create mode 120000 tests/ops/data/img3_dup.jpg create mode 120000 tests/ops/data/img3_dup_dup.jpg create mode 120000 tests/ops/data/video1_dup.mp4 create mode 120000 tests/ops/data/video2_dup.mp4 create mode 120000 tests/ops/data/video3_dup.mp4 create mode 120000 tests/ops/data/video3_dup_dup.mp4 diff --git a/tests/ops/data/img1_dup.png b/tests/ops/data/img1_dup.png new file mode 120000 index 000000000..d62a85900 --- /dev/null +++ b/tests/ops/data/img1_dup.png @@ -0,0 +1 @@ +/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/img1.png \ No newline at end of file diff --git a/tests/ops/data/img2_dup.jpg b/tests/ops/data/img2_dup.jpg new file mode 120000 index 000000000..8a99a2526 --- /dev/null +++ b/tests/ops/data/img2_dup.jpg @@ -0,0 +1 @@ +/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/img2.jpg \ No newline at end of file diff --git a/tests/ops/data/img3_dup.jpg b/tests/ops/data/img3_dup.jpg new file mode 120000 index 000000000..6e8c435e3 --- /dev/null +++ b/tests/ops/data/img3_dup.jpg @@ -0,0 +1 @@ +/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/img3.jpg \ No newline at end of file diff --git a/tests/ops/data/img3_dup_dup.jpg b/tests/ops/data/img3_dup_dup.jpg new file mode 120000 index 000000000..f539c0972 --- /dev/null +++ b/tests/ops/data/img3_dup_dup.jpg @@ -0,0 +1 @@ +/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/img3_dup.jpg \ No newline at end of file diff --git a/tests/ops/data/video1_dup.mp4 b/tests/ops/data/video1_dup.mp4 new file mode 120000 index 000000000..6d1bbbc84 --- /dev/null +++ b/tests/ops/data/video1_dup.mp4 @@ -0,0 +1 @@ +/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/video1.mp4 \ No newline at end of file diff --git a/tests/ops/data/video2_dup.mp4 b/tests/ops/data/video2_dup.mp4 new file mode 120000 index 000000000..8fa6335be --- /dev/null +++ b/tests/ops/data/video2_dup.mp4 @@ -0,0 +1 @@ +/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/video2.mp4 \ No newline at end of file diff --git a/tests/ops/data/video3_dup.mp4 b/tests/ops/data/video3_dup.mp4 new file mode 120000 index 000000000..f63158860 --- /dev/null +++ b/tests/ops/data/video3_dup.mp4 @@ -0,0 +1 @@ +/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/video3.mp4 \ No newline at end of file diff --git a/tests/ops/data/video3_dup_dup.mp4 b/tests/ops/data/video3_dup_dup.mp4 new file mode 120000 index 000000000..6a225ba39 --- /dev/null +++ b/tests/ops/data/video3_dup_dup.mp4 @@ -0,0 +1 @@ +/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/video3_dup.mp4 \ No newline at end of file From f3a4ec4c7d724040f4d9cc35101f61e518d160be Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 11 Dec 2024 17:26:24 -0800 Subject: [PATCH 20/56] add load strategy validation framework --- data_juicer/core/data/dataset_builder.py | 5 +- data_juicer/core/data/load_strategy.py | 112 +++++++++++++++++--- data_juicer/core/executor/local_executor.py | 4 +- data_juicer/format/load.py | 66 +++++++++++- data_juicer/format/mixture_formatter.py | 104 ++---------------- 5 files changed, 178 insertions(+), 113 deletions(-) diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 553761516..06e4dbbf0 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -11,6 +11,7 @@ class DatasetBuilder(object): def __init__(self, cfg): + self.cfg = cfg # defaults to use dataset_path if cfg.dataset_path is not None: ds_configs = rewrite_cli_datapath(cfg.dataset_path) @@ -26,7 +27,7 @@ def __init__(self, cfg): self.load_strategies = [] for ds_config in ds_configs: # initialize data loading strategy - executor_type = cfg.get('executor_type', None) + executor_type = ds_config.get('executor_type', None) data_type = ds_config.get('type', None) data_source = ds_config.get('source', None) self.load_strategies.append( @@ -38,7 +39,7 @@ def load_dataset(self) -> Union[NestedDataset, RayDataset]: # handle sampling of mixture datasets _datasets = [] for f in self.load_strategies: - _datasets.append(f.load_data()) + _datasets.append(f.load_data(self.cfg)) return _datasets[0] diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 8fd3ddb74..20d15194d 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -1,9 +1,14 @@ import fnmatch from abc import ABC, abstractmethod +from argparse import Namespace from dataclasses import dataclass from typing import Dict, Optional, Type, Union from data_juicer.core.data import DJDataset, RayDataset +from data_juicer.utils.lazy_loader import LazyLoader + +ray = LazyLoader('ray', 'ray') +rd = LazyLoader('rd', 'ray.data') # based on executor type and data source type, use different # data load strategy to product corresponding datasets @@ -34,16 +39,72 @@ def matches(self, other: 'StrategyKey') -> bool: and fnmatch.fnmatch(other.data_source, self.data_source)) -class DataLoadStrategy(ABC): +class ValidationError(Exception): + """Custom exception for validation errors""" + pass + + +class ConfigValidator: + """Mixin class for configuration validation""" + + # Define validation rules for each strategy type + VALIDATION_RULES = { + 'required_fields': [], # Fields that must be present + 'field_types': {}, # Expected types for fields + 'custom_validators': {} # Custom validation functions + } + + def validate_config(self, ds_config: Dict) -> None: + """ + Validate the configuration dictionary. + + Args: + ds_config: Configuration dictionary to validate + + Raises: + ValidationError: If validation fails + """ + # Check required fields + missing_fields = [ + field for field in self.VALIDATION_RULES['required_fields'] + if field not in ds_config + ] + if missing_fields: + raise ValidationError( + f"Missing required fields: {', '.join(missing_fields)}") + + # Check field types + for field, expected_type in self.VALIDATION_RULES['field_types'].items( + ): + if field in ds_config: + value = ds_config[field] + if not isinstance(value, expected_type): + raise ValidationError(f"Field '{field}' must be of " + "type '{expected_type.__name__}', " + f"got '{type(value).__name__}'") + + # Run custom validators + for field, validator in self.VALIDATION_RULES[ + 'custom_validators'].items(): + if field in ds_config: + try: + validator(ds_config[field]) + except Exception as e: + raise ValidationError( + f"Validation failed for field '{field}': {str(e)}") + + +class DataLoadStrategy(ABC, ConfigValidator): """ abstract class for data load strategy """ def __init__(self, ds_config: Dict): + self.validate_config(ds_config) self.ds_config = ds_config @abstractmethod - def load_data(self) -> Union[DJDataset, RayDataset]: + def load_data(self, cfg: Namespace) -> Union[DJDataset, RayDataset]: pass @@ -140,7 +201,7 @@ class LocalDataLoadStrategy(DataLoadStrategy): """ @abstractmethod - def load_data(self) -> DJDataset: + def load_data(self, cfg: Namespace) -> DJDataset: pass @@ -160,14 +221,41 @@ def load_data(self) -> DJDataset: @DataLoadStrategyRegistry.register('ray', 'ondisk', 'json') class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): - def load_data(self): - pass + VALIDATION_RULES = { + 'required_fields': ['dataset_path'], + 'field_types': { + 'dataset_path': (str, list), # Can be string or list + 'cache_dir': str, + 'shuffle': bool + }, + 'custom_validators': { + 'dataset_path': + lambda x: + (isinstance(x, (str, list)) and + (isinstance(x, str) or all(isinstance(p, str) for p in x))) + } + } + + def load_data(self, cfg: Namespace): + return rd.read_json(self.ds_config.path) @DataLoadStrategyRegistry.register('ray', 'remote', 'huggingface') class RayHuggingfaceDataLoadStrategy(RayDataLoadStrategy): - def load_data(self): + VALIDATION_RULES = { + 'required_fields': ['dataset_name', 'split'], + 'field_types': { + 'dataset_name': str, + 'split': str, + 'streaming': bool + }, + 'custom_validators': { + 'split': lambda x: x in ['train', 'test', 'validation'] + } + } + + def load_data(self, cfg: Namespace): pass @@ -178,7 +266,7 @@ class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy): rely on AutoFormatter for actual data loading """ - def load_data(self): + def load_data(self, cfg: Namespace): pass @@ -188,7 +276,7 @@ class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy): data load strategy for Huggingface dataset for LocalExecutor """ - def load_data(self): + def load_data(self, cfg: Namespace): pass @@ -198,7 +286,7 @@ class LocalModelScopeDataLoadStrategy(LocalDataLoadStrategy): data load strategy for ModelScope dataset for LocalExecutor """ - def load_data(self): + def load_data(self, cfg: Namespace): pass @@ -208,7 +296,7 @@ class LocalArxivDataLoadStrategy(LocalDataLoadStrategy): data load strategy for arxiv dataset for LocalExecutor """ - def load_data(self): + def load_data(self, cfg: Namespace): pass @@ -218,7 +306,7 @@ class LocalWikiDataLoadStrategy(LocalDataLoadStrategy): data load strategy for wiki dataset for LocalExecutor """ - def load_data(self): + def load_data(self, cfg: Namespace): pass @@ -228,5 +316,5 @@ class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy): data load strategy for commoncrawl dataset for LocalExecutor """ - def load_data(self): + def load_data(self, cfg: Namespace): pass diff --git a/data_juicer/core/executor/local_executor.py b/data_juicer/core/executor/local_executor.py index ac8aca353..39c9f595e 100644 --- a/data_juicer/core/executor/local_executor.py +++ b/data_juicer/core/executor/local_executor.py @@ -110,10 +110,10 @@ def run(self, logger.info('Loading dataset from checkpoint...') dataset = self.ckpt_manager.load_ckpt() else: - logger.info('Loading dataset from data formatter...') + logger.info('Loading dataset from dataset builder...') if load_data_np is None: load_data_np = self.cfg.np - dataset = self.formatter.load_dataset(load_data_np, self.cfg) + dataset = self.dataset_builder.load_dataset() # 2. extract processes and optimize their orders logger.info('Preparing process operators...') diff --git a/data_juicer/format/load.py b/data_juicer/format/load.py index 3a65817be..ef58b124b 100644 --- a/data_juicer/format/load.py +++ b/data_juicer/format/load.py @@ -1,5 +1,9 @@ -from .formatter import BaseFormatter -from .mixture_formatter import MixtureFormatter +import os + +from data_juicer.format.formatter import (FORMATTERS, BaseFormatter, + MixtureFormatter, RemoteFormatter) +from data_juicer.utils.file_utils import (find_files_with_suffix, + is_absolute_path) def load_formatter(dataset_path, @@ -39,3 +43,61 @@ def load_formatter(dataset_path, add_suffix=add_suffix, **kwargs) return formatter + + +def _load_formatter(dataset_path, + text_keys=None, + suffixes=None, + add_suffix=False, + **kwargs) -> BaseFormatter: + """ + Load the appropriate formatter for different types of data formats. + + :param dataset_path: Path to dataset file or dataset directory + :param text_keys: key names of field that stores sample text. + Default: None + :param suffixes: the suffix of files that will be read. Default: + None + :return: a dataset formatter. + """ + + if suffixes is None: + suffixes = [] + ext_num = {} + if os.path.isdir(dataset_path) or os.path.isfile(dataset_path): + file_dict = find_files_with_suffix(dataset_path, suffixes) + if not file_dict: + raise IOError( + 'Unable to find files matching the suffix from {}'.format( + dataset_path)) + for ext in file_dict: + ext_num[ext] = len(file_dict[ext]) + + # local dataset + if ext_num: + formatter_num = {} + for name, formatter in FORMATTERS.modules.items(): + formatter_num[name] = 0 + for ext in ext_num: + if ext in formatter.SUFFIXES: + formatter_num[name] += ext_num[ext] + formatter = max(formatter_num, key=lambda x: formatter_num[x]) + target_suffixes = set(ext_num.keys()).intersection( + set(FORMATTERS.modules[formatter].SUFFIXES)) + return FORMATTERS.modules[formatter](dataset_path, + text_keys=text_keys, + suffixes=target_suffixes, + add_suffix=add_suffix, + **kwargs) + + # try huggingface dataset hub + elif not is_absolute_path(dataset_path) and dataset_path.count('/') <= 1: + return RemoteFormatter(dataset_path, text_keys=text_keys, **kwargs) + + # no data + else: + raise ValueError(f'Unable to load the dataset from [{dataset_path}]. ' + f'It might be because Data-Juicer doesn\'t support ' + f'the format of this dataset, or the path of this ' + f'dataset is incorrect.Please check if it\'s a valid ' + f'dataset path and retry.') diff --git a/data_juicer/format/mixture_formatter.py b/data_juicer/format/mixture_formatter.py index ff34fac3b..c62a6b845 100644 --- a/data_juicer/format/mixture_formatter.py +++ b/data_juicer/format/mixture_formatter.py @@ -1,14 +1,10 @@ -import os from typing import List, Union import numpy as np from datasets import Dataset, concatenate_datasets from loguru import logger -from data_juicer.format.formatter import (FORMATTERS, BaseFormatter, - RemoteFormatter) -from data_juicer.utils.file_utils import (find_files_with_suffix, - is_absolute_path) +from data_juicer.format.formatter import BaseFormatter from data_juicer.utils.sample import random_sample @@ -59,38 +55,14 @@ def __init__(self, self.sample_numbers = sample_numbers self.weights = weights - self.formatters = [ - load_formatter(dataset_path=data_prefix, - suffixes=suffixes, - text_keys=text_keys, - add_suffix=add_suffix, - **kwargs) for data_prefix in data_prefixes - ] - - def _get_weight(self, data_prefix): - """ - Split every dataset path and its weight. - - :param data_prefix: a dataset file or a dataset dir or a list of - them, e.g. ` ds1.jsonl ds2_dir ds3_file.json` - :return: list of dataset path and list of weights - """ - data_prefix = data_prefix.split() - weights = [] - prefixes = [] - - for i in range(len(data_prefix)): - try: - value = max(float(data_prefix[i]), 0.0) - weights.append(value) - except: # noqa: E722 - value = data_prefix[i].strip() - - # if not set weight, use 1.0 as default - if i == 0 or len(weights) == len(prefixes): - weights.append(1.0) - prefixes.append(value) - return prefixes, weights + self.formatters = None + # [ + # load_formatter(dataset_path=data_prefix, + # suffixes=suffixes, + # text_keys=text_keys, + # add_suffix=add_suffix, + # **kwargs) for data_prefix in data_prefixes + # ] def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: """ @@ -114,61 +86,3 @@ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: mixed_dataset = NestedDataset(concatenate_datasets(dataset_list)) logger.info(f'There are {len(mixed_dataset)} in final dataset') return mixed_dataset - - -def load_formatter(dataset_path, - text_keys=None, - suffixes=None, - add_suffix=False, - **kwargs) -> BaseFormatter: - """ - Load the appropriate formatter for different types of data formats. - - :param dataset_path: Path to dataset file or dataset directory - :param text_keys: key names of field that stores sample text. - Default: None - :param suffixes: the suffix of files that will be read. Default: - None - :return: a dataset formatter. - """ - - if suffixes is None: - suffixes = [] - ext_num = {} - if os.path.isdir(dataset_path) or os.path.isfile(dataset_path): - file_dict = find_files_with_suffix(dataset_path, suffixes) - if not file_dict: - raise IOError( - 'Unable to find files matching the suffix from {}'.format( - dataset_path)) - for ext in file_dict: - ext_num[ext] = len(file_dict[ext]) - - # local dataset - if ext_num: - formatter_num = {} - for name, formatter in FORMATTERS.modules.items(): - formatter_num[name] = 0 - for ext in ext_num: - if ext in formatter.SUFFIXES: - formatter_num[name] += ext_num[ext] - formatter = max(formatter_num, key=lambda x: formatter_num[x]) - target_suffixes = set(ext_num.keys()).intersection( - set(FORMATTERS.modules[formatter].SUFFIXES)) - return FORMATTERS.modules[formatter](dataset_path, - text_keys=text_keys, - suffixes=target_suffixes, - add_suffix=add_suffix, - **kwargs) - - # try huggingface dataset hub - elif not is_absolute_path(dataset_path) and dataset_path.count('/') <= 1: - return RemoteFormatter(dataset_path, text_keys=text_keys, **kwargs) - - # no data - else: - raise ValueError(f'Unable to load the dataset from [{dataset_path}]. ' - f'It might be because Data-Juicer doesn\'t support ' - f'the format of this dataset, or the path of this ' - f'dataset is incorrect.Please check if it\'s a valid ' - f'dataset path and retry.') From 70fffd2c1a85f72f99e517c4e00f604925d08425 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Mon, 16 Dec 2024 09:48:35 -0800 Subject: [PATCH 21/56] add DataValidator logic --- data_juicer/core/data/load_strategy.py | 157 ++++++++++++++++++++----- data_juicer/download/downloader.py | 39 ++++++ 2 files changed, 166 insertions(+), 30 deletions(-) diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 20d15194d..c437c2276 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -5,6 +5,7 @@ from typing import Dict, Optional, Type, Union from data_juicer.core.data import DJDataset, RayDataset +from data_juicer.download.downloader import validate_snapshot_format from data_juicer.utils.lazy_loader import LazyLoader ray = LazyLoader('ray', 'ray') @@ -39,17 +40,23 @@ def matches(self, other: 'StrategyKey') -> bool: and fnmatch.fnmatch(other.data_source, self.data_source)) -class ValidationError(Exception): +class ConfigValidationError(Exception): """Custom exception for validation errors""" pass +class DataValidationError(Exception): + """Custom exception for data validation errors""" + pass + + class ConfigValidator: """Mixin class for configuration validation""" # Define validation rules for each strategy type - VALIDATION_RULES = { + CONFIG_VALIDATION_RULES = { 'required_fields': [], # Fields that must be present + 'optional_fields': [], # Fields that are optional 'field_types': {}, # Expected types for fields 'custom_validators': {} # Custom validation functions } @@ -66,34 +73,89 @@ def validate_config(self, ds_config: Dict) -> None: """ # Check required fields missing_fields = [ - field for field in self.VALIDATION_RULES['required_fields'] + field for field in self.CONFIG_VALIDATION_RULES['required_fields'] if field not in ds_config ] if missing_fields: - raise ValidationError( + raise ConfigValidationError( f"Missing required fields: {', '.join(missing_fields)}") + # Optional fields + # no need for any special checks + # Check field types - for field, expected_type in self.VALIDATION_RULES['field_types'].items( - ): + for field, expected_type in self.CONFIG_VALIDATION_RULES[ + 'field_types'].items(): if field in ds_config: value = ds_config[field] if not isinstance(value, expected_type): - raise ValidationError(f"Field '{field}' must be of " - "type '{expected_type.__name__}', " - f"got '{type(value).__name__}'") + raise ConfigValidationError( + f"Field '{field}' must be of " + "type '{expected_type.__name__}', " + f"got '{type(value).__name__}'") # Run custom validators - for field, validator in self.VALIDATION_RULES[ + for field, validator in self.CONFIG_VALIDATION_RULES[ 'custom_validators'].items(): if field in ds_config: try: validator(ds_config[field]) except Exception as e: - raise ValidationError( + raise ConfigValidationError( f"Validation failed for field '{field}': {str(e)}") +class DataValidator: + """Mixin class for data content validation""" + + # Define data validation rules + DATA_VALIDATION_RULES = { + 'required_columns': [], # Columns that must be present in the dataset + 'column_types': {}, # Expected types for columns + 'custom_validators': {} # Custom validation functions for data content + } + + def validate_data(self, dataset) -> None: + """ + Validate the actual dataset content. + + Args: + dataset: The loaded dataset to validate + + Raises: + DataValidationError: If validation fails + """ + # Check required columns + if hasattr(dataset, 'column_names'): + missing_columns = [ + col for col in self.DATA_VALIDATION_RULES['required_columns'] + if col not in dataset.column_names + ] + if missing_columns: + raise DataValidationError( + f"Missing required columns: {', '.join(missing_columns)}") + + # Check column types + for col, expected_type in self.DATA_VALIDATION_RULES[ + 'column_types'].items(): + if col in dataset.column_names: + # Sample check for performance + sample = dataset.select(range(min(100, len(dataset)))) + if not all( + isinstance(val, expected_type) for val in sample[col]): + raise DataValidationError( + f"Column '{col}' contains values of incorrect type") + + # Run custom validators + for validator_name, validator in self.DATA_VALIDATION_RULES[ + 'custom_validators'].items(): + try: + validator(dataset) + except Exception as e: + raise DataValidationError( + f"Data validation '{validator_name}' failed: {str(e)}") + + class DataLoadStrategy(ABC, ConfigValidator): """ abstract class for data load strategy @@ -221,19 +283,12 @@ def load_data(self, cfg: Namespace) -> DJDataset: @DataLoadStrategyRegistry.register('ray', 'ondisk', 'json') class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): - VALIDATION_RULES = { - 'required_fields': ['dataset_path'], + CONFIG_VALIDATION_RULES = { + 'required_fields': ['path'], 'field_types': { - 'dataset_path': (str, list), # Can be string or list - 'cache_dir': str, - 'shuffle': bool + 'path': (str, list) # Can be string or list }, - 'custom_validators': { - 'dataset_path': - lambda x: - (isinstance(x, (str, list)) and - (isinstance(x, str) or all(isinstance(p, str) for p in x))) - } + 'custom_validators': {} } def load_data(self, cfg: Namespace): @@ -243,16 +298,12 @@ def load_data(self, cfg: Namespace): @DataLoadStrategyRegistry.register('ray', 'remote', 'huggingface') class RayHuggingfaceDataLoadStrategy(RayDataLoadStrategy): - VALIDATION_RULES = { - 'required_fields': ['dataset_name', 'split'], + CONFIG_VALIDATION_RULES = { + 'required_fields': ['path'], 'field_types': { - 'dataset_name': str, - 'split': str, - 'streaming': bool + 'path': (str, list) # Can be string or list }, - 'custom_validators': { - 'split': lambda x: x in ['train', 'test', 'validation'] - } + 'custom_validators': {} } def load_data(self, cfg: Namespace): @@ -266,6 +317,14 @@ class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy): rely on AutoFormatter for actual data loading """ + CONFIG_VALIDATION_RULES = { + 'required_fields': ['path'], + 'field_types': { + 'path': (str, list) # Can be string or list + }, + 'custom_validators': {} + } + def load_data(self, cfg: Namespace): pass @@ -276,6 +335,14 @@ class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy): data load strategy for Huggingface dataset for LocalExecutor """ + CONFIG_VALIDATION_RULES = { + 'required_fields': ['path'], + 'field_types': { + 'path': (str, list) # Can be string or list + }, + 'custom_validators': {} + } + def load_data(self, cfg: Namespace): pass @@ -296,6 +363,14 @@ class LocalArxivDataLoadStrategy(LocalDataLoadStrategy): data load strategy for arxiv dataset for LocalExecutor """ + CONFIG_VALIDATION_RULES = { + 'required_fields': ['path'], + 'field_types': { + 'path': (str, list) # Can be string or list + }, + 'custom_validators': {} + } + def load_data(self, cfg: Namespace): pass @@ -306,6 +381,14 @@ class LocalWikiDataLoadStrategy(LocalDataLoadStrategy): data load strategy for wiki dataset for LocalExecutor """ + CONFIG_VALIDATION_RULES = { + 'required_fields': ['path'], + 'field_types': { + 'path': (str, list) # Can be string or list + }, + 'custom_validators': {} + } + def load_data(self, cfg: Namespace): pass @@ -316,5 +399,19 @@ class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy): data load strategy for commoncrawl dataset for LocalExecutor """ + CONFIG_VALIDATION_RULES = { + 'required_fields': ['start_snapshot', 'end_snapshot'], + 'optional_fields': ['aws', 'url_limit'], + 'field_types': { + 'start_snapshot': str, + 'end_snapshot': str + }, + 'custom_validators': { + 'start_snashot': validate_snapshot_format, + 'end_snapshot': validate_snapshot_format, + 'url_limit': lambda x: x > 0 + } + } + def load_data(self, cfg: Namespace): pass diff --git a/data_juicer/download/downloader.py b/data_juicer/download/downloader.py index 8cebfec54..79eb186da 100644 --- a/data_juicer/download/downloader.py +++ b/data_juicer/download/downloader.py @@ -7,6 +7,7 @@ from urllib.parse import urljoin import pandas as pd +import regex as re import requests from bs4 import BeautifulSoup from datasets import Dataset @@ -222,3 +223,41 @@ def get_arxiv_urls(): urls.sort() return urls + + +def validate_snapshot_format(snapshot: Optional[str]) -> None: + """ + Validate snapshot format 'YYYY-WW'. + + Args: + snapshot: Snapshot string in format 'YYYY-WW' or None + + Raises: + ValueError: If format is invalid + """ + if snapshot is None: + return + + # Check basic format with regex + pattern = r'^\d{4}-\d{2}$' + if not re.match(pattern, snapshot): + raise ValueError(f'Invalid snapshot format: {snapshot}. ' + "Expected format: 'YYYY-WW' (e.g., '2020-50')") + + # Parse year and week + try: + year, week = map(int, snapshot.split('-')) + + # Validate year + if not (2000 <= year <= 2100): # Reasonable year range + raise ValueError(f'Year must be between 2000 and 2100, got {year}') + + # Validate week number (1-53) + if not (1 <= week <= 53): + raise ValueError(f'Week must be between 1 and 53, got {week}') + + except ValueError as e: + if str(e).startswith('Week') or str(e).startswith('Year'): + raise + raise ValueError(f'Invalid snapshot format: {snapshot}. ' + "Expected format: 'YYYY-WW' (e.g., '2020-50')") From bbc303d2764d2a148c516e1a6117cd078adba720 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Mon, 16 Dec 2024 12:58:37 -0800 Subject: [PATCH 22/56] data validator as separate pre-processing --- configs/datasets/validation.yaml | 17 +++ data_juicer/core/data/data_validator.py | 132 +++++++++++++++++++++++ data_juicer/core/data/dataset_builder.py | 21 +++- data_juicer/core/data/load_strategy.py | 56 ---------- 4 files changed, 167 insertions(+), 59 deletions(-) create mode 100644 configs/datasets/validation.yaml create mode 100644 data_juicer/core/data/data_validator.py diff --git a/configs/datasets/validation.yaml b/configs/datasets/validation.yaml new file mode 100644 index 000000000..dbe1cbedd --- /dev/null +++ b/configs/datasets/validation.yaml @@ -0,0 +1,17 @@ +dataset: + type: ondisk + path: path/to/data.json + +validators: + - type: conversation + min_turns: 2 + max_turns: 20 + - type: required_fields + required_fields: + - "text" + - "metadata" + - "language" + field_types: + text: "str" + metadata: "dict" + language: "str" diff --git a/data_juicer/core/data/data_validator.py b/data_juicer/core/data/data_validator.py new file mode 100644 index 000000000..b4fe9d28a --- /dev/null +++ b/data_juicer/core/data/data_validator.py @@ -0,0 +1,132 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional, Type, Union + +from data_juicer.core.data.dj_dataset import NestedDataset +from data_juicer.core.data.ray_dataset import RayDataset + + +class DataValidator(ABC): + """Base class for data validation""" + + @abstractmethod + def validate(self, dataset) -> None: + """ + Validate dataset content + + Args: + dataset: The dataset to validate + + Raises: + DataValidationError: If validation fails + """ + pass + + +class DataValidationError(Exception): + """Custom exception for data validation errors""" + pass + + +class DataValidatorRegistry: + """Registry for data validators""" + + _validators: Dict[str, Type[DataValidator]] = {} + + @classmethod + def register(cls, validator_type: str): + + def decorator(validator_class: Type[DataValidator]): + cls._validators[validator_type] = validator_class + return validator_class + + return decorator + + @classmethod + def get_validator(cls, + validator_type: str) -> Optional[Type[DataValidator]]: + return cls._validators.get(validator_type) + + +@DataValidatorRegistry.register('conversation') +class ConversationDataValidator(DataValidator): + """Validator for conversation data""" + + def __init__(self, config: Dict): + self.config = config + # Validation rules specific to conversation data + self.required_columns = ['text'] + self.min_turns = config.get('min_turns', 2) + self.max_turns = config.get('max_turns', 100) + + def validate(self, dataset) -> None: + # Check required columns + if not all(col in dataset.column_names + for col in self.required_columns): + raise DataValidationError( + f'Missing required columns: {self.required_columns}') + + # Validate conversation structure + for item in dataset: + turns = self._parse_turns(item['text']) + if not (self.min_turns <= len(turns) <= self.max_turns): + raise DataValidationError( + f'Conversation must have between {self.min_turns} and ' + f'{self.max_turns} turns') + + # Additional conversation-specific validations... + + +@DataValidatorRegistry.register('code') +class CodeDataValidator(DataValidator): + """Validator for code data""" + + def __init__(self, config: Dict): + self.config = config + self.required_columns = ['code', 'language'] + self.supported_languages = config.get('supported_languages', []) + + def validate(self, dataset) -> None: + # Implement code-specific validation logic... + pass + + +@DataValidatorRegistry.register('required_fields') +class RequiredFieldsValidator(DataValidator): + """Validator that checks for required fields in dataset""" + + def __init__(self, config: Dict): + """ + Initialize validator with config + + Args: + config: Dict containing: + - required_fields: List of field names that must exist + - field_types: Optional map of field names to expected types + """ + self.config = config + self.required_fields = config['required_fields'] + self.field_types = config.get('field_types', {}) + + def validate(self, dataset: Union[NestedDataset, RayDataset]) -> None: + """ + Validate dataset has required fields with correct types + + Args: + dataset: NestedDataset or RayDataset to validate + + Raises: + DataValidationError: If validation fails + """ + # Check if fields exist in dataset + if isinstance(dataset, NestedDataset): + available_fields = set(dataset.column_names) + elif isinstance(dataset, RayDataset): + available_fields = set(dataset.schema().names) + else: + raise DataValidationError( + f'Unsupported dataset type: {type(dataset)}') + + missing_fields = set(self.required_fields) - available_fields + if missing_fields: + raise DataValidationError( + f'Dataset missing required fields: {missing_fields}') diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 06e4dbbf0..11da30c2e 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -3,6 +3,7 @@ from typing import List, Tuple, Union from data_juicer.core.data import NestedDataset +from data_juicer.core.data.data_validator import DataValidatorRegistry from data_juicer.core.data.load_strategy import DataLoadStrategyRegistry from data_juicer.core.data.ray_dataset import RayDataset from data_juicer.utils.file_utils import is_absolute_path @@ -34,12 +35,26 @@ def __init__(self, cfg): DataLoadStrategyRegistry.get_strategy_class( executor_type, data_type, data_source)(ds_config)) + self.validators = [] + if hasattr(cfg, 'validators'): + for validator_config in cfg.validators: + validator_type = validator_config['type'] + validator_cls = DataValidatorRegistry.get_validator( + validator_type) + if validator_cls: + self.validators.append(validator_cls(validator_config)) + def load_dataset(self) -> Union[NestedDataset, RayDataset]: - # handle mixture dataset, nested dataset - # handle sampling of mixture datasets + # load dataset with its load strategy + # do data validation _datasets = [] for f in self.load_strategies: - _datasets.append(f.load_data(self.cfg)) + _dataset = f.load_data(self.cfg) + for validator in self.validators: + validator.validate(_dataset) + _datasets.append(_dataset) + + # handle data mixture return _datasets[0] diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index c437c2276..9bee4903e 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -45,11 +45,6 @@ class ConfigValidationError(Exception): pass -class DataValidationError(Exception): - """Custom exception for data validation errors""" - pass - - class ConfigValidator: """Mixin class for configuration validation""" @@ -105,57 +100,6 @@ def validate_config(self, ds_config: Dict) -> None: f"Validation failed for field '{field}': {str(e)}") -class DataValidator: - """Mixin class for data content validation""" - - # Define data validation rules - DATA_VALIDATION_RULES = { - 'required_columns': [], # Columns that must be present in the dataset - 'column_types': {}, # Expected types for columns - 'custom_validators': {} # Custom validation functions for data content - } - - def validate_data(self, dataset) -> None: - """ - Validate the actual dataset content. - - Args: - dataset: The loaded dataset to validate - - Raises: - DataValidationError: If validation fails - """ - # Check required columns - if hasattr(dataset, 'column_names'): - missing_columns = [ - col for col in self.DATA_VALIDATION_RULES['required_columns'] - if col not in dataset.column_names - ] - if missing_columns: - raise DataValidationError( - f"Missing required columns: {', '.join(missing_columns)}") - - # Check column types - for col, expected_type in self.DATA_VALIDATION_RULES[ - 'column_types'].items(): - if col in dataset.column_names: - # Sample check for performance - sample = dataset.select(range(min(100, len(dataset)))) - if not all( - isinstance(val, expected_type) for val in sample[col]): - raise DataValidationError( - f"Column '{col}' contains values of incorrect type") - - # Run custom validators - for validator_name, validator in self.DATA_VALIDATION_RULES[ - 'custom_validators'].items(): - try: - validator(dataset) - except Exception as e: - raise DataValidationError( - f"Data validation '{validator_name}' failed: {str(e)}") - - class DataLoadStrategy(ABC, ConfigValidator): """ abstract class for data load strategy From 4b6065f58fabf95cb9a05c1bcb872d5db2e75042 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 25 Dec 2024 10:18:48 -0800 Subject: [PATCH 23/56] update data validator logic and add/fix test cases --- data_juicer/core/analyzer.py | 17 ++- data_juicer/core/data/config_validator.py | 61 +++++++++++ data_juicer/core/data/data_validator.py | 54 +++++++++- data_juicer/core/data/dataset_builder.py | 23 +++- data_juicer/core/data/load_strategy.py | 61 +---------- data_juicer/core/data/ray_dataset.py | 3 + data_juicer/format/__init__.py | 7 +- data_juicer/format/load.py | 4 +- tests/core/test_data_validator.py | 124 ++++++++++++++++++++++ tests/download/test_download.py | 70 ++++++++++-- 10 files changed, 335 insertions(+), 89 deletions(-) create mode 100644 data_juicer/core/data/config_validator.py create mode 100644 tests/core/test_data_validator.py diff --git a/data_juicer/core/analyzer.py b/data_juicer/core/analyzer.py index 2ae4d3511..42d6a01e2 100644 --- a/data_juicer/core/analyzer.py +++ b/data_juicer/core/analyzer.py @@ -7,7 +7,7 @@ from data_juicer.analysis import ColumnWiseAnalysis, OverallAnalysis from data_juicer.config import init_configs -from data_juicer.format import load_formatter +from data_juicer.core.data.dataset_builder import DatasetBuilder from data_juicer.ops import Filter, load_ops from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils import cache_utils @@ -42,14 +42,9 @@ def __init__(self, cfg: Optional[Namespace] = None): f'[{self.cfg.cache_compress}]') cache_utils.CACHE_COMPRESS = self.cfg.cache_compress - # setup formatter + # setup dataset builder logger.info('Setting up data formatter...') - self.formatter = load_formatter( - dataset_path=self.cfg.dataset_path, - generated_dataset_config=self.cfg.generated_dataset_config, - text_keys=self.cfg.text_keys, - suffixes=self.cfg.suffixes, - add_suffix=self.cfg.add_suffix) + self.dataset_builder = DatasetBuilder(self.cfg) # prepare exporter and check export path suffix # NOTICE: no need to export dataset texts for analyzer @@ -84,9 +79,9 @@ def run(self, """ # 1. format data logger.info('Loading dataset from data formatter...') - if load_data_np is None: - load_data_np = self.cfg.np - dataset = self.formatter.load_dataset(load_data_np, self.cfg) + if load_data_np is not None: + self.dataset_builder.set_dataset_path(load_data_np) + dataset = self.dataset_builder.load_dataset() # extract processes logger.info('Preparing process operators...') diff --git a/data_juicer/core/data/config_validator.py b/data_juicer/core/data/config_validator.py new file mode 100644 index 000000000..a48832912 --- /dev/null +++ b/data_juicer/core/data/config_validator.py @@ -0,0 +1,61 @@ +from typing import Dict + + +class ConfigValidationError(Exception): + """Custom exception for validation errors""" + pass + + +class ConfigValidator: + """Mixin class for configuration validation""" + + # Define validation rules for each strategy type + CONFIG_VALIDATION_RULES = { + 'required_fields': [], # Fields that must be present + 'optional_fields': [], # Fields that are optional + 'field_types': {}, # Expected types for fields + 'custom_validators': {} # Custom validation functions + } + + def validate_config(self, ds_config: Dict) -> None: + """ + Validate the configuration dictionary. + + Args: + ds_config: Configuration dictionary to validate + + Raises: + ValidationError: If validation fails + """ + # Check required fields + missing_fields = [ + field for field in self.CONFIG_VALIDATION_RULES['required_fields'] + if field not in ds_config + ] + if missing_fields: + raise ConfigValidationError( + f"Missing required fields: {', '.join(missing_fields)}") + + # Optional fields + # no need for any special checks + + # Check field types + for field, expected_type in self.CONFIG_VALIDATION_RULES[ + 'field_types'].items(): + if field in ds_config: + value = ds_config[field] + if not isinstance(value, expected_type): + raise ConfigValidationError( + f"Field '{field}' must be of " + "type '{expected_type.__name__}', " + f"got '{type(value).__name__}'") + + # Run custom validators + for field, validator in self.CONFIG_VALIDATION_RULES[ + 'custom_validators'].items(): + if field in ds_config: + try: + validator(ds_config[field]) + except Exception as e: + raise ConfigValidationError( + f"Validation failed for field '{field}': {str(e)}") diff --git a/data_juicer/core/data/data_validator.py b/data_juicer/core/data/data_validator.py index b4fe9d28a..19a1f4d64 100644 --- a/data_juicer/core/data/data_validator.py +++ b/data_juicer/core/data/data_validator.py @@ -8,6 +8,9 @@ class DataValidator(ABC): """Base class for data validation""" + def __init__(self, config: Dict): + self.config = config + @abstractmethod def validate(self, dataset) -> None: """ @@ -52,7 +55,8 @@ class ConversationDataValidator(DataValidator): """Validator for conversation data""" def __init__(self, config: Dict): - self.config = config + super().__init__(config) + # Validation rules specific to conversation data self.required_columns = ['text'] self.min_turns = config.get('min_turns', 2) @@ -81,7 +85,8 @@ class CodeDataValidator(DataValidator): """Validator for code data""" def __init__(self, config: Dict): - self.config = config + super().__init__(config) + self.required_columns = ['code', 'language'] self.supported_languages = config.get('supported_languages', []) @@ -102,10 +107,14 @@ def __init__(self, config: Dict): config: Dict containing: - required_fields: List of field names that must exist - field_types: Optional map of field names to expected types + - allow_missing: Optional float for max ratio missing allowed """ - self.config = config + super().__init__(config) + self.required_fields = config['required_fields'] self.field_types = config.get('field_types', {}) + # Default no missing allowed + self.allow_missing = config.get('allow_missing', 0.0) def validate(self, dataset: Union[NestedDataset, RayDataset]) -> None: """ @@ -130,3 +139,42 @@ def validate(self, dataset: Union[NestedDataset, RayDataset]) -> None: if missing_fields: raise DataValidationError( f'Dataset missing required fields: {missing_fields}') + + # Check field types and missing values + for field in self.required_fields: + # Get expected type if specified + expected_type = self.field_types.get(field) + + # Sample data for validation + # For large datasets, we check a sample for performance + MAX_SAMPLE_SIZE = 1000 + if isinstance(dataset, NestedDataset): + sample_size = min(MAX_SAMPLE_SIZE, len(dataset)) + sample = dataset.select(range(sample_size)) + values = sample[field] + elif isinstance(dataset, RayDataset): # RayDataset + sample_size = min(MAX_SAMPLE_SIZE, dataset.data.count()) + sample = dataset.data.take(sample_size) + values = [row[field] for row in sample] + else: + raise NotImplementedError + + # Check for missing values + missing_count = sum(1 for v in values if v is None) + missing_ratio = missing_count / len(values) + if missing_ratio > self.allow_missing: + raise DataValidationError( + f"Field '{field}' has {missing_ratio:.1%} missing values, " + f'exceeding allowed {self.allow_missing:.1%}') + + # Check types if specified + if expected_type: + invalid_types = [ + type(v) for v in values + if v is not None and not isinstance(v, expected_type) + ] + if invalid_types: + raise DataValidationError( + f"Field '{field}' contains values of incorrect type. " + f'Expected {expected_type.__name__}, ' + f'got {set(t.__name__ for t in invalid_types)}') diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 11da30c2e..77e80cd33 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -13,6 +13,7 @@ class DatasetBuilder(object): def __init__(self, cfg): self.cfg = cfg + # defaults to use dataset_path if cfg.dataset_path is not None: ds_configs = rewrite_cli_datapath(cfg.dataset_path) @@ -22,6 +23,7 @@ def __init__(self, cfg): raise ValueError( 'Unable to initialize dataset; should have one of ' 'dataset_path or dataset in configurations') + # dataset config could be a list or a single entry; retrofit if not isinstance(ds_configs, list): ds_configs = [ds_configs] @@ -35,6 +37,7 @@ def __init__(self, cfg): DataLoadStrategyRegistry.get_strategy_class( executor_type, data_type, data_source)(ds_config)) + # initialize data validators self.validators = [] if hasattr(cfg, 'validators'): for validator_config in cfg.validators: @@ -45,11 +48,12 @@ def __init__(self, cfg): self.validators.append(validator_cls(validator_config)) def load_dataset(self) -> Union[NestedDataset, RayDataset]: - # load dataset with its load strategy - # do data validation _datasets = [] for f in self.load_strategies: + # load dataset with its load strategy _dataset = f.load_data(self.cfg) + + # do data validation for validator in self.validators: validator.validate(_dataset) _datasets.append(_dataset) @@ -57,6 +61,21 @@ def load_dataset(self) -> Union[NestedDataset, RayDataset]: # handle data mixture return _datasets[0] + @classmethod + def load_dataset_by_generated_config(cls, generated_dataset_config): + """ + load dataset by generated config + """ + assert isinstance(generated_dataset_config, + dict) and 'type' in generated_dataset_config + args = generated_dataset_config.copy() + + # TODO finish the auto local dataset part + obj_name = args.pop('type') + from data_juicer.format.formatter import FORMATTERS + dataset = FORMATTERS.modules[obj_name](**args).load_dataset() + return dataset + def rewrite_cli_datapath(dataset_path) -> List: """ diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 9bee4903e..f94e51d86 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -5,6 +5,7 @@ from typing import Dict, Optional, Type, Union from data_juicer.core.data import DJDataset, RayDataset +from data_juicer.core.data.config_validator import ConfigValidator from data_juicer.download.downloader import validate_snapshot_format from data_juicer.utils.lazy_loader import LazyLoader @@ -40,66 +41,6 @@ def matches(self, other: 'StrategyKey') -> bool: and fnmatch.fnmatch(other.data_source, self.data_source)) -class ConfigValidationError(Exception): - """Custom exception for validation errors""" - pass - - -class ConfigValidator: - """Mixin class for configuration validation""" - - # Define validation rules for each strategy type - CONFIG_VALIDATION_RULES = { - 'required_fields': [], # Fields that must be present - 'optional_fields': [], # Fields that are optional - 'field_types': {}, # Expected types for fields - 'custom_validators': {} # Custom validation functions - } - - def validate_config(self, ds_config: Dict) -> None: - """ - Validate the configuration dictionary. - - Args: - ds_config: Configuration dictionary to validate - - Raises: - ValidationError: If validation fails - """ - # Check required fields - missing_fields = [ - field for field in self.CONFIG_VALIDATION_RULES['required_fields'] - if field not in ds_config - ] - if missing_fields: - raise ConfigValidationError( - f"Missing required fields: {', '.join(missing_fields)}") - - # Optional fields - # no need for any special checks - - # Check field types - for field, expected_type in self.CONFIG_VALIDATION_RULES[ - 'field_types'].items(): - if field in ds_config: - value = ds_config[field] - if not isinstance(value, expected_type): - raise ConfigValidationError( - f"Field '{field}' must be of " - "type '{expected_type.__name__}', " - f"got '{type(value).__name__}'") - - # Run custom validators - for field, validator in self.CONFIG_VALIDATION_RULES[ - 'custom_validators'].items(): - if field in ds_config: - try: - validator(ds_config[field]) - except Exception as e: - raise ConfigValidationError( - f"Validation failed for field '{field}': {str(e)}") - - class DataLoadStrategy(ABC, ConfigValidator): """ abstract class for data load strategy diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 352cd95dc..4893762b4 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -24,6 +24,9 @@ def __init__(self, if cfg: self.num_proc = cfg.np + def schema(self): + return self.data.schema() + def process(self, operators, *, diff --git a/data_juicer/format/__init__.py b/data_juicer/format/__init__.py index a0473fca3..368dcf220 100644 --- a/data_juicer/format/__init__.py +++ b/data_juicer/format/__init__.py @@ -5,14 +5,13 @@ from .empty_formatter import EmptyFormatter, RayEmptyFormatter from .formatter import LocalFormatter, RemoteFormatter from .json_formatter import JsonFormatter -from .load import load_formatter from .mixture_formatter import MixtureFormatter from .parquet_formatter import ParquetFormatter from .text_formatter import TextFormatter from .tsv_formatter import TsvFormatter __all__ = [ - 'load_formatter', 'JsonFormatter', 'LocalFormatter', 'RemoteFormatter', - 'TextFormatter', 'ParquetFormatter', 'CsvFormatter', 'TsvFormatter', - 'MixtureFormatter', 'EmptyFormatter', 'RayEmptyFormatter' + 'JsonFormatter', 'LocalFormatter', 'RemoteFormatter', 'TextFormatter', + 'ParquetFormatter', 'CsvFormatter', 'TsvFormatter', 'MixtureFormatter', + 'EmptyFormatter', 'RayEmptyFormatter' ] diff --git a/data_juicer/format/load.py b/data_juicer/format/load.py index ef58b124b..4cfbc8939 100644 --- a/data_juicer/format/load.py +++ b/data_juicer/format/load.py @@ -1,7 +1,7 @@ import os -from data_juicer.format.formatter import (FORMATTERS, BaseFormatter, - MixtureFormatter, RemoteFormatter) +from data_juicer.format import MixtureFormatter, RemoteFormatter +from data_juicer.format.formatter import FORMATTERS, BaseFormatter from data_juicer.utils.file_utils import (find_files_with_suffix, is_absolute_path) diff --git a/tests/core/test_data_validator.py b/tests/core/test_data_validator.py new file mode 100644 index 000000000..7146da752 --- /dev/null +++ b/tests/core/test_data_validator.py @@ -0,0 +1,124 @@ +from unittest import TestCase, main +import datasets +import ray +import ray.data +import pandas as pd + +from data_juicer.core.data import NestedDataset, RayDataset +from data_juicer.core.data.data_validator import (DataValidationError, + RequiredFieldsValidator +) + + +# Test RequiredFieldsValidator +class TestRequiredFieldsValidator(TestCase): + + def setUp(self): + # Create sample DataFrame + self.df = pd.DataFrame({ + 'text': ['Hello', 'World', None, 'Test'], + 'metadata': [{'lang': 'en'}, {'lang': 'es'}, {'lang': 'fr'}, None], + 'score': [1.0, 2.0, 3.0, 4.0] + }) + + # Create dataset + self.dataset = NestedDataset(datasets.Dataset.from_pandas(self.df)) + + # Create ray dataset + self.ray_dataset = RayDataset(ray.data.from_pandas(self.df)) + + + def test_basic_validation(self): + """Test basic field validation""" + config = { + 'required_fields': ['text', 'metadata'], + 'allow_missing': .25 + } + validator = RequiredFieldsValidator(config) + + # Should pass + validator.validate(self.dataset) + + # Should fail with missing field + config['required_fields'].append('nonexistent') + validator = RequiredFieldsValidator(config) + with self.assertRaises(DataValidationError) as exc: + validator.validate(self.dataset) + self.assertIn("missing required fields", str(exc.exception).lower()) + + def test_type_validation(self): + """Test field type validation""" + # Should pass + config = { + 'required_fields': ['text', 'score'], + 'field_types': { + 'text': str, + 'score': float + }, + 'allow_missing': .25 + } + validator = RequiredFieldsValidator(config) + validator.validate(self.dataset) + + # Should fail with wrong type + config['field_types']['score'] = str + validator = RequiredFieldsValidator(config) + with self.assertRaises(DataValidationError) as exc: + validator.validate(self.dataset) + self.assertIn("incorrect type", str(exc.exception).lower()) + + def test_ray_dataset_support(self): + """Test validation with RayDataset""" + config = { + 'required_fields': ['text', 'metadata'], + 'field_types': { + 'text': str, + 'metadata': dict + }, + 'allow_missing': .25 + } + validator = RequiredFieldsValidator(config) + + # Should pass + validator.validate(self.ray_dataset) + + def test_invalid_dataset_type(self): + """Test validation with unsupported dataset type""" + config = { + 'required_fields': ['text'] + } + validator = RequiredFieldsValidator(config) + + with self.assertRaises(DataValidationError) as exc: + validator.validate([1, 2, 3]) # Invalid dataset type + self.assertIn("unsupported dataset type", str(exc.exception).lower()) + + def test_empty_required_fields(self): + """Test validation with empty required fields""" + config = { + 'required_fields': [] + } + validator = RequiredFieldsValidator(config) + + # Should pass as no fields are required + validator.validate(self.dataset) + + def test_multiple_dataset_types(self): + """Test validation works with different dataset types""" + datasets_to_test = [ + ('nested', self.dataset), + ('ray', self.ray_dataset) + ] + + config = { + 'required_fields': ['text', 'metadata', 'score'], + 'allow_missing': .25 + } + validator = RequiredFieldsValidator(config) + + for name, dataset in datasets_to_test: + with self.subTest(dataset_type=name): + validator.validate(dataset) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tests/download/test_download.py b/tests/download/test_download.py index d07e46e5b..a86892cbd 100644 --- a/tests/download/test_download.py +++ b/tests/download/test_download.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import patch, MagicMock import tempfile import os import shutil @@ -18,19 +19,74 @@ def tearDown(self): def test_wikipedia_urls(self): dump_date = "20241101" - urls = get_wikipedia_urls(dump_date=dump_date) - assert len(urls) > 3 - assert urls[0] == "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2" - assert urls[1] == "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream2.xml-p41243p151573.bz2" - assert urls[2] == "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream3.xml-p151574p311329.bz2" + expected_urls = [ + "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2", + "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream2.xml-p41243p151573.bz2", + "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream3.xml-p311574p311329.bz2" + ] + + with patch('requests.get') as mock_get: + # Mock the response from Wikipedia API + mock_response = MagicMock() + mock_response.text = "some HTML containing the dump files" + mock_get.return_value = mock_response + + urls = get_wikipedia_urls(dump_date=dump_date) + + # Verify the function made the correct API call + mock_get.assert_called_once_with( + f"https://dumps.wikimedia.org/enwiki/{dump_date}/") + + # Verify returned URLs + assert len(urls) > 3 + assert urls[0] == expected_urls[0] + assert urls[1] == expected_urls[1] + assert urls[2] == expected_urls[2] - def test_wikipedia_download(self): + @patch('data_juicer.download.wikipedia.get_wikipedia_urls') + @patch('data_juicer.download.wikipedia.download_file') + @patch('data_juicer.download.wikipedia.process_wiki_dump') + def test_wikipedia_download(self, mock_process, mock_download, mock_get_urls): dump_date = "20241101" url_limit = 1 item_limit = 50 - wiki_df = download_wikipedia(self.temp_dir, dump_date=dump_date, url_limit=url_limit, item_limit=item_limit) + + # Mock the URLs returned + mock_urls = [ + "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2" + ] + mock_get_urls.return_value = mock_urls + + # Mock the download process + mock_download.return_value = "/tmp/mock_downloaded_file.bz2" + + # Mock the processing result + mock_df = MagicMock() + mock_df.take.return_value = [{"text": f"Article {i}"} for i in range(10)] + mock_process.return_value = mock_df + + # Run the function + wiki_df = download_wikipedia( + self.temp_dir, + dump_date=dump_date, + url_limit=url_limit, + item_limit=item_limit + ) + + # Verify the calls + mock_get_urls.assert_called_once_with(dump_date=dump_date) + mock_download.assert_called_once_with( + mock_urls[0], + os.path.join(self.temp_dir, os.path.basename(mock_urls[0])) + ) + mock_process.assert_called_once() + + # Verify the result sample = wiki_df.take(10) assert len(sample) == 10 + + # Verify the mocks were used correctly + mock_df.take.assert_called_once_with(10) if __name__ == '__main__': From 0b153ab8e195c557c8877015325fc88374eba18a Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 2 Jan 2025 10:51:15 -0800 Subject: [PATCH 24/56] [nit] rename test --- tests/core/test_dataload_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_dataload_strategy.py b/tests/core/test_dataload_strategy.py index b3f489cfb..f8afc466f 100644 --- a/tests/core/test_dataload_strategy.py +++ b/tests/core/test_dataload_strategy.py @@ -7,7 +7,7 @@ class MockStrategy(DataLoadStrategy): def load_data(self): pass -class TestDataLoadStrategyRegistry(unittest.TestCase): +class DataLoadStrategyRegistryTest(unittest.TestCase): def setUp(self): # Clear existing strategies before each test DataLoadStrategyRegistry._strategies = {} From 171b3619a4a7f6b63f411e8e7d41ec877f576efb Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 2 Jan 2025 10:53:02 -0800 Subject: [PATCH 25/56] [nit] rename test again --- tests/core/test_data_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_data_validator.py b/tests/core/test_data_validator.py index 7146da752..dccbd8e96 100644 --- a/tests/core/test_data_validator.py +++ b/tests/core/test_data_validator.py @@ -11,7 +11,7 @@ # Test RequiredFieldsValidator -class TestRequiredFieldsValidator(TestCase): +class RequiredFieldsValidatorTest(TestCase): def setUp(self): # Create sample DataFrame From 6841d19bc22b513aae96bd2930209a415cddc667 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 2 Jan 2025 13:44:24 -0800 Subject: [PATCH 26/56] add builder test cases; update ds config validation logic --- data_juicer/core/data/dataset_builder.py | 34 +++++- data_juicer/core/data/load_strategy.py | 5 + data_juicer/core/executor/base.py | 1 + data_juicer/core/executor/local_executor.py | 2 +- data_juicer/core/executor/ray_executor.py | 2 +- tests/core/test_dataset_builder.py | 120 +++++++++++++++++++- 6 files changed, 154 insertions(+), 10 deletions(-) diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 77e80cd33..2e2ec1e84 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -3,6 +3,7 @@ from typing import List, Tuple, Union from data_juicer.core.data import NestedDataset +from data_juicer.core.data.config_validator import ConfigValidationError from data_juicer.core.data.data_validator import DataValidatorRegistry from data_juicer.core.data.load_strategy import DataLoadStrategyRegistry from data_juicer.core.data.ray_dataset import RayDataset @@ -11,31 +12,50 @@ class DatasetBuilder(object): - def __init__(self, cfg): + def __init__(self, cfg, executor_type): self.cfg = cfg + self.executor_type = executor_type # defaults to use dataset_path if cfg.dataset_path is not None: ds_configs = rewrite_cli_datapath(cfg.dataset_path) - elif cfg.dataset is not None: + elif cfg.dataset not in (None, []): ds_configs = cfg.dataset else: - raise ValueError( + raise ConfigValidationError( 'Unable to initialize dataset; should have one of ' 'dataset_path or dataset in configurations') # dataset config could be a list or a single entry; retrofit if not isinstance(ds_configs, list): ds_configs = [ds_configs] + + # validate dataset config for type constraints + # 1. ds_config should be a dictionary + # 2. ds_configs should only have one type + # 3. if type is REMOTE, there should only one ds_config + # TODO other constraints; ray dataset only supports ondisk, etc. + for ds_config in ds_configs: + if type(ds_config) != dict: + raise ConfigValidationError( + 'Dataset config should be a dictionary') + types = [ds_config.get('type', None) for ds_config in ds_configs] + if len(set(types)) > 1: + raise ConfigValidationError( + 'Mixture of diff types (ONDISK/REMOTE/...) are not supported') + if types[0] == 'remote' and len(ds_configs) > 1: + raise ConfigValidationError( + 'Multiple remote datasets are not supported') + + # initialize the data load strategies self.load_strategies = [] for ds_config in ds_configs: # initialize data loading strategy - executor_type = ds_config.get('executor_type', None) data_type = ds_config.get('type', None) data_source = ds_config.get('source', None) self.load_strategies.append( DataLoadStrategyRegistry.get_strategy_class( - executor_type, data_type, data_source)(ds_config)) + self.executor_type, data_type, data_source)(ds_config)) # initialize data validators self.validators = [] @@ -49,6 +69,7 @@ def __init__(self, cfg): def load_dataset(self) -> Union[NestedDataset, RayDataset]: _datasets = [] + for f in self.load_strategies: # load dataset with its load strategy _dataset = f.load_data(self.cfg) @@ -58,7 +79,8 @@ def load_dataset(self) -> Union[NestedDataset, RayDataset]: validator.validate(_dataset) _datasets.append(_dataset) - # handle data mixture + # handle data mixture; only supports ONDISK + return _datasets[0] @classmethod diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index f94e51d86..0b0377739 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -72,6 +72,11 @@ def get_strategy_class( 1. Exact match 2. Wildcard matches from most specific to most general """ + # default to wildcard if not provided + executor_type = executor_type or '*' + data_type = data_type or '*' + data_source = data_source or '*' + # Create the lookup key lookup_key = StrategyKey(executor_type, data_type, data_source) diff --git a/data_juicer/core/executor/base.py b/data_juicer/core/executor/base.py index 9ed7e602d..e1d6e2073 100644 --- a/data_juicer/core/executor/base.py +++ b/data_juicer/core/executor/base.py @@ -12,6 +12,7 @@ class ExecutorBase(ABC): @abstractmethod def __init__(self, cfg: Optional[Namespace] = None): self.cfg = init_configs() if cfg is None else cfg + self.executor_type = 'base' @abstractmethod def run(self, diff --git a/data_juicer/core/executor/local_executor.py b/data_juicer/core/executor/local_executor.py index 39c9f595e..b9233061e 100644 --- a/data_juicer/core/executor/local_executor.py +++ b/data_juicer/core/executor/local_executor.py @@ -37,7 +37,7 @@ def __init__(self, cfg: Optional[Namespace] = None): :param cfg: optional jsonargparse Namespace. """ super().__init__(cfg) - + self.executor_type = 'local' self.work_dir = self.cfg.work_dir self.tracer = None diff --git a/data_juicer/core/executor/ray_executor.py b/data_juicer/core/executor/ray_executor.py index c6fcef5e8..1a311286c 100644 --- a/data_juicer/core/executor/ray_executor.py +++ b/data_juicer/core/executor/ray_executor.py @@ -34,7 +34,7 @@ def __init__(self, cfg=None): :param cfg: optional config dict. """ super().__init__(cfg) - + self.executor_type = 'ray' self.work_dir = self.cfg.work_dir self.adapter = Adapter(self.cfg) diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index 86e0aab0e..c0d6e544d 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -5,17 +5,28 @@ from contextlib import redirect_stdout from io import StringIO from data_juicer.config import init_configs -from data_juicer.core.data.dataset_builder import rewrite_cli_datapath, parse_cli_datapath -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +from data_juicer.core.data.dataset_builder import (rewrite_cli_datapath, + parse_cli_datapath, + DatasetBuilder) +from data_juicer.core.data.config_validator import ConfigValidationError +from data_juicer.utils.unittest_utils import (DataJuicerTestCaseBase, + SKIPPED_TESTS) + @SKIPPED_TESTS.register_module() class DatasetBuilderTest(DataJuicerTestCaseBase): def setUp(self): + """Setup basic configuration for tests""" + self.base_cfg = Namespace() + self.base_cfg.dataset_path = None + self.executor_type = 'local' + # Get the directory where this test file is located test_file_dir = os.path.dirname(os.path.abspath(__file__)) os.chdir(test_file_dir) + def test_rewrite_cli_datapath_local_single_file(self): dataset_path = "./data/sample.txt" ans = rewrite_cli_datapath(dataset_path) @@ -214,6 +225,111 @@ def test_parse_cli_datapath_multiple_datasets(self): self.assertEqual(weights, expected_weights, f"Failed weights for input: {input_path}") + def test_builder_single_dataset_config(self): + """Test handling of single dataset configuration""" + # Setup single dataset config + self.base_cfg.dataset = { + 'type': 'ondisk', + 'path': 'test.jsonl' + } + + builder = DatasetBuilder(self.base_cfg, self.executor_type) + + # Verify config was converted to list + self.assertIsInstance(builder.load_strategies, list) + self.assertEqual(len(builder.load_strategies), 1) + + # Verify config content preserved + strategy = builder.load_strategies[0] + self.assertEqual(strategy.ds_config['type'], 'ondisk') + self.assertEqual(strategy.ds_config['path'], 'test.jsonl') + + def test_builder_multiple_dataset_config(self): + """Test handling of multiple dataset configurations""" + # Setup multiple dataset config + self.base_cfg.dataset = [ + { + 'type': 'ondisk', + 'path': 'test1.jsonl' + }, + { + 'type': 'ondisk', + 'path': 'test2.jsonl' + } + ] + + builder = DatasetBuilder(self.base_cfg, self.executor_type) + + # Verify list handling + self.assertEqual(len(builder.load_strategies), 2) + + # Verify each config + self.assertEqual(builder.load_strategies[0].ds_config['path'], 'test1.jsonl') + self.assertEqual(builder.load_strategies[1].ds_config['path'], 'test2.jsonl') + + def test_builder_none_dataset_config(self): + """Test handling when both dataset and dataset_path are None""" + self.base_cfg.dataset = None + + with self.assertRaises(ConfigValidationError) as context: + DatasetBuilder(self.base_cfg, self.executor_type) + + self.assertIn('dataset_path or dataset', str(context.exception)) + + def test_builder_mixed_dataset_types(self): + """Test validation of mixed dataset types""" + self.base_cfg.dataset = [ + { + 'type': 'ondisk', + 'path': 'test1.jsonl' + }, + { + 'type': 'remote', + 'source': 'some_source' + } + ] + + with self.assertRaises(ConfigValidationError) as context: + DatasetBuilder(self.base_cfg, self.executor_type) + + self.assertIn('Mixture of diff types', str(context.exception)) + + def test_builder_multiple_remote_datasets(self): + """Test validation of multiple remote datasets""" + self.base_cfg.dataset = [ + { + 'type': 'remote', + 'source': 'source1' + }, + { + 'type': 'remote', + 'source': 'source2' + } + ] + + with self.assertRaises(ConfigValidationError) as context: + DatasetBuilder(self.base_cfg, self.executor_type) + + self.assertIn('Multiple remote datasets', str(context.exception)) + + def test_builder_empty_dataset_config(self): + """Test handling of empty dataset configuration""" + self.base_cfg.dataset = [] + + with self.assertRaises(ConfigValidationError) as context: + DatasetBuilder(self.base_cfg, self.executor_type) + + self.assertIn('dataset_path or dataset', str(context.exception)) + + def test_builder_invalid_dataset_config_type(self): + """Test handling of invalid dataset configuration type""" + self.base_cfg.dataset = "invalid_string_config" + + with self.assertRaises(ConfigValidationError) as context: + DatasetBuilder(self.base_cfg, self.executor_type) + + self.assertIn('Dataset config should be a dictionary', + str(context.exception)) if __name__ == '__main__': unittest.main() From 3128d05100f1ea134373505fe4b6ae511741250c Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 2 Jan 2025 13:46:08 -0800 Subject: [PATCH 27/56] [minor] update test case naming --- tests/core/test_dataset_builder.py | 46 +++++++++++++++--------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index c0d6e544d..a7d56734c 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -60,29 +60,6 @@ def test_rewrite_cli_datapath_with_weights(self): {'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}], ans) - def test_dataset_builder_ondisk_config(self): - test_config_file = './data/test_config.yaml' - out = StringIO() - with redirect_stdout(out): - cfg = init_configs(args=f'--config {test_config_file}'.split()) - self.assertIsInstance(cfg, Namespace) - self.assertEqual(cfg.project_name, 'dataset-ondisk-json') - self.assertEqual(cfg.dataset, - {'path': ['sample.json'], 'type': 'ondisk'}) - self.assertEqual(not cfg.dataset_path, True) - - def test_dataset_builder_ondisk_config_list(self): - test_config_file = './data/test_config_list.yaml' - out = StringIO() - with redirect_stdout(out): - cfg = init_configs(args=f'--config {test_config_file}'.split()) - self.assertIsInstance(cfg, Namespace) - self.assertEqual(cfg.project_name, 'dataset-ondisk-list') - self.assertEqual(cfg.dataset,[ - {'path': ['sample.json'], 'type': 'ondisk'}, - {'path': ['sample.txt'], 'type': 'ondisk'}]) - self.assertEqual(not cfg.dataset_path, True) - @patch('os.path.isdir') @patch('os.path.isfile') def test_rewrite_cli_datapath_local_files(self, mock_isfile, mock_isdir): @@ -331,5 +308,28 @@ def test_builder_invalid_dataset_config_type(self): self.assertIn('Dataset config should be a dictionary', str(context.exception)) + def test_builder_ondisk_config(self): + test_config_file = './data/test_config.yaml' + out = StringIO() + with redirect_stdout(out): + cfg = init_configs(args=f'--config {test_config_file}'.split()) + self.assertIsInstance(cfg, Namespace) + self.assertEqual(cfg.project_name, 'dataset-ondisk-json') + self.assertEqual(cfg.dataset, + {'path': ['sample.json'], 'type': 'ondisk'}) + self.assertEqual(not cfg.dataset_path, True) + + def test_builder_ondisk_config_list(self): + test_config_file = './data/test_config_list.yaml' + out = StringIO() + with redirect_stdout(out): + cfg = init_configs(args=f'--config {test_config_file}'.split()) + self.assertIsInstance(cfg, Namespace) + self.assertEqual(cfg.project_name, 'dataset-ondisk-list') + self.assertEqual(cfg.dataset,[ + {'path': ['sample.json'], 'type': 'ondisk'}, + {'path': ['sample.txt'], 'type': 'ondisk'}]) + self.assertEqual(not cfg.dataset_path, True) + if __name__ == '__main__': unittest.main() From 7b6b2bd1852f2f90eacbae89fb9ab312d764e3f9 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Mon, 6 Jan 2025 09:58:21 -0800 Subject: [PATCH 28/56] add support for max_sample_num in dataset configs; add tests --- configs/datasets/mixture.yaml | 18 +- configs/datasets/ondisk_json.yaml | 7 +- configs/datasets/ondisk_parquet.yaml | 7 +- configs/datasets/remote_arxiv.yaml | 13 +- configs/datasets/remote_commoncrawl.yaml | 15 +- configs/datasets/remote_huggingface.yaml | 13 +- configs/datasets/remote_modelscope.yaml | 13 +- configs/datasets/remote_wiki.yaml | 13 +- configs/datasets/validation.yaml | 5 +- data_juicer/core/data/dataset_builder.py | 52 +++-- data_juicer/core/data/load_strategy.py | 18 +- tests/core/data/test_config.yaml | 7 +- tests/core/data/test_config_list.yaml | 13 +- tests/core/test_dataset_builder.py | 242 ++++++++++++++++++----- 14 files changed, 307 insertions(+), 129 deletions(-) diff --git a/configs/datasets/mixture.yaml b/configs/datasets/mixture.yaml index b67a9b5e2..e28ae040c 100644 --- a/configs/datasets/mixture.yaml +++ b/configs/datasets/mixture.yaml @@ -1,10 +1,12 @@ project_name: 'dataset-mixture' dataset: - - type: 'local' - weight: 1.0 - path: - - 'path/to/json/file' - - type: 'local' - weight: 1.0 - files: - - 'path/to/csv/file' + max_sample_num: 10000 + configs: + - type: 'local' + weight: 1.0 + path: + - 'path/to/json/file' + - type: 'local' + weight: 1.0 + files: + - 'path/to/csv/file' diff --git a/configs/datasets/ondisk_json.yaml b/configs/datasets/ondisk_json.yaml index af3ac7049..19b0899d6 100644 --- a/configs/datasets/ondisk_json.yaml +++ b/configs/datasets/ondisk_json.yaml @@ -1,6 +1,7 @@ # global parameters project_name: 'dataset-ondisk-json' dataset: - - type: 'ondisk' - path: - - 'path/to/json/file' + configs: + - type: 'ondisk' + path: + - 'path/to/json/file' diff --git a/configs/datasets/ondisk_parquet.yaml b/configs/datasets/ondisk_parquet.yaml index 0e5256a20..08d0ae341 100644 --- a/configs/datasets/ondisk_parquet.yaml +++ b/configs/datasets/ondisk_parquet.yaml @@ -1,6 +1,7 @@ # global parameters project_name: 'dataset-ondisk-parquet' dataset: - - type: 'ondisk' - path: - - 'path/to/parquet/file' + configs: + - type: 'ondisk' + path: + - 'path/to/parquet/file' diff --git a/configs/datasets/remote_arxiv.yaml b/configs/datasets/remote_arxiv.yaml index 1d14596c2..febcdeeb4 100644 --- a/configs/datasets/remote_arxiv.yaml +++ b/configs/datasets/remote_arxiv.yaml @@ -1,9 +1,10 @@ # global parameters project_name: 'dataset-remote-arxiv' dataset: - - type: 'remote' - source: 'arxiv' - lang: 'en' - dump_date: 'latest' - force_download: false - url_limit: 2 + configs: + - type: 'remote' + source: 'arxiv' + lang: 'en' + dump_date: 'latest' + force_download: false + url_limit: 2 diff --git a/configs/datasets/remote_commoncrawl.yaml b/configs/datasets/remote_commoncrawl.yaml index ce9f7a049..8b57e077b 100644 --- a/configs/datasets/remote_commoncrawl.yaml +++ b/configs/datasets/remote_commoncrawl.yaml @@ -1,10 +1,11 @@ # global parameters project_name: 'dataset-remote-commoncrawl' dataset: - - type: 'remote' - source: 'commoncrawl' - start_snapshot: '2020-50' - end_snapshot: '2021-04' - aws: true - force_download: false - url_limit: 2 + configs: + - type: 'remote' + source: 'commoncrawl' + start_snapshot: '2020-50' + end_snapshot: '2021-04' + aws: true + force_download: false + url_limit: 2 diff --git a/configs/datasets/remote_huggingface.yaml b/configs/datasets/remote_huggingface.yaml index 13738a6c5..e2f0532f9 100644 --- a/configs/datasets/remote_huggingface.yaml +++ b/configs/datasets/remote_huggingface.yaml @@ -1,9 +1,10 @@ # global parameters project_name: 'dataset-remote-huggingface' dataset: - - type: 'remote' - source: 'huggingface' - path: "HuggingFaceFW/fineweb" - name: "CC-MAIN-2024-10" - split: "train" - limit: 1000 + configs: + - type: 'remote' + source: 'huggingface' + path: "HuggingFaceFW/fineweb" + name: "CC-MAIN-2024-10" + split: "train" + limit: 1000 diff --git a/configs/datasets/remote_modelscope.yaml b/configs/datasets/remote_modelscope.yaml index 88b76c461..62b3b2869 100644 --- a/configs/datasets/remote_modelscope.yaml +++ b/configs/datasets/remote_modelscope.yaml @@ -1,9 +1,10 @@ # global parameters project_name: 'dataset-remote-modelscope' dataset: - type: 'remote' - source: 'modelscope' - path: 'modelscope/clue' - subset_name: 'afqmc' - split: 'train' - limit: 1000 + configs: + - type: 'remote' + source: 'modelscope' + path: 'modelscope/clue' + subset_name: 'afqmc' + split: 'train' + limit: 1000 diff --git a/configs/datasets/remote_wiki.yaml b/configs/datasets/remote_wiki.yaml index 6e94c7549..a3eb4e0d0 100644 --- a/configs/datasets/remote_wiki.yaml +++ b/configs/datasets/remote_wiki.yaml @@ -1,9 +1,10 @@ # global parameters project_name: 'dataset-remote-wiki' dataset: - type: 'remote' - source: 'wiki' - lang: 'en' - dump_date: 'latest' - force_download: false - url_limit: 2 + configs: + - type: 'remote' + source: 'wiki' + lang: 'en' + dump_date: 'latest' + force_download: false + url_limit: 2 diff --git a/configs/datasets/validation.yaml b/configs/datasets/validation.yaml index dbe1cbedd..77947e48d 100644 --- a/configs/datasets/validation.yaml +++ b/configs/datasets/validation.yaml @@ -1,6 +1,7 @@ dataset: - type: ondisk - path: path/to/data.json + configs: + - type: ondisk + path: path/to/data.json validators: - type: conversation diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 2e2ec1e84..7b4e678c7 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -26,16 +26,19 @@ def __init__(self, cfg, executor_type): 'Unable to initialize dataset; should have one of ' 'dataset_path or dataset in configurations') - # dataset config could be a list or a single entry; retrofit - if not isinstance(ds_configs, list): - ds_configs = [ds_configs] - # validate dataset config for type constraints - # 1. ds_config should be a dictionary - # 2. ds_configs should only have one type - # 3. if type is REMOTE, there should only one ds_config + # 1. ds_config should have a 'configs' key + # 2. ds_config['configs'] should be a list + # 3. ds_configs should only have one type + # 4. if type is REMOTE, there should only one ds_config # TODO other constraints; ray dataset only supports ondisk, etc. - for ds_config in ds_configs: + if 'configs' not in ds_configs: + raise ConfigValidationError( + 'Dataset config should have a "configs" key') + if not isinstance(ds_configs['configs'], list): + raise ConfigValidationError( + 'Dataset config "configs" should be a list') + for ds_config in ds_configs['configs']: if type(ds_config) != dict: raise ConfigValidationError( 'Dataset config should be a dictionary') @@ -70,16 +73,19 @@ def __init__(self, cfg, executor_type): def load_dataset(self) -> Union[NestedDataset, RayDataset]: _datasets = [] - for f in self.load_strategies: + for stra in self.load_strategies: # load dataset with its load strategy - _dataset = f.load_data(self.cfg) + dataset = stra.load_data(self.cfg) + sampled = self.random_sample(dataset, stra.weight) + + # deal with sampling + if stra.sampling_strategy: + None # do data validation for validator in self.validators: - validator.validate(_dataset) - _datasets.append(_dataset) - - # handle data mixture; only supports ONDISK + validator.validate(sampled) + _datasets.append(sampled) return _datasets[0] @@ -99,7 +105,7 @@ def load_dataset_by_generated_config(cls, generated_dataset_config): return dataset -def rewrite_cli_datapath(dataset_path) -> List: +def rewrite_cli_datapath(dataset_path, max_sample_num=None) -> List: """ rewrite the dataset_path from CLI into proper dataset config format that is compatible with YAML config style; retrofitting CLI input @@ -107,18 +113,28 @@ def rewrite_cli_datapath(dataset_path) -> List: :param dataset_path: a dataset file or a dataset dir or a list of them, e.g. ` ds1.jsonl ds2_dir ds3_file.json` + :param max_sample_num: the maximum number of samples to load :return: list of dataset configs """ paths, weights = parse_cli_datapath(dataset_path) - ret = [] + ret = ({ + 'configs': [], + 'max_sample_num': max_sample_num + } if max_sample_num else { + 'configs': [] + }) for p, w in zip(paths, weights): if os.path.isdir(p) or os.path.isfile(p): # local files - ret.append({'type': 'ondisk', 'path': [p], 'weight': w}) + ret['configs'].append({'type': 'ondisk', 'path': [p], 'weight': w}) elif (not is_absolute_path(p) and not p.startswith('.') and p.count('/') <= 1): # remote huggingface - ret.append({'type': 'huggingface', 'path': p, 'split': 'train'}) + ret['configs'].append({ + 'type': 'huggingface', + 'path': p, + 'split': 'train' + }) else: # raise ValueError( diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 0b0377739..723f205b8 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -49,6 +49,7 @@ class DataLoadStrategy(ABC, ConfigValidator): def __init__(self, ds_config: Dict): self.validate_config(ds_config) self.ds_config = ds_config + self.weight = ds_config.get('weight', 1.0) # default weight is 1.0 @abstractmethod def load_data(self, cfg: Namespace) -> Union[DJDataset, RayDataset]: @@ -197,7 +198,8 @@ class RayHuggingfaceDataLoadStrategy(RayDataLoadStrategy): } def load_data(self, cfg: Namespace): - pass + raise NotImplementedError( + 'Huggingface data load strategy is not implemented') @DataLoadStrategyRegistry.register('local', 'ondisk', '*') @@ -234,7 +236,8 @@ class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy): } def load_data(self, cfg: Namespace): - pass + raise NotImplementedError( + 'Huggingface data load strategy is not implemented') @DataLoadStrategyRegistry.register('local', 'remote', 'modelscope') @@ -244,7 +247,8 @@ class LocalModelScopeDataLoadStrategy(LocalDataLoadStrategy): """ def load_data(self, cfg: Namespace): - pass + raise NotImplementedError( + 'ModelScope data load strategy is not implemented') @DataLoadStrategyRegistry.register('local', 'remote', 'arxiv') @@ -262,7 +266,8 @@ class LocalArxivDataLoadStrategy(LocalDataLoadStrategy): } def load_data(self, cfg: Namespace): - pass + raise NotImplementedError( + 'Arxiv data load strategy is not implemented') @DataLoadStrategyRegistry.register('local', 'remote', 'wiki') @@ -280,7 +285,7 @@ class LocalWikiDataLoadStrategy(LocalDataLoadStrategy): } def load_data(self, cfg: Namespace): - pass + raise NotImplementedError('Wiki data load strategy is not implemented') @DataLoadStrategyRegistry.register('local', 'remote', 'commoncrawl') @@ -304,4 +309,5 @@ class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy): } def load_data(self, cfg: Namespace): - pass + raise NotImplementedError( + 'CommonCrawl data load strategy is not implemented') diff --git a/tests/core/data/test_config.yaml b/tests/core/data/test_config.yaml index 9620bed65..fe10ab1be 100644 --- a/tests/core/data/test_config.yaml +++ b/tests/core/data/test_config.yaml @@ -1,5 +1,6 @@ project_name: 'dataset-ondisk-json' dataset: - type: 'ondisk' - path: - - 'sample.json' \ No newline at end of file + configs: + - type: 'ondisk' + path: + - 'sample.json' \ No newline at end of file diff --git a/tests/core/data/test_config_list.yaml b/tests/core/data/test_config_list.yaml index 61ad61162..ed8eeef47 100644 --- a/tests/core/data/test_config_list.yaml +++ b/tests/core/data/test_config_list.yaml @@ -1,8 +1,9 @@ project_name: 'dataset-ondisk-list' dataset: - - type: 'ondisk' - path: - - 'sample.json' - - type: 'ondisk' - path: - - 'sample.txt' \ No newline at end of file + configs: + - type: 'ondisk' + path: + - 'sample.json' + - type: 'ondisk' + path: + - 'sample.txt' \ No newline at end of file diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index a7d56734c..e46ca37ce 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -31,21 +31,27 @@ def test_rewrite_cli_datapath_local_single_file(self): dataset_path = "./data/sample.txt" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}], ans) + {'configs': [ + {'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}]}, + ans) def test_rewrite_cli_datapath_local_directory(self): dataset_path = "./data" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}], ans) + {'configs': [ + {'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}]}, + ans) def test_rewrite_cli_datapath_hf(self): dataset_path = "hf-internal-testing/librispeech_asr_dummy" ans = rewrite_cli_datapath(dataset_path) - self.assertEqual([{'path': 'hf-internal-testing/librispeech_asr_dummy', - 'split': 'train', - 'type': 'huggingface'}], - ans) + self.assertEqual( + {'configs': [ + {'path': 'hf-internal-testing/librispeech_asr_dummy', + 'split': 'train', + 'type': 'huggingface'}]}, + ans) def test_rewrite_cli_datapath_local_wrong_files(self): dataset_path = "./missingDir" @@ -56,8 +62,9 @@ def test_rewrite_cli_datapath_with_weights(self): dataset_path = "0.5 ./data/sample.json ./data/sample.txt" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'path': ['./data/sample.json'], 'type': 'ondisk', 'weight': 0.5}, - {'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}], + {'configs': [ + {'path': ['./data/sample.json'], 'type': 'ondisk', 'weight': 0.5}, + {'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}]}, ans) @patch('os.path.isdir') @@ -68,19 +75,23 @@ def test_rewrite_cli_datapath_local_files(self, mock_isfile, mock_isdir): mock_isdir.side_effect = lambda x: x.endswith('_dir') dataset_path = "1.0 ds1.jsonl 2.0 ds2_dir 3.0 ds3.jsonl" - expected = [ - {'type': 'ondisk', 'path': ['ds1.jsonl'], 'weight': 1.0}, - {'type': 'ondisk', 'path': ['ds2_dir'], 'weight': 2.0}, - {'type': 'ondisk', 'path': ['ds3.jsonl'], 'weight': 3.0} - ] + expected = { + 'configs': [ + {'type': 'ondisk', 'path': ['ds1.jsonl'], 'weight': 1.0}, + {'type': 'ondisk', 'path': ['ds2_dir'], 'weight': 2.0}, + {'type': 'ondisk', 'path': ['ds3.jsonl'], 'weight': 3.0} + ] + } result = rewrite_cli_datapath(dataset_path) self.assertEqual(result, expected) def test_rewrite_cli_datapath_huggingface(self): dataset_path = "1.0 huggingface/dataset" - expected = [ - {'type': 'huggingface', 'path': 'huggingface/dataset', 'split': 'train'} - ] + expected = { + 'configs': [ + {'type': 'huggingface', 'path': 'huggingface/dataset', 'split': 'train'} + ] + } result = rewrite_cli_datapath(dataset_path) self.assertEqual(result, expected) @@ -89,6 +100,38 @@ def test_rewrite_cli_datapath_invalid(self): with self.assertRaises(ValueError): rewrite_cli_datapath(dataset_path) + def test_rewrite_cli_datapath_with_max_samples(self): + """Test rewriting CLI datapath with max_sample_num""" + dataset_path = "./data/sample.txt" + max_sample_num = 1000 + + result = rewrite_cli_datapath(dataset_path, max_sample_num) + + expected = { + 'configs': [{ + 'type': 'ondisk', + 'path': ['./data/sample.txt'], + 'weight': 1.0 + }], + 'max_sample_num': 1000 + } + self.assertEqual(result, expected) + + def test_rewrite_cli_datapath_without_max_samples(self): + """Test rewriting CLI datapath without max_sample_num""" + dataset_path = "./data/sample.txt" + + result = rewrite_cli_datapath(dataset_path) + + expected = { + 'configs': [{ + 'type': 'ondisk', + 'path': ['./data/sample.txt'], + 'weight': 1.0 + }] + } + self.assertEqual(result, expected) + def test_parse_cli_datapath(self): dataset_path = "1.0 ds1.jsonl 2.0 ds2_dir 3.0 ds3.jsonl" expected_paths = ['ds1.jsonl', 'ds2_dir', 'ds3.jsonl'] @@ -206,8 +249,12 @@ def test_builder_single_dataset_config(self): """Test handling of single dataset configuration""" # Setup single dataset config self.base_cfg.dataset = { - 'type': 'ondisk', - 'path': 'test.jsonl' + 'configs': [ + { + 'type': 'ondisk', + 'path': 'test.jsonl' + } + ] } builder = DatasetBuilder(self.base_cfg, self.executor_type) @@ -224,16 +271,18 @@ def test_builder_single_dataset_config(self): def test_builder_multiple_dataset_config(self): """Test handling of multiple dataset configurations""" # Setup multiple dataset config - self.base_cfg.dataset = [ - { - 'type': 'ondisk', - 'path': 'test1.jsonl' - }, - { - 'type': 'ondisk', - 'path': 'test2.jsonl' - } - ] + self.base_cfg.dataset = { + 'configs': [ + { + 'type': 'ondisk', + 'path': 'test1.jsonl' + }, + { + 'type': 'ondisk', + 'path': 'test2.jsonl' + } + ] + } builder = DatasetBuilder(self.base_cfg, self.executor_type) @@ -255,16 +304,18 @@ def test_builder_none_dataset_config(self): def test_builder_mixed_dataset_types(self): """Test validation of mixed dataset types""" - self.base_cfg.dataset = [ - { - 'type': 'ondisk', - 'path': 'test1.jsonl' - }, - { - 'type': 'remote', - 'source': 'some_source' - } - ] + self.base_cfg.dataset = { + 'configs': [ + { + 'type': 'ondisk', + 'path': 'test1.jsonl' + }, + { + 'type': 'remote', + 'source': 'some_source' + } + ] + } with self.assertRaises(ConfigValidationError) as context: DatasetBuilder(self.base_cfg, self.executor_type) @@ -273,16 +324,18 @@ def test_builder_mixed_dataset_types(self): def test_builder_multiple_remote_datasets(self): """Test validation of multiple remote datasets""" - self.base_cfg.dataset = [ - { - 'type': 'remote', - 'source': 'source1' + self.base_cfg.dataset = { + 'configs': [ + { + 'type': 'remote', + 'source': 'source1' }, { 'type': 'remote', 'source': 'source2' - } - ] + } + ] + } with self.assertRaises(ConfigValidationError) as context: DatasetBuilder(self.base_cfg, self.executor_type) @@ -291,7 +344,9 @@ def test_builder_multiple_remote_datasets(self): def test_builder_empty_dataset_config(self): """Test handling of empty dataset configuration""" - self.base_cfg.dataset = [] + self.base_cfg.dataset = { + 'configs': [] + } with self.assertRaises(ConfigValidationError) as context: DatasetBuilder(self.base_cfg, self.executor_type) @@ -316,7 +371,7 @@ def test_builder_ondisk_config(self): self.assertIsInstance(cfg, Namespace) self.assertEqual(cfg.project_name, 'dataset-ondisk-json') self.assertEqual(cfg.dataset, - {'path': ['sample.json'], 'type': 'ondisk'}) + {'configs': [{'path': ['sample.json'], 'type': 'ondisk'}]}) self.assertEqual(not cfg.dataset_path, True) def test_builder_ondisk_config_list(self): @@ -326,10 +381,99 @@ def test_builder_ondisk_config_list(self): cfg = init_configs(args=f'--config {test_config_file}'.split()) self.assertIsInstance(cfg, Namespace) self.assertEqual(cfg.project_name, 'dataset-ondisk-list') - self.assertEqual(cfg.dataset,[ - {'path': ['sample.json'], 'type': 'ondisk'}, - {'path': ['sample.txt'], 'type': 'ondisk'}]) + self.assertEqual(cfg.dataset, + {'configs': [ + {'path': ['sample.json'], 'type': 'ondisk'}, + {'path': ['sample.txt'], 'type': 'ondisk'} + ]}) self.assertEqual(not cfg.dataset_path, True) + def test_builder_with_max_samples(self): + """Test DatasetBuilder with max_sample_num""" + self.base_cfg.dataset = { + 'configs': [{ + 'type': 'ondisk', + 'path': ['test.jsonl'], + 'weight': 1.0 + }], + 'max_sample_num': 1000 + } + + builder = DatasetBuilder(self.base_cfg, self.executor_type) + + self.assertEqual(len(builder.load_strategies), 1) + self.assertEqual(builder.max_sample_num, 1000) + + def test_builder_without_max_samples(self): + """Test DatasetBuilder without max_sample_num""" + self.base_cfg.dataset = { + 'configs': [{ + 'type': 'ondisk', + 'path': ['test.jsonl'], + 'weight': 1.0 + }] + } + + builder = DatasetBuilder(self.base_cfg, self.executor_type) + + self.assertEqual(len(builder.load_strategies), 1) + self.assertIsNone(builder.max_sample_num) + + def test_mixed_dataset_configs(self): + """Test handling of mixed dataset configurations""" + self.base_cfg.dataset = { + 'configs': [ + { + 'type': 'ondisk', + 'path': ['test1.jsonl'], + 'weight': 1.0 + }, + { + 'type': 'ondisk', + 'path': ['test2.jsonl'], + 'weight': 2.0 + } + ], + 'max_sample_num': 500 + } + + builder = DatasetBuilder(self.base_cfg, self.executor_type) + + self.assertEqual(len(builder.load_strategies), 2) + self.assertEqual(builder.max_sample_num, 500) + self.assertEqual( + builder.load_strategies[0].ds_config['weight'], + 1.0 + ) + self.assertEqual( + builder.load_strategies[1].ds_config['weight'], + 2.0 + ) + + def test_invalid_max_sample_num(self): + """Test handling of invalid max_sample_num""" + invalid_values = [-1, 0, "100", None] + + for value in invalid_values: + self.base_cfg.dataset = { + 'configs': [{ + 'type': 'ondisk', + 'path': ['test.jsonl'], + 'weight': 1.0 + }], + 'max_sample_num': value + } + + if value is not None and value <= 0: + with self.assertRaises(ConfigValidationError): + DatasetBuilder(self.base_cfg, self.executor_type) + elif value == "100": + with self.assertRaises(ConfigValidationError): + DatasetBuilder(self.base_cfg, self.executor_type) + else: + builder = DatasetBuilder(self.base_cfg, self.executor_type) + if value is None: + self.assertIsNone(builder.max_sample_num) + if __name__ == '__main__': unittest.main() From 161f0598c5529f59a8f81b2a95ddec5951ee3f94 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Mon, 6 Jan 2025 10:32:59 -0800 Subject: [PATCH 29/56] fix test cases and update dataset builder code --- data_juicer/core/data/dataset_builder.py | 34 ++++++++++++++++-------- tests/core/test_dataset_builder.py | 24 +++++++---------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 7b4e678c7..a26af68e4 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -11,6 +11,9 @@ class DatasetBuilder(object): + """ + DatasetBuilder is a class that builds a dataset from a configuration. + """ def __init__(self, cfg, executor_type): self.cfg = cfg @@ -19,7 +22,7 @@ def __init__(self, cfg, executor_type): # defaults to use dataset_path if cfg.dataset_path is not None: ds_configs = rewrite_cli_datapath(cfg.dataset_path) - elif cfg.dataset not in (None, []): + elif cfg.dataset is not None: ds_configs = cfg.dataset else: raise ConfigValidationError( @@ -27,32 +30,41 @@ def __init__(self, cfg, executor_type): 'dataset_path or dataset in configurations') # validate dataset config for type constraints - # 1. ds_config should have a 'configs' key - # 2. ds_config['configs'] should be a list - # 3. ds_configs should only have one type - # 4. if type is REMOTE, there should only one ds_config # TODO other constraints; ray dataset only supports ondisk, etc. + if type(ds_configs) != dict: + raise ConfigValidationError( + 'Dataset config should be a dictionary') if 'configs' not in ds_configs: raise ConfigValidationError( 'Dataset config should have a "configs" key') - if not isinstance(ds_configs['configs'], list): + if (not isinstance(ds_configs['configs'], list) + or len(ds_configs['configs']) == 0): raise ConfigValidationError( - 'Dataset config "configs" should be a list') + 'Dataset config "configs" should be a non-empty list') + if ('max_sample_num' in ds_configs + and (type(ds_configs['max_sample_num']) != int + or ds_configs['max_sample_num'] <= 0)): + raise ConfigValidationError( + 'Dataset config "max_sample_num" should be a positive integer') for ds_config in ds_configs['configs']: if type(ds_config) != dict: raise ConfigValidationError( - 'Dataset config should be a dictionary') - types = [ds_config.get('type', None) for ds_config in ds_configs] + 'Dataset configs should be dictionaries') + types = [ + ds_config.get('type', None) for ds_config in ds_configs['configs'] + ] if len(set(types)) > 1: raise ConfigValidationError( 'Mixture of diff types (ONDISK/REMOTE/...) are not supported') - if types[0] == 'remote' and len(ds_configs) > 1: + if types[0] == 'remote' and len(ds_configs['configs']) > 1: raise ConfigValidationError( 'Multiple remote datasets are not supported') + self.max_sample_num = ds_configs.get('max_sample_num', None) + # initialize the data load strategies self.load_strategies = [] - for ds_config in ds_configs: + for ds_config in ds_configs['configs']: # initialize data loading strategy data_type = ds_config.get('type', None) data_source = ds_config.get('source', None) diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index e46ca37ce..cde927c52 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -329,10 +329,10 @@ def test_builder_multiple_remote_datasets(self): { 'type': 'remote', 'source': 'source1' - }, - { - 'type': 'remote', - 'source': 'source2' + }, + { + 'type': 'remote', + 'source': 'source2' } ] } @@ -351,7 +351,7 @@ def test_builder_empty_dataset_config(self): with self.assertRaises(ConfigValidationError) as context: DatasetBuilder(self.base_cfg, self.executor_type) - self.assertIn('dataset_path or dataset', str(context.exception)) + self.assertIn('non-empty list', str(context.exception)) def test_builder_invalid_dataset_config_type(self): """Test handling of invalid dataset configuration type""" @@ -464,16 +464,10 @@ def test_invalid_max_sample_num(self): 'max_sample_num': value } - if value is not None and value <= 0: - with self.assertRaises(ConfigValidationError): - DatasetBuilder(self.base_cfg, self.executor_type) - elif value == "100": - with self.assertRaises(ConfigValidationError): - DatasetBuilder(self.base_cfg, self.executor_type) - else: - builder = DatasetBuilder(self.base_cfg, self.executor_type) - if value is None: - self.assertIsNone(builder.max_sample_num) + with self.assertRaises(ConfigValidationError) as context: + DatasetBuilder(self.base_cfg, self.executor_type) + self.assertIn('should be a positive integer', + str(context.exception)) if __name__ == '__main__': unittest.main() From afe906d100537ca157339db2d8864201d13fba38 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 8 Jan 2025 14:09:46 -0800 Subject: [PATCH 30/56] handle weights and sample_nums --- configs/datasets/mixture.yaml | 6 +- configs/datasets/ondisk_json.yaml | 3 +- configs/datasets/ondisk_parquet.yaml | 3 +- data_juicer/core/data/dataset_builder.py | 67 ++++++++++++++++----- data_juicer/core/data/load_strategy.py | 59 +++++++++++------- data_juicer/core/executor/local_executor.py | 2 +- 6 files changed, 96 insertions(+), 44 deletions(-) diff --git a/configs/datasets/mixture.yaml b/configs/datasets/mixture.yaml index e28ae040c..9fe18cccb 100644 --- a/configs/datasets/mixture.yaml +++ b/configs/datasets/mixture.yaml @@ -4,9 +4,7 @@ dataset: configs: - type: 'local' weight: 1.0 - path: - - 'path/to/json/file' + path: 'path/to/json/file' - type: 'local' weight: 1.0 - files: - - 'path/to/csv/file' + files: 'path/to/csv/file' diff --git a/configs/datasets/ondisk_json.yaml b/configs/datasets/ondisk_json.yaml index 19b0899d6..a01e3b5a1 100644 --- a/configs/datasets/ondisk_json.yaml +++ b/configs/datasets/ondisk_json.yaml @@ -3,5 +3,4 @@ project_name: 'dataset-ondisk-json' dataset: configs: - type: 'ondisk' - path: - - 'path/to/json/file' + path: 'path/to/json/file' diff --git a/configs/datasets/ondisk_parquet.yaml b/configs/datasets/ondisk_parquet.yaml index 08d0ae341..e0f2fb144 100644 --- a/configs/datasets/ondisk_parquet.yaml +++ b/configs/datasets/ondisk_parquet.yaml @@ -3,5 +3,4 @@ project_name: 'dataset-ondisk-parquet' dataset: configs: - type: 'ondisk' - path: - - 'path/to/parquet/file' + path: 'path/to/parquet/file' diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index a26af68e4..4545f69cc 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -2,12 +2,16 @@ import shlex from typing import List, Tuple, Union +import numpy as np +from datasets import concatenate_datasets + from data_juicer.core.data import NestedDataset from data_juicer.core.data.config_validator import ConfigValidationError from data_juicer.core.data.data_validator import DataValidatorRegistry from data_juicer.core.data.load_strategy import DataLoadStrategyRegistry from data_juicer.core.data.ray_dataset import RayDataset from data_juicer.utils.file_utils import is_absolute_path +from data_juicer.utils.sample import random_sample class DatasetBuilder(object): @@ -60,8 +64,6 @@ def __init__(self, cfg, executor_type): raise ConfigValidationError( 'Multiple remote datasets are not supported') - self.max_sample_num = ds_configs.get('max_sample_num', None) - # initialize the data load strategies self.load_strategies = [] for ds_config in ds_configs['configs']: @@ -70,7 +72,16 @@ def __init__(self, cfg, executor_type): data_source = ds_config.get('source', None) self.load_strategies.append( DataLoadStrategyRegistry.get_strategy_class( - self.executor_type, data_type, data_source)(ds_config)) + self.executor_type, data_type, data_source)(ds_config, + cfg=self.cfg)) + + # initialzie the sample numbers + self.max_sample_num = ds_configs.get('max_sample_num', None) + # get weights and sample numbers + if self.max_sample_num is not None: + self.weights = [stra.weight for stra in self.load_strategies] + self.sample_numbers = get_sample_numbers(self.weights, + self.max_sample_num) # initialize data validators self.validators = [] @@ -82,24 +93,30 @@ def __init__(self, cfg, executor_type): if validator_cls: self.validators.append(validator_cls(validator_config)) - def load_dataset(self) -> Union[NestedDataset, RayDataset]: + def load_dataset(self, **kwargs) -> Union[NestedDataset, RayDataset]: _datasets = [] - for stra in self.load_strategies: + # load datasets with sample numbers + for stra, weight, sample_num in zip(self.load_strategies, self.weights, + self.sample_numbers): # load dataset with its load strategy - dataset = stra.load_data(self.cfg) - sampled = self.random_sample(dataset, stra.weight) - - # deal with sampling - if stra.sampling_strategy: - None + dataset = stra.load_data(**kwargs) # do data validation for validator in self.validators: - validator.validate(sampled) - _datasets.append(sampled) + validator.validate(dataset) + + # do data sampling, if necessary + if self.max_sample_num is not None: + dataset = random_sample(dataset, weight, sample_num) - return _datasets[0] + _datasets.append(dataset) + + # handle data mixture + if self.executor_type == 'local': + return NestedDataset(concatenate_datasets(_datasets)) + elif self.executor_type == 'ray': + return RayDataset(_datasets[0], ) @classmethod def load_dataset_by_generated_config(cls, generated_dataset_config): @@ -189,3 +206,25 @@ def parse_cli_datapath(dataset_path) -> Tuple[List[str], List[float]]: prefixes.append(value) return prefixes, weights + + +def get_sample_numbers(weights, max_sample_num): + sample_numbers = [0] * len(weights) + + # Normalize weights + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + sample_num_per_dataset = [ + int(np.ceil(max_sample_num * weight)) for weight in weights + ] + + # Adjust + acc_sample_numbers = 0 + for i in range(len(sample_num_per_dataset)): + sample_numbers[i] = min(sample_num_per_dataset[i], + max_sample_num - acc_sample_numbers) + acc_sample_numbers += sample_numbers[i] + + return sample_numbers diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 723f205b8..ae094ca51 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -4,9 +4,13 @@ from dataclasses import dataclass from typing import Dict, Optional, Type, Union +import datasets + from data_juicer.core.data import DJDataset, RayDataset from data_juicer.core.data.config_validator import ConfigValidator from data_juicer.download.downloader import validate_snapshot_format +from data_juicer.format.formatter import unify_format +from data_juicer.format.load import load_formatter from data_juicer.utils.lazy_loader import LazyLoader ray = LazyLoader('ray', 'ray') @@ -46,13 +50,14 @@ class DataLoadStrategy(ABC, ConfigValidator): abstract class for data load strategy """ - def __init__(self, ds_config: Dict): + def __init__(self, ds_config: Dict, cfg: Namespace): self.validate_config(ds_config) self.ds_config = ds_config + self.cfg = cfg self.weight = ds_config.get('weight', 1.0) # default weight is 1.0 @abstractmethod - def load_data(self, cfg: Namespace) -> Union[DJDataset, RayDataset]: + def load_data(self, **kwargs) -> Union[DJDataset, RayDataset]: pass @@ -144,7 +149,7 @@ class RayDataLoadStrategy(DataLoadStrategy): """ @abstractmethod - def load_data(self) -> RayDataset: + def load_data(self, **kwargs) -> RayDataset: pass @@ -154,7 +159,7 @@ class LocalDataLoadStrategy(DataLoadStrategy): """ @abstractmethod - def load_data(self, cfg: Namespace) -> DJDataset: + def load_data(self, **kwargs) -> DJDataset: pass @@ -177,12 +182,12 @@ class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): CONFIG_VALIDATION_RULES = { 'required_fields': ['path'], 'field_types': { - 'path': (str, list) # Can be string or list + 'path': str }, 'custom_validators': {} } - def load_data(self, cfg: Namespace): + def load_data(self, **kwargs): return rd.read_json(self.ds_config.path) @@ -192,12 +197,12 @@ class RayHuggingfaceDataLoadStrategy(RayDataLoadStrategy): CONFIG_VALIDATION_RULES = { 'required_fields': ['path'], 'field_types': { - 'path': (str, list) # Can be string or list + 'path': str }, 'custom_validators': {} } - def load_data(self, cfg: Namespace): + def load_data(self, **kwargs): raise NotImplementedError( 'Huggingface data load strategy is not implemented') @@ -212,13 +217,18 @@ class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy): CONFIG_VALIDATION_RULES = { 'required_fields': ['path'], 'field_types': { - 'path': (str, list) # Can be string or list + 'path': str }, 'custom_validators': {} } - def load_data(self, cfg: Namespace): - pass + def load_data(self, **kwargs): + # use proper formatter to load data; kwargs are ignored + formatter = load_formatter(dataset_path=self.ds_config.path, + suffixes=self.cfg.suffixes, + text_keys=self.cfg.text_keys, + add_suffix=self.cfg.add_suffix) + return formatter.load_data() @DataLoadStrategyRegistry.register('local', 'remote', 'huggingface') @@ -229,15 +239,22 @@ class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy): CONFIG_VALIDATION_RULES = { 'required_fields': ['path'], + 'optional_fields': ['split', 'limit', 'name'], 'field_types': { - 'path': (str, list) # Can be string or list + 'path': str }, 'custom_validators': {} } - def load_data(self, cfg: Namespace): - raise NotImplementedError( - 'Huggingface data load strategy is not implemented') + def load_data(self, **kwargs): + num_proc = kwargs.get('num_proc', 1) + ds = datasets.load_dataset(self.ds_config.path, + split=self.ds_config.split, + name=self.ds_config.name, + limit=self.ds_config.limit, + num_proc=num_proc, + **kwargs) + ds = unify_format(ds, text_keys=self.text_keys, num_proc=num_proc) @DataLoadStrategyRegistry.register('local', 'remote', 'modelscope') @@ -246,7 +263,7 @@ class LocalModelScopeDataLoadStrategy(LocalDataLoadStrategy): data load strategy for ModelScope dataset for LocalExecutor """ - def load_data(self, cfg: Namespace): + def load_data(self): raise NotImplementedError( 'ModelScope data load strategy is not implemented') @@ -260,12 +277,12 @@ class LocalArxivDataLoadStrategy(LocalDataLoadStrategy): CONFIG_VALIDATION_RULES = { 'required_fields': ['path'], 'field_types': { - 'path': (str, list) # Can be string or list + 'path': (str) # has to be a string }, 'custom_validators': {} } - def load_data(self, cfg: Namespace): + def load_data(self): raise NotImplementedError( 'Arxiv data load strategy is not implemented') @@ -279,12 +296,12 @@ class LocalWikiDataLoadStrategy(LocalDataLoadStrategy): CONFIG_VALIDATION_RULES = { 'required_fields': ['path'], 'field_types': { - 'path': (str, list) # Can be string or list + 'path': str }, 'custom_validators': {} } - def load_data(self, cfg: Namespace): + def load_data(self): raise NotImplementedError('Wiki data load strategy is not implemented') @@ -308,6 +325,6 @@ class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy): } } - def load_data(self, cfg: Namespace): + def load_data(self): raise NotImplementedError( 'CommonCrawl data load strategy is not implemented') diff --git a/data_juicer/core/executor/local_executor.py b/data_juicer/core/executor/local_executor.py index d3ddaf736..5fba09169 100644 --- a/data_juicer/core/executor/local_executor.py +++ b/data_juicer/core/executor/local_executor.py @@ -113,7 +113,7 @@ def run(self, logger.info('Loading dataset from dataset builder...') if load_data_np is None: load_data_np = self.cfg.np - dataset = self.dataset_builder.load_dataset() + dataset = self.dataset_builder.load_dataset(num_proc=load_data_np) # 2. extract processes and optimize their orders logger.info('Preparing process operators...') From 1217e618467a626bebcadc0f7cbe57ee3a43a808 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 9 Jan 2025 11:56:37 -0800 Subject: [PATCH 31/56] support ExecutorType enum --- data_juicer/core/data/data_validator.py | 3 +- data_juicer/core/data/dataset_builder.py | 38 +++++++++---- data_juicer/core/data/load_strategy.py | 39 +++++++------ data_juicer/core/executor/base.py | 7 +++ data_juicer/format/load.py | 50 +--------------- tests/core/test_dataload_strategy.py | 72 +++++++++++++++--------- 6 files changed, 104 insertions(+), 105 deletions(-) diff --git a/data_juicer/core/data/data_validator.py b/data_juicer/core/data/data_validator.py index 19a1f4d64..5f0d9346d 100644 --- a/data_juicer/core/data/data_validator.py +++ b/data_juicer/core/data/data_validator.py @@ -157,7 +157,8 @@ def validate(self, dataset: Union[NestedDataset, RayDataset]) -> None: sample = dataset.data.take(sample_size) values = [row[field] for row in sample] else: - raise NotImplementedError + raise NotImplementedError( + f'Unsupported dataset type: {type(dataset)}') # Check for missing values missing_count = sum(1 for v in values if v is None) diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 4545f69cc..92318e29c 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -1,5 +1,6 @@ import os import shlex +from argparse import Namespace from typing import List, Tuple, Union import numpy as np @@ -10,6 +11,7 @@ from data_juicer.core.data.data_validator import DataValidatorRegistry from data_juicer.core.data.load_strategy import DataLoadStrategyRegistry from data_juicer.core.data.ray_dataset import RayDataset +from data_juicer.core.executor.base import ExecutorType from data_juicer.utils.file_utils import is_absolute_path from data_juicer.utils.sample import random_sample @@ -19,11 +21,16 @@ class DatasetBuilder(object): DatasetBuilder is a class that builds a dataset from a configuration. """ - def __init__(self, cfg, executor_type): + def __init__(self, cfg: Namespace, executor_type: str): + # if generated_dataset_config present, prioritize + if cfg.generated_dataset_config: + self.use_generated_dataset_config = True + self.generated_dataset_config = cfg.generated_dataset_config + return + self.cfg = cfg - self.executor_type = executor_type + self.executor_type = ExecutorType(executor_type) - # defaults to use dataset_path if cfg.dataset_path is not None: ds_configs = rewrite_cli_datapath(cfg.dataset_path) elif cfg.dataset is not None: @@ -31,7 +38,8 @@ def __init__(self, cfg, executor_type): else: raise ConfigValidationError( 'Unable to initialize dataset; should have one of ' - 'dataset_path or dataset in configurations') + 'generated_dataset_configdataset_path or dataset ' + 'in configurations') # validate dataset config for type constraints # TODO other constraints; ray dataset only supports ondisk, etc. @@ -72,13 +80,13 @@ def __init__(self, cfg, executor_type): data_source = ds_config.get('source', None) self.load_strategies.append( DataLoadStrategyRegistry.get_strategy_class( - self.executor_type, data_type, data_source)(ds_config, - cfg=self.cfg)) + self.executor_type.value, data_type, + data_source)(ds_config, cfg=self.cfg)) # initialzie the sample numbers self.max_sample_num = ds_configs.get('max_sample_num', None) # get weights and sample numbers - if self.max_sample_num is not None: + if self.max_sample_num: self.weights = [stra.weight for stra in self.load_strategies] self.sample_numbers = get_sample_numbers(self.weights, self.max_sample_num) @@ -94,8 +102,12 @@ def __init__(self, cfg, executor_type): self.validators.append(validator_cls(validator_config)) def load_dataset(self, **kwargs) -> Union[NestedDataset, RayDataset]: - _datasets = [] + # if generated_dataset_config present, prioritize + if self.use_generated_dataset_config: + return DatasetBuilder.load_dataset_by_generated_config( + self.generated_dataset_config) + _datasets = [] # load datasets with sample numbers for stra, weight, sample_num in zip(self.load_strategies, self.weights, self.sample_numbers): @@ -107,16 +119,18 @@ def load_dataset(self, **kwargs) -> Union[NestedDataset, RayDataset]: validator.validate(dataset) # do data sampling, if necessary - if self.max_sample_num is not None: + if self.max_sample_num: dataset = random_sample(dataset, weight, sample_num) _datasets.append(dataset) # handle data mixture - if self.executor_type == 'local': + if self.executor_type == ExecutorType.LOCAL: return NestedDataset(concatenate_datasets(_datasets)) - elif self.executor_type == 'ray': - return RayDataset(_datasets[0], ) + elif self.executor_type == ExecutorType.RAY: + # TODO: support multiple datasets and mixing for ray + assert len(_datasets) == 1, 'Ray setup supports one dataset now' + return _datasets[0] @classmethod def load_dataset_by_generated_config(cls, generated_dataset_config): diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index ae094ca51..76c597d5c 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -8,6 +8,7 @@ from data_juicer.core.data import DJDataset, RayDataset from data_juicer.core.data.config_validator import ConfigValidator +from data_juicer.core.executor.base import ExecutorType from data_juicer.download.downloader import validate_snapshot_format from data_juicer.format.formatter import unify_format from data_juicer.format.load import load_formatter @@ -26,7 +27,7 @@ class StrategyKey: """ Immutable key for strategy registration with wildcard support """ - executor_type: str + executor_type: ExecutorType data_type: str data_source: str @@ -40,7 +41,8 @@ def matches(self, other: 'StrategyKey') -> bool: - '[seq]' matches any character in seq - '[!seq]' matches any character not in seq """ - return (fnmatch.fnmatch(other.executor_type, self.executor_type) + return (fnmatch.fnmatch(other.executor_type.value, + self.executor_type.value) and fnmatch.fnmatch(other.data_type, self.data_type) and fnmatch.fnmatch(other.data_source, self.data_source)) @@ -69,7 +71,7 @@ class DataLoadStrategyRegistry: @classmethod def get_strategy_class( - cls, executor_type: str, data_type: str, + cls, executor_type: ExecutorType, data_type: str, data_source: str) -> Optional[Type[DataLoadStrategy]]: """ Retrieve the most specific matching strategy @@ -79,7 +81,7 @@ def get_strategy_class( 2. Wildcard matches from most specific to most general """ # default to wildcard if not provided - executor_type = executor_type or '*' + executor_type = executor_type or ExecutorType.ANY data_type = data_type or '*' data_source = data_source or '*' @@ -119,7 +121,8 @@ def specificity_score(key: StrategyKey) -> int: return None @classmethod - def register(cls, executor_type: str, data_type: str, data_source: str): + def register(cls, executor_type: ExecutorType, data_type: str, + data_source: str): """ Decorator for registering data load strategies with wildcard support @@ -176,7 +179,7 @@ def load_data(self, **kwargs) -> DJDataset: # pass -@DataLoadStrategyRegistry.register('ray', 'ondisk', 'json') +@DataLoadStrategyRegistry.register(ExecutorType.RAY, 'ondisk', 'json') class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): CONFIG_VALIDATION_RULES = { @@ -188,10 +191,13 @@ class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): } def load_data(self, **kwargs): - return rd.read_json(self.ds_config.path) + dataset = rd.read_json(self.ds_config.path) + return RayDataset(dataset, + dataset_path=self.ds_config.path, + cfg=self.cfg) -@DataLoadStrategyRegistry.register('ray', 'remote', 'huggingface') +@DataLoadStrategyRegistry.register(ExecutorType.RAY, 'remote', 'huggingface') class RayHuggingfaceDataLoadStrategy(RayDataLoadStrategy): CONFIG_VALIDATION_RULES = { @@ -207,7 +213,7 @@ def load_data(self, **kwargs): 'Huggingface data load strategy is not implemented') -@DataLoadStrategyRegistry.register('local', 'ondisk', '*') +@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', '*') class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for on disk data for LocalExecutor @@ -223,15 +229,16 @@ class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy): } def load_data(self, **kwargs): - # use proper formatter to load data; kwargs are ignored + # use proper formatter to load data formatter = load_formatter(dataset_path=self.ds_config.path, suffixes=self.cfg.suffixes, text_keys=self.cfg.text_keys, - add_suffix=self.cfg.add_suffix) + add_suffix=self.cfg.add_suffix**kwargs) + # TODO more sophiscated localformatter routing return formatter.load_data() -@DataLoadStrategyRegistry.register('local', 'remote', 'huggingface') +@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'huggingface') class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for Huggingface dataset for LocalExecutor @@ -257,7 +264,7 @@ def load_data(self, **kwargs): ds = unify_format(ds, text_keys=self.text_keys, num_proc=num_proc) -@DataLoadStrategyRegistry.register('local', 'remote', 'modelscope') +@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'modelscope') class LocalModelScopeDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for ModelScope dataset for LocalExecutor @@ -268,7 +275,7 @@ def load_data(self): 'ModelScope data load strategy is not implemented') -@DataLoadStrategyRegistry.register('local', 'remote', 'arxiv') +@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'arxiv') class LocalArxivDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for arxiv dataset for LocalExecutor @@ -287,7 +294,7 @@ def load_data(self): 'Arxiv data load strategy is not implemented') -@DataLoadStrategyRegistry.register('local', 'remote', 'wiki') +@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'wiki') class LocalWikiDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for wiki dataset for LocalExecutor @@ -305,7 +312,7 @@ def load_data(self): raise NotImplementedError('Wiki data load strategy is not implemented') -@DataLoadStrategyRegistry.register('local', 'remote', 'commoncrawl') +@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'commoncrawl') class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for commoncrawl dataset for LocalExecutor diff --git a/data_juicer/core/executor/base.py b/data_juicer/core/executor/base.py index e1d6e2073..124f5a7b9 100644 --- a/data_juicer/core/executor/base.py +++ b/data_juicer/core/executor/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from enum import Enum from typing import Optional from jsonargparse import Namespace @@ -7,6 +8,12 @@ from data_juicer.config import init_configs +class ExecutorType(Enum): + LOCAL = 'local' + RAY = 'ray' + ANY = '*' + + class ExecutorBase(ABC): @abstractmethod diff --git a/data_juicer/format/load.py b/data_juicer/format/load.py index 4cfbc8939..5bfcf1804 100644 --- a/data_juicer/format/load.py +++ b/data_juicer/format/load.py @@ -1,56 +1,15 @@ import os -from data_juicer.format import MixtureFormatter, RemoteFormatter from data_juicer.format.formatter import FORMATTERS, BaseFormatter -from data_juicer.utils.file_utils import (find_files_with_suffix, - is_absolute_path) +from data_juicer.utils.file_utils import find_files_with_suffix def load_formatter(dataset_path, - generated_dataset_config=None, text_keys=None, - suffixes=[], + suffixes=None, add_suffix=False, **kwargs) -> BaseFormatter: """ - Load mixture formatter for multiple different data formats with an optional - weight(default 1.0) according to their formats. - - :param dataset_path: path to a dataset file or a dataset directory - :param generated_dataset_config: Configuration used to create a dataset. - The dataset will be created from this configuration if provided. - It must contain the `type` field to specify the dataset name. - :param text_keys: key names of field that stores sample text. - Default: None - :param suffixes: files with specified suffixes to be processed. - :param add_suffix: whether to add the file suffix to dataset meta - info - :return: a dataset formatter. - """ - if generated_dataset_config: - assert isinstance(generated_dataset_config, - dict) and 'type' in generated_dataset_config - args = generated_dataset_config.copy() - obj_name = args.pop('type') - args.update(kwargs) - - from .formatter import FORMATTERS - return FORMATTERS.modules[obj_name](**args) - - formatter = MixtureFormatter(dataset_path=dataset_path, - text_keys=text_keys, - suffixes=suffixes, - add_suffix=add_suffix, - **kwargs) - return formatter - - -def _load_formatter(dataset_path, - text_keys=None, - suffixes=None, - add_suffix=False, - **kwargs) -> BaseFormatter: - """ Load the appropriate formatter for different types of data formats. :param dataset_path: Path to dataset file or dataset directory @@ -90,11 +49,6 @@ def _load_formatter(dataset_path, add_suffix=add_suffix, **kwargs) - # try huggingface dataset hub - elif not is_absolute_path(dataset_path) and dataset_path.count('/') <= 1: - return RemoteFormatter(dataset_path, text_keys=text_keys, **kwargs) - - # no data else: raise ValueError(f'Unable to load the dataset from [{dataset_path}]. ' f'It might be because Data-Juicer doesn\'t support ' diff --git a/tests/core/test_dataload_strategy.py b/tests/core/test_dataload_strategy.py index f8afc466f..52b315ecc 100644 --- a/tests/core/test_dataload_strategy.py +++ b/tests/core/test_dataload_strategy.py @@ -2,6 +2,7 @@ from data_juicer.core.data.load_strategy import ( DataLoadStrategyRegistry, DataLoadStrategy, StrategyKey ) +from data_juicer.core.executor.base import ExecutorType class MockStrategy(DataLoadStrategy): def load_data(self): @@ -14,103 +15,118 @@ def setUp(self): def test_exact_match(self): # Register a specific strategy - @DataLoadStrategyRegistry.register('local', 'ondisk', 'json') + @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', 'json') class TestStrategy(MockStrategy): pass # Test exact match - strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.LOCAL, 'ondisk', 'json') self.assertEqual(strategy, TestStrategy) # Test no match - strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'csv') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.LOCAL, 'ondisk', 'csv') self.assertIsNone(strategy) def test_wildcard_matching(self): # Register strategies with different wildcard patterns - @DataLoadStrategyRegistry.register('local', 'ondisk', '*') + @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', '*') class AllFilesStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register('local', '*', '*') + @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, '*', '*') class AllLocalStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register('*', '*', '*') + @DataLoadStrategyRegistry.register(ExecutorType.ANY, '*', '*') class FallbackStrategy(MockStrategy): pass # Test specific matches - strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.LOCAL, 'ondisk', 'json') self.assertEqual(strategy, AllFilesStrategy) # Should match most specific wildcard - strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'remote', 'json') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.LOCAL, 'remote', 'json') self.assertEqual(strategy, AllLocalStrategy) # Should match second level wildcard - strategy = DataLoadStrategyRegistry.get_strategy_class('ray', 'remote', 'json') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.RAY, 'remote', 'json') self.assertEqual(strategy, FallbackStrategy) # Should match most general wildcard def test_specificity_priority(self): - @DataLoadStrategyRegistry.register('*', '*', '*') + @DataLoadStrategyRegistry.register(ExecutorType.ANY, '*', '*') class GeneralStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register('local', '*', '*') + @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, '*', '*') class LocalStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register('local', 'ondisk', '*') + @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', '*') class LocalOndiskStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register('local', 'ondisk', 'json') + @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', 'json') class ExactStrategy(MockStrategy): pass # Test matching priority - strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.LOCAL, 'ondisk', 'json') self.assertEqual(strategy, ExactStrategy) # Should match exact first - strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'csv') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.LOCAL, 'ondisk', 'csv') self.assertEqual(strategy, LocalOndiskStrategy) # Should match one wildcard - strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'remote', 'json') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.LOCAL, 'remote', 'json') self.assertEqual(strategy, LocalStrategy) # Should match two wildcards - strategy = DataLoadStrategyRegistry.get_strategy_class('ray', 'remote', 'json') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.RAY, 'remote', 'json') self.assertEqual(strategy, GeneralStrategy) # Should match general wildcard def test_pattern_matching(self): - @DataLoadStrategyRegistry.register('local', 'ondisk', '*.json') + @DataLoadStrategyRegistry.register( + ExecutorType.LOCAL, 'ondisk', '*.json') class JsonStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register('local', 'ondisk', 'data_[0-9]*') + @DataLoadStrategyRegistry.register( + ExecutorType.LOCAL, 'ondisk', 'data_[0-9]*') class NumberedDataStrategy(MockStrategy): pass # Test pattern matching - strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'test.json') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.LOCAL, 'ondisk', 'test.json') self.assertEqual(strategy, JsonStrategy) - strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'data_123') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.LOCAL, 'ondisk', 'data_123') self.assertEqual(strategy, NumberedDataStrategy) - strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'test.csv') + strategy = DataLoadStrategyRegistry.get_strategy_class( + ExecutorType.LOCAL, 'ondisk', 'test.csv') self.assertIsNone(strategy) def test_strategy_key_matches(self): # Test StrategyKey matching directly - wildcard_key = StrategyKey('*', 'ondisk', '*.json') - specific_key = StrategyKey('local', 'ondisk', 'test.json') + wildcard_key = StrategyKey(ExecutorType.ANY, 'ondisk', '*.json') + specific_key = StrategyKey(ExecutorType.LOCAL, 'ondisk', 'test.json') + # Exact keys don't match wildcards self.assertTrue(wildcard_key.matches(specific_key)) - self.assertFalse(specific_key.matches(wildcard_key)) # Exact keys don't match wildcards + self.assertFalse(specific_key.matches(wildcard_key)) # Test pattern matching - pattern_key = StrategyKey('local', '*', 'data_[0-9]*') - match_key = StrategyKey('local', 'ondisk', 'data_123') - no_match_key = StrategyKey('local', 'ondisk', 'data_abc') + pattern_key = StrategyKey(ExecutorType.LOCAL, '*', 'data_[0-9]*') + match_key = StrategyKey(ExecutorType.LOCAL, 'ondisk', 'data_123') + no_match_key = StrategyKey(ExecutorType.LOCAL, 'ondisk', 'data_abc') self.assertTrue(pattern_key.matches(match_key)) self.assertFalse(pattern_key.matches(no_match_key)) From 5dd17fec24fb2e35227ba647a50bc930465c1acd Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 9 Jan 2025 13:02:47 -0800 Subject: [PATCH 32/56] flip on DatasetBuilder; replace formatter --- data_juicer/core/__init__.py | 6 ++++-- data_juicer/core/data/dataset_builder.py | 2 +- data_juicer/core/executor/__init__.py | 9 ++++++--- data_juicer/core/executor/factory.py | 9 ++++----- data_juicer/core/executor/local_executor.py | 13 ++++--------- data_juicer/core/executor/ray_executor.py | 21 +++++++-------------- data_juicer/core/sandbox/factories.py | 2 +- demos/data_mixture/app.py | 5 ++--- tools/process_data.py | 3 +-- 9 files changed, 30 insertions(+), 40 deletions(-) diff --git a/data_juicer/core/__init__.py b/data_juicer/core/__init__.py index f5450fabc..89f4c6a28 100644 --- a/data_juicer/core/__init__.py +++ b/data_juicer/core/__init__.py @@ -1,7 +1,8 @@ from .adapter import Adapter from .analyzer import Analyzer from .data import NestedDataset -from .executor import ExecutorBase, ExecutorFactory, LocalExecutor, RayExecutor +from .executor import (Executor, ExecutorBase, ExecutorFactory, ExecutorType, + RayExecutor) from .exporter import Exporter from .monitor import Monitor from .tracer import Tracer @@ -11,9 +12,10 @@ 'Analyzer', 'NestedDataset', 'ExecutorFactory', - 'LocalExecutor', + 'Executor', 'RayExecutor', 'ExecutorBase', + 'ExecutorType', 'Exporter', 'Monitor', 'Tracer', diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 92318e29c..e8e2d33d4 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -21,7 +21,7 @@ class DatasetBuilder(object): DatasetBuilder is a class that builds a dataset from a configuration. """ - def __init__(self, cfg: Namespace, executor_type: str): + def __init__(self, cfg: Namespace, executor_type: str = 'local'): # if generated_dataset_config present, prioritize if cfg.generated_dataset_config: self.use_generated_dataset_config = True diff --git a/data_juicer/core/executor/__init__.py b/data_juicer/core/executor/__init__.py index 31c448137..a837ad474 100644 --- a/data_juicer/core/executor/__init__.py +++ b/data_juicer/core/executor/__init__.py @@ -1,6 +1,9 @@ -from .base import ExecutorBase +from .base import ExecutorBase, ExecutorType from .factory import ExecutorFactory -from .local_executor import LocalExecutor +from .local_executor import Executor from .ray_executor import RayExecutor -__all__ = ['ExecutorBase', 'ExecutorFactory', 'LocalExecutor', 'RayExecutor'] +__all__ = [ + 'ExecutorBase', 'ExecutorFactory', 'Executor', 'RayExecutor', + 'ExecutorType' +] diff --git a/data_juicer/core/executor/factory.py b/data_juicer/core/executor/factory.py index 3133e3b1a..a97e49291 100644 --- a/data_juicer/core/executor/factory.py +++ b/data_juicer/core/executor/factory.py @@ -1,16 +1,15 @@ from typing import Union -from .local_executor import LocalExecutor +from .local_executor import Executor from .ray_executor import RayExecutor class ExecutorFactory: @staticmethod - def create_executor( - executor_type: str) -> Union[LocalExecutor, RayExecutor]: - if executor_type == 'local': - return LocalExecutor() + def create_executor(executor_type: str) -> Union[Executor, RayExecutor]: + if executor_type in ('local', 'default'): + return Executor() elif executor_type == 'ray': return RayExecutor() # TODO: add nemo support diff --git a/data_juicer/core/executor/local_executor.py b/data_juicer/core/executor/local_executor.py index 5fba09169..f95ad5cb1 100644 --- a/data_juicer/core/executor/local_executor.py +++ b/data_juicer/core/executor/local_executor.py @@ -7,12 +7,12 @@ from loguru import logger from pydantic import PositiveInt +from data_juicer.core import ExecutorType from data_juicer.core.adapter import Adapter from data_juicer.core.data.dataset_builder import DatasetBuilder from data_juicer.core.executor import ExecutorBase from data_juicer.core.exporter import Exporter from data_juicer.core.tracer import Tracer -from data_juicer.format.load import load_formatter from data_juicer.ops import OPERATORS, load_ops from data_juicer.ops.op_fusion import fuse_operators from data_juicer.ops.selector import (FrequencySpecifiedFieldSelector, @@ -22,7 +22,7 @@ from data_juicer.utils.sample import random_sample -class LocalExecutor(ExecutorBase): +class Executor(ExecutorBase): """ This Executor class is used to process a specific dataset. @@ -53,13 +53,8 @@ def __init__(self, cfg: Optional[Namespace] = None): # setup dataset builder logger.info('Setting up dataset builder...') - self.formatter = load_formatter( - dataset_path=self.cfg.dataset_path, - generated_dataset_config=self.cfg.generated_dataset_config, - text_keys=self.cfg.text_keys, - suffixes=self.cfg.suffixes, - add_suffix=self.cfg.add_suffix) - self.dataset_builder = DatasetBuilder(cfg) + self.dataset_builder = DatasetBuilder(cfg, + executor_type=ExecutorType.LOCAL) # whether to use checkpoint mechanism. If it's true, Executor will # check if there are existing checkpoints first and try to load the diff --git a/data_juicer/core/executor/ray_executor.py b/data_juicer/core/executor/ray_executor.py index 2d0bd453e..0dc1734bd 100644 --- a/data_juicer/core/executor/ray_executor.py +++ b/data_juicer/core/executor/ray_executor.py @@ -7,7 +7,8 @@ from pydantic import PositiveInt from data_juicer.core.adapter import Adapter -from data_juicer.core.data.ray_dataset import RayDataset +from data_juicer.core.data.dataset_builder import DatasetBuilder +from data_juicer.core.executor import ExecutorType from data_juicer.ops import load_ops from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils.lazy_loader import LazyLoader @@ -60,6 +61,10 @@ def __init__(self, cfg=None): self.tmp_dir = os.path.join(self.work_dir, '.tmp', ray.get_runtime_context().get_job_id()) + # init dataset builder + self.datasetbuilder = DatasetBuilder(self.cfg, + executor_type=ExecutorType.RAY) + def run(self, load_data_np: Optional[PositiveInt] = None, skip_return=False): @@ -72,20 +77,8 @@ def run(self, """ # 1. load data logger.info('Loading dataset with Ray...') + dataset = self.datasetbuilder.load_dataset() - if self.cfg.get('generated_dataset_config', None): - generated_dataset_config = self.cfg.generated_dataset_config - assert isinstance(generated_dataset_config, - dict) and 'type' in generated_dataset_config - args = generated_dataset_config.copy() - obj_name = args.pop('type') - from data_juicer.format.formatter import FORMATTERS - dataset = FORMATTERS.modules[obj_name](**args).load_dataset() - else: - dataset = RayDataset.read_json(self.cfg.dataset_path) - - # convert all the path in dataset to absolute path - dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg) # 2. extract processes logger.info('Preparing process operators...') ops = load_ops(self.cfg.process) diff --git a/data_juicer/core/sandbox/factories.py b/data_juicer/core/sandbox/factories.py index c3aeeedc0..0e1b5ca87 100644 --- a/data_juicer/core/sandbox/factories.py +++ b/data_juicer/core/sandbox/factories.py @@ -1,5 +1,5 @@ from data_juicer.core import Analyzer as DJAnalyzer -from data_juicer.core.executor import LocalExecutor as DJExecutor +from data_juicer.core.executor import Executor as DJExecutor from data_juicer.core.sandbox.evaluators import (Gpt3QualityEvaluator, InceptionEvaluator, VBenchEvaluator) diff --git a/demos/data_mixture/app.py b/demos/data_mixture/app.py index ec649cdb7..cc8efefb6 100644 --- a/demos/data_mixture/app.py +++ b/demos/data_mixture/app.py @@ -3,7 +3,7 @@ import pandas as pd import streamlit as st -from data_juicer.format import load_formatter +from data_juicer.core.data.dataset_builder import DatasetBuilder if st.__version__ >= '1.23.0': data_editor = st.data_editor @@ -96,8 +96,7 @@ def mix_dataset(): ' '.join([str(weight), ds_file]) for ds_file, weight in zip(ds_files, weights) ]) - formatter = load_formatter(data_path) - df = pd.DataFrame(formatter.load_dataset()) + df = pd.DataFrame(DatasetBuilder(data_path).load_dataset()) st.session_state.dataset = df else: diff --git a/tools/process_data.py b/tools/process_data.py index 71241cf41..8893ec100 100644 --- a/tools/process_data.py +++ b/tools/process_data.py @@ -1,7 +1,7 @@ from loguru import logger from data_juicer.config import init_configs -from data_juicer.core import Executor +from data_juicer.core import Executor, RayExecutor @logger.catch(reraise=True) @@ -10,7 +10,6 @@ def main(): if cfg.executor_type == 'default': executor = Executor(cfg) elif cfg.executor_type == 'ray': - from data_juicer.core.executor.ray_executor import RayExecutor executor = RayExecutor(cfg) executor.run() From eb3b1231800c57d2bd1fd8749567b3204f47058f Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 9 Jan 2025 14:01:19 -0800 Subject: [PATCH 33/56] minor fix --- configs/datasets/mixture.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/datasets/mixture.yaml b/configs/datasets/mixture.yaml index 9fe18cccb..7bca137b1 100644 --- a/configs/datasets/mixture.yaml +++ b/configs/datasets/mixture.yaml @@ -2,9 +2,9 @@ project_name: 'dataset-mixture' dataset: max_sample_num: 10000 configs: - - type: 'local' + - type: 'ondisk' weight: 1.0 path: 'path/to/json/file' - - type: 'local' + - type: 'ondisk' weight: 1.0 files: 'path/to/csv/file' From 7c171fbd8b4c7c2ab4111f4e7e04bca412ae86b8 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 9 Jan 2025 14:10:48 -0800 Subject: [PATCH 34/56] add ExecutorBase to RayExecutor --- data_juicer/core/executor/ray_executor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/data_juicer/core/executor/ray_executor.py b/data_juicer/core/executor/ray_executor.py index 0dc1734bd..c299a5ed0 100644 --- a/data_juicer/core/executor/ray_executor.py +++ b/data_juicer/core/executor/ray_executor.py @@ -3,12 +3,13 @@ import time from typing import Optional +from jsonargparse import Namespace from loguru import logger from pydantic import PositiveInt from data_juicer.core.adapter import Adapter from data_juicer.core.data.dataset_builder import DatasetBuilder -from data_juicer.core.executor import ExecutorType +from data_juicer.core.executor import ExecutorBase, ExecutorType from data_juicer.ops import load_ops from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils.lazy_loader import LazyLoader @@ -32,7 +33,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): shutil.rmtree(self.tmp_dir) -class RayExecutor: +class RayExecutor(ExecutorBase): """ Executor based on Ray. @@ -44,7 +45,7 @@ class RayExecutor: """ - def __init__(self, cfg=None): + def __init__(self, cfg: Optional[Namespace] = None): """ Initialization method. From dd95df0f0c53f87af1dfdc8cd656d5f91f9d50a4 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 23 Jan 2025 13:47:07 -0800 Subject: [PATCH 35/56] fix bugs; use str for executor_type --- data_juicer/core/__init__.py | 5 +- data_juicer/core/analyzer.py | 13 ++--- data_juicer/core/data/dataset_builder.py | 25 ++++++---- data_juicer/core/data/load_strategy.py | 45 +++++++++-------- data_juicer/core/executor/__init__.py | 8 ++- data_juicer/core/executor/base.py | 7 --- data_juicer/core/executor/local_executor.py | 3 +- data_juicer/core/executor/ray_executor.py | 5 +- data_juicer/format/formatter.py | 2 + tests/core/data/test_config.yaml | 3 +- tests/core/data/test_config_list.yaml | 6 +-- tests/core/test_dataload_strategy.py | 55 ++++++++++----------- tests/core/test_dataset_builder.py | 34 ++++++------- 13 files changed, 98 insertions(+), 113 deletions(-) diff --git a/data_juicer/core/__init__.py b/data_juicer/core/__init__.py index 89f4c6a28..a3d05cb5c 100644 --- a/data_juicer/core/__init__.py +++ b/data_juicer/core/__init__.py @@ -1,8 +1,8 @@ from .adapter import Adapter from .analyzer import Analyzer from .data import NestedDataset -from .executor import (Executor, ExecutorBase, ExecutorFactory, ExecutorType, - RayExecutor) +from .executor import Executor, ExecutorFactory, RayExecutor +from .executor.base import ExecutorBase from .exporter import Exporter from .monitor import Monitor from .tracer import Tracer @@ -15,7 +15,6 @@ 'Executor', 'RayExecutor', 'ExecutorBase', - 'ExecutorType', 'Exporter', 'Monitor', 'Tracer', diff --git a/data_juicer/core/analyzer.py b/data_juicer/core/analyzer.py index d9ac586e9..ef1fabb34 100644 --- a/data_juicer/core/analyzer.py +++ b/data_juicer/core/analyzer.py @@ -8,7 +8,7 @@ from data_juicer.analysis import ColumnWiseAnalysis, OverallAnalysis from data_juicer.config import init_configs -from data_juicer.format import load_formatter +from data_juicer.core.data.dataset_builder import DatasetBuilder from data_juicer.ops import NON_STATS_FILTERS, TAGGING_OPS, Filter, load_ops from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils import cache_utils @@ -44,14 +44,9 @@ def __init__(self, cfg: Optional[Namespace] = None): f'[{self.cfg.cache_compress}]') cache_utils.CACHE_COMPRESS = self.cfg.cache_compress - # setup formatter - logger.info('Setting up data formatter...') - self.formatter = load_formatter( - dataset_path=self.cfg.dataset_path, - generated_dataset_config=self.cfg.generated_dataset_config, - text_keys=self.cfg.text_keys, - suffixes=self.cfg.suffixes, - add_suffix=self.cfg.add_suffix) + # setup dataset builder + logger.info('Setting up dataset builder...') + self.dataset_builder = DatasetBuilder(cfg, executor_type='local') # prepare exporter and check export path suffix # NOTICE: no need to export dataset texts for analyzer diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index e8e2d33d4..ab73ace06 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -11,7 +11,6 @@ from data_juicer.core.data.data_validator import DataValidatorRegistry from data_juicer.core.data.load_strategy import DataLoadStrategyRegistry from data_juicer.core.data.ray_dataset import RayDataset -from data_juicer.core.executor.base import ExecutorType from data_juicer.utils.file_utils import is_absolute_path from data_juicer.utils.sample import random_sample @@ -23,17 +22,20 @@ class DatasetBuilder(object): def __init__(self, cfg: Namespace, executor_type: str = 'local'): # if generated_dataset_config present, prioritize - if cfg.generated_dataset_config: + if hasattr( + cfg, + 'generated_dataset_config') and cfg.generated_dataset_config: self.use_generated_dataset_config = True self.generated_dataset_config = cfg.generated_dataset_config return + self.use_generated_dataset_config = False self.cfg = cfg - self.executor_type = ExecutorType(executor_type) + self.executor_type = executor_type - if cfg.dataset_path is not None: + if hasattr(cfg, 'dataset_path') and cfg.dataset_path is not None: ds_configs = rewrite_cli_datapath(cfg.dataset_path) - elif cfg.dataset is not None: + elif hasattr(cfg, 'dataset') and cfg.dataset is not None: ds_configs = cfg.dataset else: raise ConfigValidationError( @@ -80,8 +82,8 @@ def __init__(self, cfg: Namespace, executor_type: str = 'local'): data_source = ds_config.get('source', None) self.load_strategies.append( DataLoadStrategyRegistry.get_strategy_class( - self.executor_type.value, data_type, - data_source)(ds_config, cfg=self.cfg)) + self.executor_type, data_type, data_source)(ds_config, + cfg=self.cfg)) # initialzie the sample numbers self.max_sample_num = ds_configs.get('max_sample_num', None) @@ -90,6 +92,9 @@ def __init__(self, cfg: Namespace, executor_type: str = 'local'): self.weights = [stra.weight for stra in self.load_strategies] self.sample_numbers = get_sample_numbers(self.weights, self.max_sample_num) + else: + self.weights = [1.0 for stra in self.load_strategies] + self.sample_numbers = [None for stra in self.load_strategies] # initialize data validators self.validators = [] @@ -125,9 +130,9 @@ def load_dataset(self, **kwargs) -> Union[NestedDataset, RayDataset]: _datasets.append(dataset) # handle data mixture - if self.executor_type == ExecutorType.LOCAL: + if self.executor_type == 'local': return NestedDataset(concatenate_datasets(_datasets)) - elif self.executor_type == ExecutorType.RAY: + elif self.executor_type == 'ray': # TODO: support multiple datasets and mixing for ray assert len(_datasets) == 1, 'Ray setup supports one dataset now' return _datasets[0] @@ -169,7 +174,7 @@ def rewrite_cli_datapath(dataset_path, max_sample_num=None) -> List: for p, w in zip(paths, weights): if os.path.isdir(p) or os.path.isfile(p): # local files - ret['configs'].append({'type': 'ondisk', 'path': [p], 'weight': w}) + ret['configs'].append({'type': 'ondisk', 'path': p, 'weight': w}) elif (not is_absolute_path(p) and not p.startswith('.') and p.count('/') <= 1): # remote huggingface diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 76c597d5c..8bf0c9c9e 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -8,7 +8,6 @@ from data_juicer.core.data import DJDataset, RayDataset from data_juicer.core.data.config_validator import ConfigValidator -from data_juicer.core.executor.base import ExecutorType from data_juicer.download.downloader import validate_snapshot_format from data_juicer.format.formatter import unify_format from data_juicer.format.load import load_formatter @@ -27,7 +26,7 @@ class StrategyKey: """ Immutable key for strategy registration with wildcard support """ - executor_type: ExecutorType + executor_type: str data_type: str data_source: str @@ -41,8 +40,7 @@ def matches(self, other: 'StrategyKey') -> bool: - '[seq]' matches any character in seq - '[!seq]' matches any character not in seq """ - return (fnmatch.fnmatch(other.executor_type.value, - self.executor_type.value) + return (fnmatch.fnmatch(other.executor_type, self.executor_type) and fnmatch.fnmatch(other.data_type, self.data_type) and fnmatch.fnmatch(other.data_source, self.data_source)) @@ -71,7 +69,7 @@ class DataLoadStrategyRegistry: @classmethod def get_strategy_class( - cls, executor_type: ExecutorType, data_type: str, + cls, executor_type: str, data_type: str, data_source: str) -> Optional[Type[DataLoadStrategy]]: """ Retrieve the most specific matching strategy @@ -81,7 +79,7 @@ def get_strategy_class( 2. Wildcard matches from most specific to most general """ # default to wildcard if not provided - executor_type = executor_type or ExecutorType.ANY + executor_type = executor_type or '*' data_type = data_type or '*' data_source = data_source or '*' @@ -121,8 +119,7 @@ def specificity_score(key: StrategyKey) -> int: return None @classmethod - def register(cls, executor_type: ExecutorType, data_type: str, - data_source: str): + def register(cls, executor_type: str, data_type: str, data_source: str): """ Decorator for registering data load strategies with wildcard support @@ -179,7 +176,7 @@ def load_data(self, **kwargs) -> DJDataset: # pass -@DataLoadStrategyRegistry.register(ExecutorType.RAY, 'ondisk', 'json') +@DataLoadStrategyRegistry.register('ray', 'ondisk', 'json') class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): CONFIG_VALIDATION_RULES = { @@ -191,13 +188,13 @@ class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): } def load_data(self, **kwargs): - dataset = rd.read_json(self.ds_config.path) + dataset = rd.read_json(self.ds_config['path']) return RayDataset(dataset, - dataset_path=self.ds_config.path, + dataset_path=self.ds_config['path'], cfg=self.cfg) -@DataLoadStrategyRegistry.register(ExecutorType.RAY, 'remote', 'huggingface') +@DataLoadStrategyRegistry.register('ray', 'remote', 'huggingface') class RayHuggingfaceDataLoadStrategy(RayDataLoadStrategy): CONFIG_VALIDATION_RULES = { @@ -213,7 +210,7 @@ def load_data(self, **kwargs): 'Huggingface data load strategy is not implemented') -@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', '*') +@DataLoadStrategyRegistry.register('local', 'ondisk', '*') class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for on disk data for LocalExecutor @@ -229,16 +226,18 @@ class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy): } def load_data(self, **kwargs): + print(f'kwards: {kwargs}') # use proper formatter to load data - formatter = load_formatter(dataset_path=self.ds_config.path, + formatter = load_formatter(dataset_path=self.ds_config['path'], suffixes=self.cfg.suffixes, text_keys=self.cfg.text_keys, - add_suffix=self.cfg.add_suffix**kwargs) + add_suffix=self.cfg.add_suffix, + **kwargs) # TODO more sophiscated localformatter routing - return formatter.load_data() + return formatter.load_dataset() -@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'huggingface') +@DataLoadStrategyRegistry.register('local', 'remote', 'huggingface') class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for Huggingface dataset for LocalExecutor @@ -254,8 +253,8 @@ class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy): } def load_data(self, **kwargs): - num_proc = kwargs.get('num_proc', 1) - ds = datasets.load_dataset(self.ds_config.path, + num_proc = kwargs.pop('num_proc', 1) + ds = datasets.load_dataset(self.ds_config['path'], split=self.ds_config.split, name=self.ds_config.name, limit=self.ds_config.limit, @@ -264,7 +263,7 @@ def load_data(self, **kwargs): ds = unify_format(ds, text_keys=self.text_keys, num_proc=num_proc) -@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'modelscope') +@DataLoadStrategyRegistry.register('local', 'remote', 'modelscope') class LocalModelScopeDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for ModelScope dataset for LocalExecutor @@ -275,7 +274,7 @@ def load_data(self): 'ModelScope data load strategy is not implemented') -@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'arxiv') +@DataLoadStrategyRegistry.register('local', 'remote', 'arxiv') class LocalArxivDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for arxiv dataset for LocalExecutor @@ -294,7 +293,7 @@ def load_data(self): 'Arxiv data load strategy is not implemented') -@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'wiki') +@DataLoadStrategyRegistry.register('local', 'remote', 'wiki') class LocalWikiDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for wiki dataset for LocalExecutor @@ -312,7 +311,7 @@ def load_data(self): raise NotImplementedError('Wiki data load strategy is not implemented') -@DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'remote', 'commoncrawl') +@DataLoadStrategyRegistry.register('local', 'remote', 'commoncrawl') class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy): """ data load strategy for commoncrawl dataset for LocalExecutor diff --git a/data_juicer/core/executor/__init__.py b/data_juicer/core/executor/__init__.py index a837ad474..58402bf9b 100644 --- a/data_juicer/core/executor/__init__.py +++ b/data_juicer/core/executor/__init__.py @@ -1,9 +1,7 @@ -from .base import ExecutorBase, ExecutorType +from .base import ExecutorBase from .factory import ExecutorFactory from .local_executor import Executor from .ray_executor import RayExecutor -__all__ = [ - 'ExecutorBase', 'ExecutorFactory', 'Executor', 'RayExecutor', - 'ExecutorType' -] +__all__ = ['ExecutorBase' + 'ExecutorFactory', 'Executor', 'RayExecutor'] diff --git a/data_juicer/core/executor/base.py b/data_juicer/core/executor/base.py index 124f5a7b9..e1d6e2073 100644 --- a/data_juicer/core/executor/base.py +++ b/data_juicer/core/executor/base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from enum import Enum from typing import Optional from jsonargparse import Namespace @@ -8,12 +7,6 @@ from data_juicer.config import init_configs -class ExecutorType(Enum): - LOCAL = 'local' - RAY = 'ray' - ANY = '*' - - class ExecutorBase(ABC): @abstractmethod diff --git a/data_juicer/core/executor/local_executor.py b/data_juicer/core/executor/local_executor.py index f95ad5cb1..8b24ecd22 100644 --- a/data_juicer/core/executor/local_executor.py +++ b/data_juicer/core/executor/local_executor.py @@ -7,7 +7,6 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.core import ExecutorType from data_juicer.core.adapter import Adapter from data_juicer.core.data.dataset_builder import DatasetBuilder from data_juicer.core.executor import ExecutorBase @@ -54,7 +53,7 @@ def __init__(self, cfg: Optional[Namespace] = None): # setup dataset builder logger.info('Setting up dataset builder...') self.dataset_builder = DatasetBuilder(cfg, - executor_type=ExecutorType.LOCAL) + executor_type=self.executor_type) # whether to use checkpoint mechanism. If it's true, Executor will # check if there are existing checkpoints first and try to load the diff --git a/data_juicer/core/executor/ray_executor.py b/data_juicer/core/executor/ray_executor.py index c299a5ed0..5304c2688 100644 --- a/data_juicer/core/executor/ray_executor.py +++ b/data_juicer/core/executor/ray_executor.py @@ -9,7 +9,7 @@ from data_juicer.core.adapter import Adapter from data_juicer.core.data.dataset_builder import DatasetBuilder -from data_juicer.core.executor import ExecutorBase, ExecutorType +from data_juicer.core.executor import ExecutorBase from data_juicer.ops import load_ops from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils.lazy_loader import LazyLoader @@ -63,8 +63,7 @@ def __init__(self, cfg: Optional[Namespace] = None): ray.get_runtime_context().get_job_id()) # init dataset builder - self.datasetbuilder = DatasetBuilder(self.cfg, - executor_type=ExecutorType.RAY) + self.datasetbuilder = DatasetBuilder(self.cfg, executor_type='ray') def run(self, load_data_np: Optional[PositiveInt] = None, diff --git a/data_juicer/format/formatter.py b/data_juicer/format/formatter.py index 48690f48b..807d2de71 100644 --- a/data_juicer/format/formatter.py +++ b/data_juicer/format/formatter.py @@ -59,6 +59,8 @@ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: :param global_cfg: global cfg used in consequent processes, :return: formatted dataset """ + _num_proc = self.kwargs.pop('num_proc', 1) + num_proc = num_proc or _num_proc datasets = load_dataset(self.type, data_files={ key.strip('.'): self.data_files[key] diff --git a/tests/core/data/test_config.yaml b/tests/core/data/test_config.yaml index fe10ab1be..65db30be3 100644 --- a/tests/core/data/test_config.yaml +++ b/tests/core/data/test_config.yaml @@ -2,5 +2,4 @@ project_name: 'dataset-ondisk-json' dataset: configs: - type: 'ondisk' - path: - - 'sample.json' \ No newline at end of file + path: 'sample.json' \ No newline at end of file diff --git a/tests/core/data/test_config_list.yaml b/tests/core/data/test_config_list.yaml index ed8eeef47..15b52b036 100644 --- a/tests/core/data/test_config_list.yaml +++ b/tests/core/data/test_config_list.yaml @@ -2,8 +2,6 @@ project_name: 'dataset-ondisk-list' dataset: configs: - type: 'ondisk' - path: - - 'sample.json' + path: 'sample.json' - type: 'ondisk' - path: - - 'sample.txt' \ No newline at end of file + path: 'sample.txt' \ No newline at end of file diff --git a/tests/core/test_dataload_strategy.py b/tests/core/test_dataload_strategy.py index 52b315ecc..773b06eae 100644 --- a/tests/core/test_dataload_strategy.py +++ b/tests/core/test_dataload_strategy.py @@ -2,7 +2,6 @@ from data_juicer.core.data.load_strategy import ( DataLoadStrategyRegistry, DataLoadStrategy, StrategyKey ) -from data_juicer.core.executor.base import ExecutorType class MockStrategy(DataLoadStrategy): def load_data(self): @@ -15,118 +14,118 @@ def setUp(self): def test_exact_match(self): # Register a specific strategy - @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', 'json') + @DataLoadStrategyRegistry.register("local", 'ondisk', 'json') class TestStrategy(MockStrategy): pass # Test exact match strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.LOCAL, 'ondisk', 'json') + "local", 'ondisk', 'json') self.assertEqual(strategy, TestStrategy) # Test no match strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.LOCAL, 'ondisk', 'csv') + "local", 'ondisk', 'csv') self.assertIsNone(strategy) def test_wildcard_matching(self): # Register strategies with different wildcard patterns - @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', '*') + @DataLoadStrategyRegistry.register("local", 'ondisk', '*') class AllFilesStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, '*', '*') + @DataLoadStrategyRegistry.register("local", '*', '*') class AllLocalStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register(ExecutorType.ANY, '*', '*') + @DataLoadStrategyRegistry.register("*", '*', '*') class FallbackStrategy(MockStrategy): pass # Test specific matches strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.LOCAL, 'ondisk', 'json') + "local", 'ondisk', 'json') self.assertEqual(strategy, AllFilesStrategy) # Should match most specific wildcard strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.LOCAL, 'remote', 'json') + "local", 'remote', 'json') self.assertEqual(strategy, AllLocalStrategy) # Should match second level wildcard strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.RAY, 'remote', 'json') + "ray", 'remote', 'json') self.assertEqual(strategy, FallbackStrategy) # Should match most general wildcard def test_specificity_priority(self): - @DataLoadStrategyRegistry.register(ExecutorType.ANY, '*', '*') + @DataLoadStrategyRegistry.register("*", '*', '*') class GeneralStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, '*', '*') + @DataLoadStrategyRegistry.register("local", '*', '*') class LocalStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', '*') + @DataLoadStrategyRegistry.register("local", 'ondisk', '*') class LocalOndiskStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register(ExecutorType.LOCAL, 'ondisk', 'json') + @DataLoadStrategyRegistry.register("local", 'ondisk', 'json') class ExactStrategy(MockStrategy): pass # Test matching priority strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.LOCAL, 'ondisk', 'json') + "local", 'ondisk', 'json') self.assertEqual(strategy, ExactStrategy) # Should match exact first strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.LOCAL, 'ondisk', 'csv') + "local", 'ondisk', 'csv') self.assertEqual(strategy, LocalOndiskStrategy) # Should match one wildcard strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.LOCAL, 'remote', 'json') + "local", 'remote', 'json') self.assertEqual(strategy, LocalStrategy) # Should match two wildcards strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.RAY, 'remote', 'json') + "ray", 'remote', 'json') self.assertEqual(strategy, GeneralStrategy) # Should match general wildcard def test_pattern_matching(self): @DataLoadStrategyRegistry.register( - ExecutorType.LOCAL, 'ondisk', '*.json') + "local", 'ondisk', '*.json') class JsonStrategy(MockStrategy): pass @DataLoadStrategyRegistry.register( - ExecutorType.LOCAL, 'ondisk', 'data_[0-9]*') + "local", 'ondisk', 'data_[0-9]*') class NumberedDataStrategy(MockStrategy): pass # Test pattern matching strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.LOCAL, 'ondisk', 'test.json') + "local", 'ondisk', 'test.json') self.assertEqual(strategy, JsonStrategy) strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.LOCAL, 'ondisk', 'data_123') + "local", 'ondisk', 'data_123') self.assertEqual(strategy, NumberedDataStrategy) strategy = DataLoadStrategyRegistry.get_strategy_class( - ExecutorType.LOCAL, 'ondisk', 'test.csv') + "local", 'ondisk', 'test.csv') self.assertIsNone(strategy) def test_strategy_key_matches(self): # Test StrategyKey matching directly - wildcard_key = StrategyKey(ExecutorType.ANY, 'ondisk', '*.json') - specific_key = StrategyKey(ExecutorType.LOCAL, 'ondisk', 'test.json') + wildcard_key = StrategyKey("*", 'ondisk', '*.json') + specific_key = StrategyKey("local", 'ondisk', 'test.json') # Exact keys don't match wildcards self.assertTrue(wildcard_key.matches(specific_key)) self.assertFalse(specific_key.matches(wildcard_key)) # Test pattern matching - pattern_key = StrategyKey(ExecutorType.LOCAL, '*', 'data_[0-9]*') - match_key = StrategyKey(ExecutorType.LOCAL, 'ondisk', 'data_123') - no_match_key = StrategyKey(ExecutorType.LOCAL, 'ondisk', 'data_abc') + pattern_key = StrategyKey("local", '*', 'data_[0-9]*') + match_key = StrategyKey("local", 'ondisk', 'data_123') + no_match_key = StrategyKey("local", 'ondisk', 'data_abc') self.assertTrue(pattern_key.matches(match_key)) self.assertFalse(pattern_key.matches(no_match_key)) diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index cde927c52..4e981603a 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -32,7 +32,7 @@ def test_rewrite_cli_datapath_local_single_file(self): ans = rewrite_cli_datapath(dataset_path) self.assertEqual( {'configs': [ - {'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}]}, + {'path': dataset_path, 'type': 'ondisk', 'weight': 1.0}]}, ans) def test_rewrite_cli_datapath_local_directory(self): @@ -40,7 +40,7 @@ def test_rewrite_cli_datapath_local_directory(self): ans = rewrite_cli_datapath(dataset_path) self.assertEqual( {'configs': [ - {'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}]}, + {'path': dataset_path, 'type': 'ondisk', 'weight': 1.0}]}, ans) def test_rewrite_cli_datapath_hf(self): @@ -63,8 +63,8 @@ def test_rewrite_cli_datapath_with_weights(self): ans = rewrite_cli_datapath(dataset_path) self.assertEqual( {'configs': [ - {'path': ['./data/sample.json'], 'type': 'ondisk', 'weight': 0.5}, - {'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}]}, + {'path': './data/sample.json', 'type': 'ondisk', 'weight': 0.5}, + {'path': './data/sample.txt', 'type': 'ondisk', 'weight': 1.0}]}, ans) @patch('os.path.isdir') @@ -77,9 +77,9 @@ def test_rewrite_cli_datapath_local_files(self, mock_isfile, mock_isdir): dataset_path = "1.0 ds1.jsonl 2.0 ds2_dir 3.0 ds3.jsonl" expected = { 'configs': [ - {'type': 'ondisk', 'path': ['ds1.jsonl'], 'weight': 1.0}, - {'type': 'ondisk', 'path': ['ds2_dir'], 'weight': 2.0}, - {'type': 'ondisk', 'path': ['ds3.jsonl'], 'weight': 3.0} + {'type': 'ondisk', 'path': 'ds1.jsonl', 'weight': 1.0}, + {'type': 'ondisk', 'path': 'ds2_dir', 'weight': 2.0}, + {'type': 'ondisk', 'path': 'ds3.jsonl', 'weight': 3.0} ] } result = rewrite_cli_datapath(dataset_path) @@ -110,7 +110,7 @@ def test_rewrite_cli_datapath_with_max_samples(self): expected = { 'configs': [{ 'type': 'ondisk', - 'path': ['./data/sample.txt'], + 'path': './data/sample.txt', 'weight': 1.0 }], 'max_sample_num': 1000 @@ -126,7 +126,7 @@ def test_rewrite_cli_datapath_without_max_samples(self): expected = { 'configs': [{ 'type': 'ondisk', - 'path': ['./data/sample.txt'], + 'path': './data/sample.txt', 'weight': 1.0 }] } @@ -371,7 +371,7 @@ def test_builder_ondisk_config(self): self.assertIsInstance(cfg, Namespace) self.assertEqual(cfg.project_name, 'dataset-ondisk-json') self.assertEqual(cfg.dataset, - {'configs': [{'path': ['sample.json'], 'type': 'ondisk'}]}) + {'configs': [{'path': 'sample.json', 'type': 'ondisk'}]}) self.assertEqual(not cfg.dataset_path, True) def test_builder_ondisk_config_list(self): @@ -383,8 +383,8 @@ def test_builder_ondisk_config_list(self): self.assertEqual(cfg.project_name, 'dataset-ondisk-list') self.assertEqual(cfg.dataset, {'configs': [ - {'path': ['sample.json'], 'type': 'ondisk'}, - {'path': ['sample.txt'], 'type': 'ondisk'} + {'path': 'sample.json', 'type': 'ondisk'}, + {'path': 'sample.txt', 'type': 'ondisk'} ]}) self.assertEqual(not cfg.dataset_path, True) @@ -393,7 +393,7 @@ def test_builder_with_max_samples(self): self.base_cfg.dataset = { 'configs': [{ 'type': 'ondisk', - 'path': ['test.jsonl'], + 'path': 'test.jsonl', 'weight': 1.0 }], 'max_sample_num': 1000 @@ -409,7 +409,7 @@ def test_builder_without_max_samples(self): self.base_cfg.dataset = { 'configs': [{ 'type': 'ondisk', - 'path': ['test.jsonl'], + 'path': 'test.jsonl', 'weight': 1.0 }] } @@ -425,12 +425,12 @@ def test_mixed_dataset_configs(self): 'configs': [ { 'type': 'ondisk', - 'path': ['test1.jsonl'], + 'path': 'test1.jsonl', 'weight': 1.0 }, { 'type': 'ondisk', - 'path': ['test2.jsonl'], + 'path': 'test2.jsonl', 'weight': 2.0 } ], @@ -458,7 +458,7 @@ def test_invalid_max_sample_num(self): self.base_cfg.dataset = { 'configs': [{ 'type': 'ondisk', - 'path': ['test.jsonl'], + 'path': 'test.jsonl', 'weight': 1.0 }], 'max_sample_num': value From 530efa8df597c624ecb26bff24ca1990813129c6 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 23 Jan 2025 14:28:13 -0800 Subject: [PATCH 36/56] add add_same_content_to_new_column reference --- data_juicer/core/data/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/data_juicer/core/data/__init__.py b/data_juicer/core/data/__init__.py index d93899665..afd5b3acb 100644 --- a/data_juicer/core/data/__init__.py +++ b/data_juicer/core/data/__init__.py @@ -1,6 +1,9 @@ -from .dj_dataset import DJDataset, NestedDataset, wrap_func_with_nested_access +from .dj_dataset import (DJDataset, NestedDataset, + add_same_content_to_new_column, + wrap_func_with_nested_access) from .ray_dataset import RayDataset __all__ = [ - 'DJDataset', 'NestedDataset', 'RayDataset', 'wrap_func_with_nested_access' + 'DJDataset', 'NestedDataset', 'RayDataset', 'wrap_func_with_nested_access', + 'add_same_content_to_new_column' ] From 3b726bd85fc9a49974643583e1edfca43a950081 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 23 Jan 2025 21:51:33 -0800 Subject: [PATCH 37/56] ray data defaults to json --- data_juicer/core/data/load_strategy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 8bf0c9c9e..1449fe9ce 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -176,9 +176,11 @@ def load_data(self, **kwargs) -> DJDataset: # pass -@DataLoadStrategyRegistry.register('ray', 'ondisk', 'json') +@DataLoadStrategyRegistry.register('ray', 'ondisk', '*') class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): + # TODO ray defaults to json + CONFIG_VALIDATION_RULES = { 'required_fields': ['path'], 'field_types': { From cac8e5ea2e29a848bf90136bd55d2e546770d9ed Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Fri, 24 Jan 2025 10:52:43 -0800 Subject: [PATCH 38/56] fix dataset_path bug; add ray config test --- configs/datasets/mixture.yaml | 2 +- data_juicer/config/config.py | 4 +- data_juicer/core/data/dataset_builder.py | 7 +- data_juicer/core/data/load_strategy.py | 5 +- .../configs/demo-new-config.yaml | 72 +++++++++++++++++++ tests/core/data/test_config_ray.yaml | 14 ++++ tests/core/test_dataset_builder.py | 35 ++++++++- 7 files changed, 131 insertions(+), 8 deletions(-) create mode 100644 demos/process_on_ray/configs/demo-new-config.yaml create mode 100644 tests/core/data/test_config_ray.yaml diff --git a/configs/datasets/mixture.yaml b/configs/datasets/mixture.yaml index 7bca137b1..da2f077e8 100644 --- a/configs/datasets/mixture.yaml +++ b/configs/datasets/mixture.yaml @@ -7,4 +7,4 @@ dataset: path: 'path/to/json/file' - type: 'ondisk' weight: 1.0 - files: 'path/to/csv/file' + path: 'path/to/csv/file' diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index f821a2e82..96c8498b5 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -459,14 +459,14 @@ def init_setup_from_cfg(cfg: Namespace): # check and get dataset dir if cfg.get('dataset_path', None) and os.path.exists(cfg.dataset_path): - logger.warning('dataset_path config is set and a valid local path') + logger.info('dataset_path config is set and a valid local path') cfg.dataset_path = os.path.abspath(cfg.dataset_path) if os.path.isdir(cfg.dataset_path): cfg.dataset_dir = cfg.dataset_path else: cfg.dataset_dir = os.path.dirname(cfg.dataset_path) elif cfg.dataset_path == '' and cfg.get('dataset', None): - logger.warning('dataset_path config is empty; dataset is present') + logger.info('dataset_path config is empty; dataset is present') cfg.dataset_dir = '' else: logger.warning(f'dataset_path [{cfg.dataset_path}] is not a valid ' diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index ab73ace06..3c25f1257 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -5,6 +5,7 @@ import numpy as np from datasets import concatenate_datasets +from loguru import logger from data_juicer.core.data import NestedDataset from data_juicer.core.data.config_validator import ConfigValidationError @@ -33,9 +34,11 @@ def __init__(self, cfg: Namespace, executor_type: str = 'local'): self.cfg = cfg self.executor_type = executor_type - if hasattr(cfg, 'dataset_path') and cfg.dataset_path is not None: + if hasattr(cfg, 'dataset_path') and cfg.dataset_path: + logger.info(f'found dataset_path setting: {cfg.dataset_path}') ds_configs = rewrite_cli_datapath(cfg.dataset_path) - elif hasattr(cfg, 'dataset') and cfg.dataset is not None: + elif hasattr(cfg, 'dataset') and cfg.dataset: + logger.info(f'found dataset setting: {cfg.dataset}') ds_configs = cfg.dataset else: raise ConfigValidationError( diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 1449fe9ce..aebf7d6b9 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -262,7 +262,10 @@ def load_data(self, **kwargs): limit=self.ds_config.limit, num_proc=num_proc, **kwargs) - ds = unify_format(ds, text_keys=self.text_keys, num_proc=num_proc) + ds = unify_format(ds, + text_keys=self.text_keys, + num_proc=num_proc, + global_cfg=self.cfg) @DataLoadStrategyRegistry.register('local', 'remote', 'modelscope') diff --git a/demos/process_on_ray/configs/demo-new-config.yaml b/demos/process_on_ray/configs/demo-new-config.yaml new file mode 100644 index 000000000..901d33c0e --- /dev/null +++ b/demos/process_on_ray/configs/demo-new-config.yaml @@ -0,0 +1,72 @@ +# Process config example for dataset + +# global parameters +project_name: 'ray-demo-new-config' +dataset: + configs: + - type: ondisk + path: ./demos/process_on_ray/data/demo-dataset.jsonl # path to your dataset directory or file + weight: 1.0 + +export_path: './outputs/demo/demo-processed' + +executor_type: 'ray' +ray_address: 'auto' # change to your ray cluster address, e.g., ray://: + +# process schedule +# a list of several process operators with their arguments +process: + # Filter ops + - alphanumeric_filter: # filter text with alphabet/numeric ratio out of specific range. + tokenization: false # Whether to count the ratio of alphanumeric to the total number of tokens. + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.9 # the max ratio of filter range + - average_line_length_filter: # filter text with the average length of lines out of specific range. + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - character_repetition_filter: # filter text with the character repetition ratio out of specific range + rep_len: 10 # repetition length for char-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range + - flagged_words_filter: # filter text with the flagged-word ratio larger than a specific max value + lang: en # consider flagged words in what language + tokenization: false # whether to use model to tokenize documents + max_ratio: 0.0045 # the max ratio to filter text + flagged_words_dir: ./assets # directory to store flagged words dictionaries + use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - language_id_score_filter: # filter text in specific language with language scores larger than a specific max value + lang: en # keep text in what language + min_score: 0.8 # the min language scores to filter text + - maximum_line_length_filter: # filter text with the maximum length of lines out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - perplexity_filter: # filter text with perplexity score out of specific range + lang: en # compute perplexity in what language + max_ppl: 1500 # the max perplexity score to filter text + - special_characters_filter: # filter text with special-char ratio out of specific range + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.25 # the max ratio of filter range + - stopwords_filter: # filter text with stopword ratio smaller than a specific min value + lang: en # consider stopwords in what language + tokenization: false # whether to use model to tokenize documents + min_ratio: 0.3 # the min ratio to filter text + stopwords_dir: ./assets # directory to store stopwords dictionaries + use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - text_length_filter: # filter text with length out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - words_num_filter: # filter text with number of words out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + min_num: 10 # the min number of filter range + max_num: 10000 # the max number of filter range + - word_repetition_filter: # filter text with the word repetition ratio out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + rep_len: 10 # repetition length for word-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range diff --git a/tests/core/data/test_config_ray.yaml b/tests/core/data/test_config_ray.yaml new file mode 100644 index 000000000..e394f0b26 --- /dev/null +++ b/tests/core/data/test_config_ray.yaml @@ -0,0 +1,14 @@ + +# global parameters +project_name: 'ray-demo-new-config' +dataset: + configs: + - type: ondisk + path: ./demos/process_on_ray/data/demo-dataset.jsonl # path to your dataset directory or file + weight: 1.0 + +export_path: './outputs/demo/demo-processed' + +executor_type: 'ray' +ray_address: 'auto' + diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index 4e981603a..7ac101d16 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -11,8 +11,12 @@ from data_juicer.core.data.config_validator import ConfigValidationError from data_juicer.utils.unittest_utils import (DataJuicerTestCaseBase, SKIPPED_TESTS) +from data_juicer.core.data.load_strategy import RayOndiskJsonDataLoadStrategy +WORK_DIR = os.path.dirname(os.path.realpath(__file__)) + + @SKIPPED_TESTS.register_module() class DatasetBuilderTest(DataJuicerTestCaseBase): @@ -364,7 +368,7 @@ def test_builder_invalid_dataset_config_type(self): str(context.exception)) def test_builder_ondisk_config(self): - test_config_file = './data/test_config.yaml' + test_config_file = os.path.join(WORK_DIR, 'data/test_config.yaml') out = StringIO() with redirect_stdout(out): cfg = init_configs(args=f'--config {test_config_file}'.split()) @@ -375,7 +379,7 @@ def test_builder_ondisk_config(self): self.assertEqual(not cfg.dataset_path, True) def test_builder_ondisk_config_list(self): - test_config_file = './data/test_config_list.yaml' + test_config_file = os.path.join(WORK_DIR, 'data/test_config_list.yaml') out = StringIO() with redirect_stdout(out): cfg = init_configs(args=f'--config {test_config_file}'.split()) @@ -469,5 +473,32 @@ def test_invalid_max_sample_num(self): self.assertIn('should be a positive integer', str(context.exception)) + def test_builder_ray_config(self): + """Test loading Ray configuration from YAML""" + test_config_file = os.path.join(WORK_DIR, 'data/test_config_ray.yaml') + out = StringIO() + with redirect_stdout(out): + cfg = init_configs(args=f'--config {test_config_file}'.split()) + + # Verify basic config + self.assertIsInstance(cfg, Namespace) + self.assertEqual(cfg.project_name, 'ray-demo-new-config') + self.assertEqual(cfg.executor_type, 'ray') + self.assertEqual(cfg.ray_address, 'auto') + + # Verify dataset config + self.assertEqual(cfg.dataset, { + 'configs': [{ + 'type': 'ondisk', + 'path': './demos/process_on_ray/data/demo-dataset.jsonl', + 'weight': 1.0 + }] + }) + + # Create builder and verify + builder = DatasetBuilder(cfg, executor_type=cfg.executor_type) + self.assertEqual(len(builder.load_strategies), 1) + self.assertIsInstance(builder.load_strategies[0], RayOndiskJsonDataLoadStrategy) + if __name__ == '__main__': unittest.main() From a99c9b5ee6f0c91fafaf8eaaf7704befaf09dde7 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Fri, 24 Jan 2025 11:34:51 -0800 Subject: [PATCH 39/56] tests video on ray config --- .../configs/demo-new-config.yaml | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 demos/process_video_on_ray/configs/demo-new-config.yaml diff --git a/demos/process_video_on_ray/configs/demo-new-config.yaml b/demos/process_video_on_ray/configs/demo-new-config.yaml new file mode 100644 index 000000000..6300b2e8c --- /dev/null +++ b/demos/process_video_on_ray/configs/demo-new-config.yaml @@ -0,0 +1,36 @@ +# Process config example for dataset + +# global parameters +project_name: 'ray-demo' +executor_type: 'ray' +dataset: + configs: + - type: ondisk + path: './demos/process_video_on_ray/data/demo-dataset.jsonl' # path to your dataset directory or file +ray_address: 'auto' # change to your ray cluster address, e.g., ray://: +export_path: './outputs/demo/demo-processed-ray-videos' + +# process schedule +# a list of several process operators with their arguments +process: + # Filter ops + - video_duration_filter: + min_duration: 20 + max_duration: 100 + - video_resolution_filter: # filter samples according to the resolution of videos in them + min_width: 200 # the min resolution of horizontal resolution filter range (unit p) + max_width: 4096 # the max resolution of horizontal resolution filter range (unit p) + min_height: 200 # the min resolution of vertical resolution filter range (unit p) + max_height: 4096 # the max resolution of vertical resolution filter range (unit p) + any_or_all: any + # Mapper ops + - video_split_by_duration_mapper: # Mapper to split video by duration. + split_duration: 10 # duration of each video split in seconds. + min_last_split_duration: 0 # the minimum allowable duration in seconds for the last video split. If the duration of the last split is less than this value, it will be discarded. + keep_original_sample: true + - video_resize_aspect_ratio_mapper: + min_ratio: 1 + max_ratio: 1.1 + strategy: increase + - video_split_by_key_frame_mapper: # Mapper to split video by key frame. + keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only cut sample in the final datasets and the original sample will be removed. It's True in default \ No newline at end of file From 3c9caf539c69e2d3cb9cb5aea7169334f4f8a96e Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Fri, 24 Jan 2025 12:30:35 -0800 Subject: [PATCH 40/56] add default cfg logic; fix data_mixture demo --- data_juicer/config/__init__.py | 6 ++-- data_juicer/config/config.py | 33 +++++++++++++++++++++ data_juicer/core/executor/local_executor.py | 2 +- demos/data_mixture/app.py | 7 +++-- tests/config/test_config_funcs.py | 28 ++++++++++++++++- 5 files changed, 69 insertions(+), 7 deletions(-) diff --git a/data_juicer/config/__init__.py b/data_juicer/config/__init__.py index 4060b6ac7..ba6deb866 100644 --- a/data_juicer/config/__init__.py +++ b/data_juicer/config/__init__.py @@ -1,7 +1,7 @@ -from .config import (export_config, get_init_configs, init_configs, - merge_config, prepare_side_configs) +from .config import (export_config, get_default_cfg, get_init_configs, + init_configs, merge_config, prepare_side_configs) __all__ = [ 'init_configs', 'get_init_configs', 'export_config', 'merge_config', - 'prepare_side_configs' + 'prepare_side_configs', 'get_default_cfg' ] diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 96c8498b5..86f0c411f 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -920,3 +920,36 @@ def get_init_configs(cfg: Union[Namespace, Dict]): json.dump(cfg, f) inited_dj_cfg = init_configs(['--config', temp_file]) return inited_dj_cfg + + +def get_default_cfg(): + """Get default config values from config_all.yaml""" + cfg = Namespace() + + # Get path to config_all.yaml + config_dir = os.path.dirname(os.path.abspath(__file__)) + default_config_path = os.path.join(config_dir, + '../../configs/config_all.yaml') + + # Load default values from yaml + with open(default_config_path, 'r', encoding='utf-8') as f: + defaults = yaml.safe_load(f) + + # Convert to flat dictionary for namespace + flat_defaults = { + 'executor_type': 'default', + 'ray_address': 'auto', + 'suffixes': None, + 'text_keys': 'text', + 'add_suffix': False, + 'export_path': './outputs', + # Add other top-level keys from config_all.yaml + **defaults + } + + # Update cfg with defaults + for key, value in flat_defaults.items(): + if not hasattr(cfg, key): + setattr(cfg, key, value) + + return cfg diff --git a/data_juicer/core/executor/local_executor.py b/data_juicer/core/executor/local_executor.py index 8b24ecd22..6d4fdc636 100644 --- a/data_juicer/core/executor/local_executor.py +++ b/data_juicer/core/executor/local_executor.py @@ -36,7 +36,7 @@ def __init__(self, cfg: Optional[Namespace] = None): :param cfg: optional jsonargparse Namespace. """ super().__init__(cfg) - self.executor_type = 'local' + self.executor_type = 'default' self.work_dir = self.cfg.work_dir self.tracer = None diff --git a/demos/data_mixture/app.py b/demos/data_mixture/app.py index cc8efefb6..c05358806 100644 --- a/demos/data_mixture/app.py +++ b/demos/data_mixture/app.py @@ -2,8 +2,8 @@ import pandas as pd import streamlit as st - from data_juicer.core.data.dataset_builder import DatasetBuilder +from data_juicer.config import get_default_cfg if st.__version__ >= '1.23.0': data_editor = st.data_editor @@ -96,7 +96,10 @@ def mix_dataset(): ' '.join([str(weight), ds_file]) for ds_file, weight in zip(ds_files, weights) ]) - df = pd.DataFrame(DatasetBuilder(data_path).load_dataset()) + cfg = get_default_cfg() + cfg.dataset_path = data_path + dataset_builder = DatasetBuilder(cfg) + df = pd.DataFrame(dataset_builder.load_dataset()) st.session_state.dataset = df else: diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index 9ae5bd55e..87f3ff85c 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -5,7 +5,7 @@ from jsonargparse import Namespace -from data_juicer.config import init_configs +from data_juicer.config import init_configs, get_default_cfg from data_juicer.ops import load_ops from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -276,5 +276,31 @@ def test_op_params_parsing(self): self.assertIn(base_param_key, params) + def test_get_default_cfg(self): + """Test getting default configuration from config_all.yaml""" + # Get default config + cfg = get_default_cfg() + + # Verify basic default values + self.assertIsInstance(cfg, Namespace) + + # Test essential defaults + self.assertEqual(cfg.executor_type, 'default') + self.assertEqual(cfg.ray_address, 'auto') + self.assertEqual(cfg.text_keys, 'text') + self.assertEqual(cfg.add_suffix, False) + self.assertEqual(cfg.export_path, '/path/to/result/dataset.jsonl') + self.assertEqual(cfg.suffixes, []) + + # Test other important defaults from config_all.yaml + self.assertTrue(hasattr(cfg, 'np')) # Number of processes + self.assertTrue(hasattr(cfg, 'use_cache')) # Cache usage flag + self.assertTrue(hasattr(cfg, 'temp_dir')) # Temporary directory + + # Test default values are of correct type + self.assertIsInstance(cfg.executor_type, str) + self.assertIsInstance(cfg.add_suffix, bool) + self.assertIsInstance(cfg.export_path, str) + if __name__ == '__main__': unittest.main() From b9f6a998d5666674d88a710d2a9478778a5b8c02 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Mon, 27 Jan 2025 12:26:47 -0800 Subject: [PATCH 41/56] default executor + local data; fix analyzer bug --- .../{ondisk_json.yaml => local_json.yaml} | 2 +- ...ondisk_parquet.yaml => local_parquet.yaml} | 2 +- configs/datasets/mixture.yaml | 4 +- configs/datasets/validation.yaml | 2 +- data_juicer/core/analyzer.py | 4 +- data_juicer/core/data/dataset_builder.py | 9 +-- data_juicer/core/data/load_strategy.py | 42 +++++++------- data_juicer/core/executor/__init__.py | 2 +- ...{local_executor.py => default_executor.py} | 0 data_juicer/core/executor/factory.py | 2 +- tests/core/data/test_config.yaml | 4 +- tests/core/data/test_config_list.yaml | 6 +- tests/core/data/test_config_ray.yaml | 2 +- tests/core/test_dataload_strategy.py | 46 +++++++-------- tests/core/test_dataset_builder.py | 56 +++++++++---------- 15 files changed, 92 insertions(+), 91 deletions(-) rename configs/datasets/{ondisk_json.yaml => local_json.yaml} (83%) rename configs/datasets/{ondisk_parquet.yaml => local_parquet.yaml} (84%) rename data_juicer/core/executor/{local_executor.py => default_executor.py} (100%) diff --git a/configs/datasets/ondisk_json.yaml b/configs/datasets/local_json.yaml similarity index 83% rename from configs/datasets/ondisk_json.yaml rename to configs/datasets/local_json.yaml index a01e3b5a1..791ec9f32 100644 --- a/configs/datasets/ondisk_json.yaml +++ b/configs/datasets/local_json.yaml @@ -2,5 +2,5 @@ project_name: 'dataset-ondisk-json' dataset: configs: - - type: 'ondisk' + - type: 'local' path: 'path/to/json/file' diff --git a/configs/datasets/ondisk_parquet.yaml b/configs/datasets/local_parquet.yaml similarity index 84% rename from configs/datasets/ondisk_parquet.yaml rename to configs/datasets/local_parquet.yaml index e0f2fb144..bfded66f8 100644 --- a/configs/datasets/ondisk_parquet.yaml +++ b/configs/datasets/local_parquet.yaml @@ -2,5 +2,5 @@ project_name: 'dataset-ondisk-parquet' dataset: configs: - - type: 'ondisk' + - type: 'local' path: 'path/to/parquet/file' diff --git a/configs/datasets/mixture.yaml b/configs/datasets/mixture.yaml index da2f077e8..14999ca1a 100644 --- a/configs/datasets/mixture.yaml +++ b/configs/datasets/mixture.yaml @@ -2,9 +2,9 @@ project_name: 'dataset-mixture' dataset: max_sample_num: 10000 configs: - - type: 'ondisk' + - type: 'local' weight: 1.0 path: 'path/to/json/file' - - type: 'ondisk' + - type: 'local' weight: 1.0 path: 'path/to/csv/file' diff --git a/configs/datasets/validation.yaml b/configs/datasets/validation.yaml index 77947e48d..10aa138b2 100644 --- a/configs/datasets/validation.yaml +++ b/configs/datasets/validation.yaml @@ -1,6 +1,6 @@ dataset: configs: - - type: ondisk + - type: local path: path/to/data.json validators: diff --git a/data_juicer/core/analyzer.py b/data_juicer/core/analyzer.py index ef1fabb34..8b4d0cdcb 100644 --- a/data_juicer/core/analyzer.py +++ b/data_juicer/core/analyzer.py @@ -46,7 +46,7 @@ def __init__(self, cfg: Optional[Namespace] = None): # setup dataset builder logger.info('Setting up dataset builder...') - self.dataset_builder = DatasetBuilder(cfg, executor_type='local') + self.dataset_builder = DatasetBuilder(cfg, executor_type='default') # prepare exporter and check export path suffix # NOTICE: no need to export dataset texts for analyzer @@ -86,7 +86,7 @@ def run(self, load_data_np = self.cfg.np if dataset is None: logger.info('Loading dataset from data formatter...') - dataset = self.formatter.load_dataset(load_data_np, self.cfg) + dataset = self.dataset_builder.load_dataset(num_proc=load_data_np) else: logger.info(f'Using existing dataset {dataset}') if self.cfg.auto: diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 3c25f1257..0203554f2 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -21,7 +21,7 @@ class DatasetBuilder(object): DatasetBuilder is a class that builds a dataset from a configuration. """ - def __init__(self, cfg: Namespace, executor_type: str = 'local'): + def __init__(self, cfg: Namespace, executor_type: str = 'default'): # if generated_dataset_config present, prioritize if hasattr( cfg, @@ -133,11 +133,12 @@ def load_dataset(self, **kwargs) -> Union[NestedDataset, RayDataset]: _datasets.append(dataset) # handle data mixture - if self.executor_type == 'local': + if self.executor_type == 'default': return NestedDataset(concatenate_datasets(_datasets)) elif self.executor_type == 'ray': # TODO: support multiple datasets and mixing for ray - assert len(_datasets) == 1, 'Ray setup supports one dataset now' + assert len( + _datasets) == 1, 'Ray setup only supports one dataset now' return _datasets[0] @classmethod @@ -177,7 +178,7 @@ def rewrite_cli_datapath(dataset_path, max_sample_num=None) -> List: for p, w in zip(paths, weights): if os.path.isdir(p) or os.path.isfile(p): # local files - ret['configs'].append({'type': 'ondisk', 'path': p, 'weight': w}) + ret['configs'].append({'type': 'local', 'path': p, 'weight': w}) elif (not is_absolute_path(p) and not p.startswith('.') and p.count('/') <= 1): # remote huggingface diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index aebf7d6b9..9473356ac 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -123,8 +123,8 @@ def register(cls, executor_type: str, data_type: str, data_source: str): """ Decorator for registering data load strategies with wildcard support - :param executor_type: Type of executor (e.g., 'local', 'ray') - :param data_type: Type of data (e.g., 'ondisk', 'remote') + :param executor_type: Type of executor (e.g., 'default', 'ray') + :param data_type: Type of data (e.g., 'local', 'remote') :param data_source: Specific data source (e.g., 'arxiv', 's3') :return: Decorator function """ @@ -153,7 +153,7 @@ def load_data(self, **kwargs) -> RayDataset: pass -class LocalDataLoadStrategy(DataLoadStrategy): +class DefaultDataLoadStrategy(DataLoadStrategy): """ abstract class for data load strategy for LocalExecutor """ @@ -176,8 +176,8 @@ def load_data(self, **kwargs) -> DJDataset: # pass -@DataLoadStrategyRegistry.register('ray', 'ondisk', '*') -class RayOndiskJsonDataLoadStrategy(RayDataLoadStrategy): +@DataLoadStrategyRegistry.register('ray', 'local', '*') +class RayLocalJsonDataLoadStrategy(RayDataLoadStrategy): # TODO ray defaults to json @@ -212,8 +212,8 @@ def load_data(self, **kwargs): 'Huggingface data load strategy is not implemented') -@DataLoadStrategyRegistry.register('local', 'ondisk', '*') -class LocalOndiskDataLoadStrategy(LocalDataLoadStrategy): +@DataLoadStrategyRegistry.register('default', 'local', '*') +class DefaultLocalDataLoadStrategy(DefaultDataLoadStrategy): """ data load strategy for on disk data for LocalExecutor rely on AutoFormatter for actual data loading @@ -239,8 +239,8 @@ def load_data(self, **kwargs): return formatter.load_dataset() -@DataLoadStrategyRegistry.register('local', 'remote', 'huggingface') -class LocalHuggingfaceDataLoadStrategy(LocalDataLoadStrategy): +@DataLoadStrategyRegistry.register('default', 'remote', 'huggingface') +class DefaultHuggingfaceDataLoadStrategy(DefaultDataLoadStrategy): """ data load strategy for Huggingface dataset for LocalExecutor """ @@ -268,19 +268,19 @@ def load_data(self, **kwargs): global_cfg=self.cfg) -@DataLoadStrategyRegistry.register('local', 'remote', 'modelscope') -class LocalModelScopeDataLoadStrategy(LocalDataLoadStrategy): +@DataLoadStrategyRegistry.register('default', 'remote', 'modelscope') +class DefaultModelScopeDataLoadStrategy(DefaultDataLoadStrategy): """ data load strategy for ModelScope dataset for LocalExecutor """ - def load_data(self): + def load_data(self, **kwargs): raise NotImplementedError( 'ModelScope data load strategy is not implemented') -@DataLoadStrategyRegistry.register('local', 'remote', 'arxiv') -class LocalArxivDataLoadStrategy(LocalDataLoadStrategy): +@DataLoadStrategyRegistry.register('default', 'remote', 'arxiv') +class DefaultArxivDataLoadStrategy(DefaultDataLoadStrategy): """ data load strategy for arxiv dataset for LocalExecutor """ @@ -293,13 +293,13 @@ class LocalArxivDataLoadStrategy(LocalDataLoadStrategy): 'custom_validators': {} } - def load_data(self): + def load_data(self, **kwargs): raise NotImplementedError( 'Arxiv data load strategy is not implemented') -@DataLoadStrategyRegistry.register('local', 'remote', 'wiki') -class LocalWikiDataLoadStrategy(LocalDataLoadStrategy): +@DataLoadStrategyRegistry.register('default', 'remote', 'wiki') +class DefaultWikiDataLoadStrategy(DefaultDataLoadStrategy): """ data load strategy for wiki dataset for LocalExecutor """ @@ -312,12 +312,12 @@ class LocalWikiDataLoadStrategy(LocalDataLoadStrategy): 'custom_validators': {} } - def load_data(self): + def load_data(self, **kwargs): raise NotImplementedError('Wiki data load strategy is not implemented') -@DataLoadStrategyRegistry.register('local', 'remote', 'commoncrawl') -class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy): +@DataLoadStrategyRegistry.register('default', 'remote', 'commoncrawl') +class DefaultCommonCrawlDataLoadStrategy(DefaultDataLoadStrategy): """ data load strategy for commoncrawl dataset for LocalExecutor """ @@ -336,6 +336,6 @@ class LocalCommonCrawlDataLoadStrategy(LocalDataLoadStrategy): } } - def load_data(self): + def load_data(self, **kwargs): raise NotImplementedError( 'CommonCrawl data load strategy is not implemented') diff --git a/data_juicer/core/executor/__init__.py b/data_juicer/core/executor/__init__.py index 58402bf9b..75ed7676a 100644 --- a/data_juicer/core/executor/__init__.py +++ b/data_juicer/core/executor/__init__.py @@ -1,6 +1,6 @@ from .base import ExecutorBase +from .default_executor import Executor from .factory import ExecutorFactory -from .local_executor import Executor from .ray_executor import RayExecutor __all__ = ['ExecutorBase' diff --git a/data_juicer/core/executor/local_executor.py b/data_juicer/core/executor/default_executor.py similarity index 100% rename from data_juicer/core/executor/local_executor.py rename to data_juicer/core/executor/default_executor.py diff --git a/data_juicer/core/executor/factory.py b/data_juicer/core/executor/factory.py index a97e49291..4ca162350 100644 --- a/data_juicer/core/executor/factory.py +++ b/data_juicer/core/executor/factory.py @@ -1,6 +1,6 @@ from typing import Union -from .local_executor import Executor +from .default_executor import Executor from .ray_executor import RayExecutor diff --git a/tests/core/data/test_config.yaml b/tests/core/data/test_config.yaml index 65db30be3..3a1679585 100644 --- a/tests/core/data/test_config.yaml +++ b/tests/core/data/test_config.yaml @@ -1,5 +1,5 @@ -project_name: 'dataset-ondisk-json' +project_name: 'dataset-local-json' dataset: configs: - - type: 'ondisk' + - type: 'local' path: 'sample.json' \ No newline at end of file diff --git a/tests/core/data/test_config_list.yaml b/tests/core/data/test_config_list.yaml index 15b52b036..d32964672 100644 --- a/tests/core/data/test_config_list.yaml +++ b/tests/core/data/test_config_list.yaml @@ -1,7 +1,7 @@ -project_name: 'dataset-ondisk-list' +project_name: 'dataset-local-list' dataset: configs: - - type: 'ondisk' + - type: 'local' path: 'sample.json' - - type: 'ondisk' + - type: 'local' path: 'sample.txt' \ No newline at end of file diff --git a/tests/core/data/test_config_ray.yaml b/tests/core/data/test_config_ray.yaml index e394f0b26..ff3220c15 100644 --- a/tests/core/data/test_config_ray.yaml +++ b/tests/core/data/test_config_ray.yaml @@ -3,7 +3,7 @@ project_name: 'ray-demo-new-config' dataset: configs: - - type: ondisk + - type: local path: ./demos/process_on_ray/data/demo-dataset.jsonl # path to your dataset directory or file weight: 1.0 diff --git a/tests/core/test_dataload_strategy.py b/tests/core/test_dataload_strategy.py index 773b06eae..a9a8f7087 100644 --- a/tests/core/test_dataload_strategy.py +++ b/tests/core/test_dataload_strategy.py @@ -14,27 +14,27 @@ def setUp(self): def test_exact_match(self): # Register a specific strategy - @DataLoadStrategyRegistry.register("local", 'ondisk', 'json') + @DataLoadStrategyRegistry.register("default", 'local', 'json') class TestStrategy(MockStrategy): pass # Test exact match strategy = DataLoadStrategyRegistry.get_strategy_class( - "local", 'ondisk', 'json') + "default", 'local', 'json') self.assertEqual(strategy, TestStrategy) # Test no match strategy = DataLoadStrategyRegistry.get_strategy_class( - "local", 'ondisk', 'csv') + "default", 'local', 'csv') self.assertIsNone(strategy) def test_wildcard_matching(self): # Register strategies with different wildcard patterns - @DataLoadStrategyRegistry.register("local", 'ondisk', '*') + @DataLoadStrategyRegistry.register("default", 'local', '*') class AllFilesStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register("local", '*', '*') + @DataLoadStrategyRegistry.register("default", '*', '*') class AllLocalStrategy(MockStrategy): pass @@ -44,11 +44,11 @@ class FallbackStrategy(MockStrategy): # Test specific matches strategy = DataLoadStrategyRegistry.get_strategy_class( - "local", 'ondisk', 'json') + "default", 'local', 'json') self.assertEqual(strategy, AllFilesStrategy) # Should match most specific wildcard strategy = DataLoadStrategyRegistry.get_strategy_class( - "local", 'remote', 'json') + "default", 'remote', 'json') self.assertEqual(strategy, AllLocalStrategy) # Should match second level wildcard strategy = DataLoadStrategyRegistry.get_strategy_class( @@ -60,29 +60,29 @@ def test_specificity_priority(self): class GeneralStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register("local", '*', '*') + @DataLoadStrategyRegistry.register("default", '*', '*') class LocalStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register("local", 'ondisk', '*') + @DataLoadStrategyRegistry.register("default", 'local', '*') class LocalOndiskStrategy(MockStrategy): pass - @DataLoadStrategyRegistry.register("local", 'ondisk', 'json') + @DataLoadStrategyRegistry.register("default", 'local', 'json') class ExactStrategy(MockStrategy): pass # Test matching priority strategy = DataLoadStrategyRegistry.get_strategy_class( - "local", 'ondisk', 'json') + "default", 'local', 'json') self.assertEqual(strategy, ExactStrategy) # Should match exact first strategy = DataLoadStrategyRegistry.get_strategy_class( - "local", 'ondisk', 'csv') + "default", 'local', 'csv') self.assertEqual(strategy, LocalOndiskStrategy) # Should match one wildcard strategy = DataLoadStrategyRegistry.get_strategy_class( - "local", 'remote', 'json') + "default", 'remote', 'json') self.assertEqual(strategy, LocalStrategy) # Should match two wildcards strategy = DataLoadStrategyRegistry.get_strategy_class( @@ -91,41 +91,41 @@ class ExactStrategy(MockStrategy): def test_pattern_matching(self): @DataLoadStrategyRegistry.register( - "local", 'ondisk', '*.json') + "default", 'local', '*.json') class JsonStrategy(MockStrategy): pass @DataLoadStrategyRegistry.register( - "local", 'ondisk', 'data_[0-9]*') + "default", 'local', 'data_[0-9]*') class NumberedDataStrategy(MockStrategy): pass # Test pattern matching strategy = DataLoadStrategyRegistry.get_strategy_class( - "local", 'ondisk', 'test.json') + "default", 'local', 'test.json') self.assertEqual(strategy, JsonStrategy) strategy = DataLoadStrategyRegistry.get_strategy_class( - "local", 'ondisk', 'data_123') + "default", 'local', 'data_123') self.assertEqual(strategy, NumberedDataStrategy) strategy = DataLoadStrategyRegistry.get_strategy_class( - "local", 'ondisk', 'test.csv') + "default", 'local', 'test.csv') self.assertIsNone(strategy) def test_strategy_key_matches(self): # Test StrategyKey matching directly - wildcard_key = StrategyKey("*", 'ondisk', '*.json') - specific_key = StrategyKey("local", 'ondisk', 'test.json') + wildcard_key = StrategyKey("*", 'local', '*.json') + specific_key = StrategyKey("default", 'local', 'test.json') # Exact keys don't match wildcards self.assertTrue(wildcard_key.matches(specific_key)) self.assertFalse(specific_key.matches(wildcard_key)) # Test pattern matching - pattern_key = StrategyKey("local", '*', 'data_[0-9]*') - match_key = StrategyKey("local", 'ondisk', 'data_123') - no_match_key = StrategyKey("local", 'ondisk', 'data_abc') + pattern_key = StrategyKey("default", '*', 'data_[0-9]*') + match_key = StrategyKey("default", 'local', 'data_123') + no_match_key = StrategyKey("default", 'local', 'data_abc') self.assertTrue(pattern_key.matches(match_key)) self.assertFalse(pattern_key.matches(no_match_key)) diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index 7ac101d16..c0964fe55 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -11,7 +11,7 @@ from data_juicer.core.data.config_validator import ConfigValidationError from data_juicer.utils.unittest_utils import (DataJuicerTestCaseBase, SKIPPED_TESTS) -from data_juicer.core.data.load_strategy import RayOndiskJsonDataLoadStrategy +from data_juicer.core.data.load_strategy import RayLocalJsonDataLoadStrategy WORK_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -24,7 +24,7 @@ def setUp(self): """Setup basic configuration for tests""" self.base_cfg = Namespace() self.base_cfg.dataset_path = None - self.executor_type = 'local' + self.executor_type = 'default' # Get the directory where this test file is located test_file_dir = os.path.dirname(os.path.abspath(__file__)) @@ -36,7 +36,7 @@ def test_rewrite_cli_datapath_local_single_file(self): ans = rewrite_cli_datapath(dataset_path) self.assertEqual( {'configs': [ - {'path': dataset_path, 'type': 'ondisk', 'weight': 1.0}]}, + {'path': dataset_path, 'type': 'local', 'weight': 1.0}]}, ans) def test_rewrite_cli_datapath_local_directory(self): @@ -44,7 +44,7 @@ def test_rewrite_cli_datapath_local_directory(self): ans = rewrite_cli_datapath(dataset_path) self.assertEqual( {'configs': [ - {'path': dataset_path, 'type': 'ondisk', 'weight': 1.0}]}, + {'path': dataset_path, 'type': 'local', 'weight': 1.0}]}, ans) def test_rewrite_cli_datapath_hf(self): @@ -67,8 +67,8 @@ def test_rewrite_cli_datapath_with_weights(self): ans = rewrite_cli_datapath(dataset_path) self.assertEqual( {'configs': [ - {'path': './data/sample.json', 'type': 'ondisk', 'weight': 0.5}, - {'path': './data/sample.txt', 'type': 'ondisk', 'weight': 1.0}]}, + {'path': './data/sample.json', 'type': 'local', 'weight': 0.5}, + {'path': './data/sample.txt', 'type': 'local', 'weight': 1.0}]}, ans) @patch('os.path.isdir') @@ -81,9 +81,9 @@ def test_rewrite_cli_datapath_local_files(self, mock_isfile, mock_isdir): dataset_path = "1.0 ds1.jsonl 2.0 ds2_dir 3.0 ds3.jsonl" expected = { 'configs': [ - {'type': 'ondisk', 'path': 'ds1.jsonl', 'weight': 1.0}, - {'type': 'ondisk', 'path': 'ds2_dir', 'weight': 2.0}, - {'type': 'ondisk', 'path': 'ds3.jsonl', 'weight': 3.0} + {'type': 'local', 'path': 'ds1.jsonl', 'weight': 1.0}, + {'type': 'local', 'path': 'ds2_dir', 'weight': 2.0}, + {'type': 'local', 'path': 'ds3.jsonl', 'weight': 3.0} ] } result = rewrite_cli_datapath(dataset_path) @@ -113,7 +113,7 @@ def test_rewrite_cli_datapath_with_max_samples(self): expected = { 'configs': [{ - 'type': 'ondisk', + 'type': 'local', 'path': './data/sample.txt', 'weight': 1.0 }], @@ -129,7 +129,7 @@ def test_rewrite_cli_datapath_without_max_samples(self): expected = { 'configs': [{ - 'type': 'ondisk', + 'type': 'local', 'path': './data/sample.txt', 'weight': 1.0 }] @@ -255,7 +255,7 @@ def test_builder_single_dataset_config(self): self.base_cfg.dataset = { 'configs': [ { - 'type': 'ondisk', + 'type': 'local', 'path': 'test.jsonl' } ] @@ -269,7 +269,7 @@ def test_builder_single_dataset_config(self): # Verify config content preserved strategy = builder.load_strategies[0] - self.assertEqual(strategy.ds_config['type'], 'ondisk') + self.assertEqual(strategy.ds_config['type'], 'local') self.assertEqual(strategy.ds_config['path'], 'test.jsonl') def test_builder_multiple_dataset_config(self): @@ -278,11 +278,11 @@ def test_builder_multiple_dataset_config(self): self.base_cfg.dataset = { 'configs': [ { - 'type': 'ondisk', + 'type': 'local', 'path': 'test1.jsonl' }, { - 'type': 'ondisk', + 'type': 'local', 'path': 'test2.jsonl' } ] @@ -311,7 +311,7 @@ def test_builder_mixed_dataset_types(self): self.base_cfg.dataset = { 'configs': [ { - 'type': 'ondisk', + 'type': 'local', 'path': 'test1.jsonl' }, { @@ -373,9 +373,9 @@ def test_builder_ondisk_config(self): with redirect_stdout(out): cfg = init_configs(args=f'--config {test_config_file}'.split()) self.assertIsInstance(cfg, Namespace) - self.assertEqual(cfg.project_name, 'dataset-ondisk-json') + self.assertEqual(cfg.project_name, 'dataset-local-json') self.assertEqual(cfg.dataset, - {'configs': [{'path': 'sample.json', 'type': 'ondisk'}]}) + {'configs': [{'path': 'sample.json', 'type': 'local'}]}) self.assertEqual(not cfg.dataset_path, True) def test_builder_ondisk_config_list(self): @@ -384,11 +384,11 @@ def test_builder_ondisk_config_list(self): with redirect_stdout(out): cfg = init_configs(args=f'--config {test_config_file}'.split()) self.assertIsInstance(cfg, Namespace) - self.assertEqual(cfg.project_name, 'dataset-ondisk-list') + self.assertEqual(cfg.project_name, 'dataset-local-list') self.assertEqual(cfg.dataset, {'configs': [ - {'path': 'sample.json', 'type': 'ondisk'}, - {'path': 'sample.txt', 'type': 'ondisk'} + {'path': 'sample.json', 'type': 'local'}, + {'path': 'sample.txt', 'type': 'local'} ]}) self.assertEqual(not cfg.dataset_path, True) @@ -396,7 +396,7 @@ def test_builder_with_max_samples(self): """Test DatasetBuilder with max_sample_num""" self.base_cfg.dataset = { 'configs': [{ - 'type': 'ondisk', + 'type': 'local', 'path': 'test.jsonl', 'weight': 1.0 }], @@ -412,7 +412,7 @@ def test_builder_without_max_samples(self): """Test DatasetBuilder without max_sample_num""" self.base_cfg.dataset = { 'configs': [{ - 'type': 'ondisk', + 'type': 'local', 'path': 'test.jsonl', 'weight': 1.0 }] @@ -428,12 +428,12 @@ def test_mixed_dataset_configs(self): self.base_cfg.dataset = { 'configs': [ { - 'type': 'ondisk', + 'type': 'local', 'path': 'test1.jsonl', 'weight': 1.0 }, { - 'type': 'ondisk', + 'type': 'local', 'path': 'test2.jsonl', 'weight': 2.0 } @@ -461,7 +461,7 @@ def test_invalid_max_sample_num(self): for value in invalid_values: self.base_cfg.dataset = { 'configs': [{ - 'type': 'ondisk', + 'type': 'local', 'path': 'test.jsonl', 'weight': 1.0 }], @@ -489,7 +489,7 @@ def test_builder_ray_config(self): # Verify dataset config self.assertEqual(cfg.dataset, { 'configs': [{ - 'type': 'ondisk', + 'type': 'local', 'path': './demos/process_on_ray/data/demo-dataset.jsonl', 'weight': 1.0 }] @@ -498,7 +498,7 @@ def test_builder_ray_config(self): # Create builder and verify builder = DatasetBuilder(cfg, executor_type=cfg.executor_type) self.assertEqual(len(builder.load_strategies), 1) - self.assertIsInstance(builder.load_strategies[0], RayOndiskJsonDataLoadStrategy) + self.assertIsInstance(builder.load_strategies[0], RayLocalJsonDataLoadStrategy) if __name__ == '__main__': unittest.main() From acccc0117463245556d7a63c4bbb7a114b25de65 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Mon, 27 Jan 2025 12:32:36 -0800 Subject: [PATCH 42/56] pass through num_proc param for ray executor when loading dataset --- data_juicer/core/executor/ray_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_juicer/core/executor/ray_executor.py b/data_juicer/core/executor/ray_executor.py index 5304c2688..4512f483f 100644 --- a/data_juicer/core/executor/ray_executor.py +++ b/data_juicer/core/executor/ray_executor.py @@ -77,7 +77,7 @@ def run(self, """ # 1. load data logger.info('Loading dataset with Ray...') - dataset = self.datasetbuilder.load_dataset() + dataset = self.datasetbuilder.load_dataset(num_proc=load_data_np) # 2. extract processes logger.info('Preparing process operators...') From 1823cd68a0d912f16e0ba5fb4b69127f4885c9fb Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Mon, 27 Jan 2025 14:52:16 -0800 Subject: [PATCH 43/56] fix bugs for huggingface dataset loading; add sample config --- configs/demo/process-huggingface.yaml | 22 ++++++++++++++ data_juicer/core/data/dataset_builder.py | 11 ++++--- data_juicer/core/data/load_strategy.py | 37 ++++++++++++++++-------- 3 files changed, 54 insertions(+), 16 deletions(-) create mode 100644 configs/demo/process-huggingface.yaml diff --git a/configs/demo/process-huggingface.yaml b/configs/demo/process-huggingface.yaml new file mode 100644 index 000000000..cafbf54d6 --- /dev/null +++ b/configs/demo/process-huggingface.yaml @@ -0,0 +1,22 @@ +# Process config example for dataset + +# global parameters +project_name: 'demo-process' +dataset: + configs: + - type: 'remote' + source: 'huggingface' + path: 'hugfaceguy0001/retarded_bar' + name: 'question' + split: 'train' + +np: 4 # number of subprocess to process your dataset + +export_path: './outputs/demo-process/demo-processed.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + - language_id_score_filter: + lang: 'zh' + min_score: 0.8 diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 0203554f2..9304eafc0 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -83,10 +83,13 @@ def __init__(self, cfg: Namespace, executor_type: str = 'default'): # initialize data loading strategy data_type = ds_config.get('type', None) data_source = ds_config.get('source', None) - self.load_strategies.append( - DataLoadStrategyRegistry.get_strategy_class( - self.executor_type, data_type, data_source)(ds_config, - cfg=self.cfg)) + stra = DataLoadStrategyRegistry.get_strategy_class( + self.executor_type, data_type, data_source)(ds_config, + cfg=self.cfg) + if stra is None: + raise ValueError(f'No data load strategy found for' + f' {data_type} {data_source}') + self.load_strategies.append(stra) # initialzie the sample numbers self.max_sample_num = ds_configs.get('max_sample_num', None) diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 9473356ac..07d065253 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -5,6 +5,7 @@ from typing import Dict, Optional, Type, Union import datasets +from loguru import logger from data_juicer.core.data import DJDataset, RayDataset from data_juicer.core.data.config_validator import ConfigValidator @@ -78,6 +79,11 @@ def get_strategy_class( 1. Exact match 2. Wildcard matches from most specific to most general """ + logger.info(f'Getting strategy class for ' + f'exec: {executor_type}, ' + f'data_type: {data_type}, ' + f'data_source: {data_source}') + # default to wildcard if not provided executor_type = executor_type or '*' data_type = data_type or '*' @@ -113,9 +119,12 @@ def specificity_score(key: StrategyKey) -> int: if part == '*') matching_strategies.sort(key=lambda x: specificity_score(x[0])) - return matching_strategies[0][1] + found = matching_strategies[0][1] + logger.info(f'Found matching strategies: {found}') + return found # No matching strategy found + logger.warning('No matching strategy found') return None @classmethod @@ -247,7 +256,8 @@ class DefaultHuggingfaceDataLoadStrategy(DefaultDataLoadStrategy): CONFIG_VALIDATION_RULES = { 'required_fields': ['path'], - 'optional_fields': ['split', 'limit', 'name'], + 'optional_fields': + ['split', 'limit', 'name', 'data_files', 'data_dir'], 'field_types': { 'path': str }, @@ -256,16 +266,19 @@ class DefaultHuggingfaceDataLoadStrategy(DefaultDataLoadStrategy): def load_data(self, **kwargs): num_proc = kwargs.pop('num_proc', 1) - ds = datasets.load_dataset(self.ds_config['path'], - split=self.ds_config.split, - name=self.ds_config.name, - limit=self.ds_config.limit, - num_proc=num_proc, - **kwargs) - ds = unify_format(ds, - text_keys=self.text_keys, - num_proc=num_proc, - global_cfg=self.cfg) + ds = datasets.load_dataset( + self.ds_config['path'], + split=self.ds_config.get('split', None), + data_files=self.ds_config.get('data_files', None), + data_dir=self.ds_config.get('data_dir', None), + name=self.ds_config.get('name', None), + limit=self.ds_config.get('limit', None), + num_proc=num_proc, + **kwargs) + return unify_format(ds, + text_keys=self.cfg.text_keys, + num_proc=num_proc, + global_cfg=self.cfg) @DataLoadStrategyRegistry.register('default', 'remote', 'modelscope') From 2963118de2655b748e54f0ca4eda6f6fdc82e589 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 29 Jan 2025 10:19:47 -0800 Subject: [PATCH 44/56] fix typo in configs --- configs/datasets/local_json.yaml | 2 +- configs/datasets/local_parquet.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/datasets/local_json.yaml b/configs/datasets/local_json.yaml index 791ec9f32..6d64c5819 100644 --- a/configs/datasets/local_json.yaml +++ b/configs/datasets/local_json.yaml @@ -1,5 +1,5 @@ # global parameters -project_name: 'dataset-ondisk-json' +project_name: 'dataset-local-json' dataset: configs: - type: 'local' diff --git a/configs/datasets/local_parquet.yaml b/configs/datasets/local_parquet.yaml index bfded66f8..09afd91e1 100644 --- a/configs/datasets/local_parquet.yaml +++ b/configs/datasets/local_parquet.yaml @@ -1,5 +1,5 @@ # global parameters -project_name: 'dataset-ondisk-parquet' +project_name: 'dataset-local-parquet' dataset: configs: - type: 'local' From 4472aef2908791d8d9866bf3e0552601f8e8983e Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 6 Feb 2025 18:03:59 -0800 Subject: [PATCH 45/56] remove absolute path logic; remove dup test files --- data_juicer/core/data/dataset_builder.py | 4 +- data_juicer/core/data/ray_dataset.py | 53 +------------------ .../configs/demo-new-config.yaml | 2 +- .../configs/demo-new-config.yaml | 2 +- tests/ops/data/img1_dup.png | 1 - tests/ops/data/img2_dup.jpg | 1 - tests/ops/data/img3_dup.jpg | 1 - tests/ops/data/img3_dup_dup.jpg | 1 - tests/ops/data/video1_dup.mp4 | 1 - tests/ops/data/video2_dup.mp4 | 1 - tests/ops/data/video3_dup.mp4 | 1 - tests/ops/data/video3_dup_dup.mp4 | 1 - 12 files changed, 6 insertions(+), 63 deletions(-) delete mode 120000 tests/ops/data/img1_dup.png delete mode 120000 tests/ops/data/img2_dup.jpg delete mode 120000 tests/ops/data/img3_dup.jpg delete mode 120000 tests/ops/data/img3_dup_dup.jpg delete mode 120000 tests/ops/data/video1_dup.mp4 delete mode 120000 tests/ops/data/video2_dup.mp4 delete mode 120000 tests/ops/data/video3_dup.mp4 delete mode 120000 tests/ops/data/video3_dup_dup.mp4 diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 9304eafc0..e8afaf7e9 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -47,7 +47,7 @@ def __init__(self, cfg: Namespace, executor_type: str = 'default'): 'in configurations') # validate dataset config for type constraints - # TODO other constraints; ray dataset only supports ondisk, etc. + # TODO other constraints; ray dataset only supports local, etc. if type(ds_configs) != dict: raise ConfigValidationError( 'Dataset config should be a dictionary') @@ -72,7 +72,7 @@ def __init__(self, cfg: Namespace, executor_type: str = 'default'): ] if len(set(types)) > 1: raise ConfigValidationError( - 'Mixture of diff types (ONDISK/REMOTE/...) are not supported') + 'Mixture of diff types (LOCAL/REMOTE/...) are not supported') if types[0] == 'remote' and len(ds_configs['configs']) > 1: raise ConfigValidationError( 'Multiple remote datasets are not supported') diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index b03a91cfe..88b45290b 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os from functools import partial from typing import Any, Dict, List, Literal, Optional, Union @@ -19,55 +18,6 @@ ds = LazyLoader('ds', 'ray.data.read_api') -def get_abs_path(path, dataset_dir): - full_path = os.path.abspath(os.path.join(dataset_dir, path)) - if os.path.exists(full_path): - return full_path - else: - return path - - -def convert_to_absolute_paths(samples, dataset_dir, path_keys): - samples = samples.to_pydict() - for key in path_keys: - for idx in range(len(samples[key])): - paths = samples[key][idx] - if isinstance(paths, str): - samples[key][idx] = get_abs_path(paths, dataset_dir) - elif isinstance(paths, list): - samples[key][idx] = [ - get_abs_path(item, dataset_dir) for item in paths - ] - return pyarrow.Table.from_pydict(samples) - - -# TODO: check path for nestdataset -def set_dataset_to_absolute_path(dataset, dataset_path, cfg): - """ - Set all the path in input data to absolute path. - Checks dataset_dir and project_dir for valid paths. - """ - path_keys = [] - columns = dataset.columns() - for key in [cfg.video_key, cfg.image_key, cfg.audio_key]: - if key in columns: - path_keys.append(key) - if len(path_keys) > 0: - dataset_dir = os.path.dirname(dataset_path) - dataset = dataset.map_batches(partial(convert_to_absolute_paths, - dataset_dir=dataset_dir, - path_keys=path_keys), - batch_format='pyarrow', - zero_copy_batch=True) - return dataset - - -def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset: - if dataset_path: - dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - return dataset - - def get_num_gpus(op, op_proc): if not op.use_cuda(): return 0 @@ -86,7 +36,8 @@ def __init__(self, dataset: rd.Dataset, dataset_path: str = None, cfg=None) -> None: - self.data = preprocess_dataset(dataset, dataset_path, cfg) + self.data = dataset + # self.data = preprocess_dataset(dataset, dataset_path, cfg) self.num_proc = None if cfg: self.num_proc = cfg.np diff --git a/demos/process_on_ray/configs/demo-new-config.yaml b/demos/process_on_ray/configs/demo-new-config.yaml index 901d33c0e..ba6db49e9 100644 --- a/demos/process_on_ray/configs/demo-new-config.yaml +++ b/demos/process_on_ray/configs/demo-new-config.yaml @@ -4,7 +4,7 @@ project_name: 'ray-demo-new-config' dataset: configs: - - type: ondisk + - type: local path: ./demos/process_on_ray/data/demo-dataset.jsonl # path to your dataset directory or file weight: 1.0 diff --git a/demos/process_video_on_ray/configs/demo-new-config.yaml b/demos/process_video_on_ray/configs/demo-new-config.yaml index 6300b2e8c..80569715d 100644 --- a/demos/process_video_on_ray/configs/demo-new-config.yaml +++ b/demos/process_video_on_ray/configs/demo-new-config.yaml @@ -5,7 +5,7 @@ project_name: 'ray-demo' executor_type: 'ray' dataset: configs: - - type: ondisk + - type: local path: './demos/process_video_on_ray/data/demo-dataset.jsonl' # path to your dataset directory or file ray_address: 'auto' # change to your ray cluster address, e.g., ray://: export_path: './outputs/demo/demo-processed-ray-videos' diff --git a/tests/ops/data/img1_dup.png b/tests/ops/data/img1_dup.png deleted file mode 120000 index d62a85900..000000000 --- a/tests/ops/data/img1_dup.png +++ /dev/null @@ -1 +0,0 @@ -/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/img1.png \ No newline at end of file diff --git a/tests/ops/data/img2_dup.jpg b/tests/ops/data/img2_dup.jpg deleted file mode 120000 index 8a99a2526..000000000 --- a/tests/ops/data/img2_dup.jpg +++ /dev/null @@ -1 +0,0 @@ -/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/img2.jpg \ No newline at end of file diff --git a/tests/ops/data/img3_dup.jpg b/tests/ops/data/img3_dup.jpg deleted file mode 120000 index 6e8c435e3..000000000 --- a/tests/ops/data/img3_dup.jpg +++ /dev/null @@ -1 +0,0 @@ -/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/img3.jpg \ No newline at end of file diff --git a/tests/ops/data/img3_dup_dup.jpg b/tests/ops/data/img3_dup_dup.jpg deleted file mode 120000 index f539c0972..000000000 --- a/tests/ops/data/img3_dup_dup.jpg +++ /dev/null @@ -1 +0,0 @@ -/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/img3_dup.jpg \ No newline at end of file diff --git a/tests/ops/data/video1_dup.mp4 b/tests/ops/data/video1_dup.mp4 deleted file mode 120000 index 6d1bbbc84..000000000 --- a/tests/ops/data/video1_dup.mp4 +++ /dev/null @@ -1 +0,0 @@ -/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/video1.mp4 \ No newline at end of file diff --git a/tests/ops/data/video2_dup.mp4 b/tests/ops/data/video2_dup.mp4 deleted file mode 120000 index 8fa6335be..000000000 --- a/tests/ops/data/video2_dup.mp4 +++ /dev/null @@ -1 +0,0 @@ -/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/video2.mp4 \ No newline at end of file diff --git a/tests/ops/data/video3_dup.mp4 b/tests/ops/data/video3_dup.mp4 deleted file mode 120000 index f63158860..000000000 --- a/tests/ops/data/video3_dup.mp4 +++ /dev/null @@ -1 +0,0 @@ -/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/video3.mp4 \ No newline at end of file diff --git a/tests/ops/data/video3_dup_dup.mp4 b/tests/ops/data/video3_dup_dup.mp4 deleted file mode 120000 index 6a225ba39..000000000 --- a/tests/ops/data/video3_dup_dup.mp4 +++ /dev/null @@ -1 +0,0 @@ -/Users/yilei.z/dev/data-juicer/tests/ops/deduplicator/../data/video3_dup.mp4 \ No newline at end of file From 7964867cf9ac6dbeac99835d44dad7202dce6f29 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 6 Feb 2025 18:20:20 -0800 Subject: [PATCH 46/56] update .gitignore for dup files in tests --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 48108c132..987d33bc3 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,6 @@ __pycache__ .vscode/ **/__dj__produced_data__/* venv/ + +# dup files created by tests +tests/ops/data/*dup* From 96207ba5a80b61f265de6591e750e1bf28433bb6 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Fri, 7 Feb 2025 10:57:46 -0800 Subject: [PATCH 47/56] fix RayDataset schema validation issue --- data_juicer/core/data/data_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_juicer/core/data/data_validator.py b/data_juicer/core/data/data_validator.py index 5f0d9346d..4225cc78a 100644 --- a/data_juicer/core/data/data_validator.py +++ b/data_juicer/core/data/data_validator.py @@ -130,7 +130,7 @@ def validate(self, dataset: Union[NestedDataset, RayDataset]) -> None: if isinstance(dataset, NestedDataset): available_fields = set(dataset.column_names) elif isinstance(dataset, RayDataset): - available_fields = set(dataset.schema().names) + available_fields = set(dataset.data.schema().names) else: raise DataValidationError( f'Unsupported dataset type: {type(dataset)}') From 9b1d7382bdc76a84accb1d98b89c9e4dd4d9a0d3 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Fri, 7 Feb 2025 11:43:28 -0800 Subject: [PATCH 48/56] fix wiki downloader tests --- data_juicer/download/wikipedia.py | 9 +- tests/download/test_download.py | 139 ++++++++++++++++++++++-------- 2 files changed, 110 insertions(+), 38 deletions(-) diff --git a/data_juicer/download/wikipedia.py b/data_juicer/download/wikipedia.py index 6e8e29b52..03c9f24fe 100644 --- a/data_juicer/download/wikipedia.py +++ b/data_juicer/download/wikipedia.py @@ -9,12 +9,13 @@ import mwparserfromhell from datasets import Dataset +from data_juicer.download.downloader import (DocumentDownloader, + DocumentExtractor, + DocumentIterator, + download_and_extract, + get_wikipedia_urls) from data_juicer.utils.file_utils import expand_outdir_and_mkdir -from .downloader import (DocumentDownloader, DocumentExtractor, - DocumentIterator, download_and_extract, - get_wikipedia_urls) - # The majority of this code is taken from the HuggingFace # implementation of the Wikipedia dataset preparation: # https://github.com/huggingface/datasets/blob/7e30308f49f8c85dc7a2ab5aafbff04b5d2f38e2/datasets/wikipedia/wikipedia.py diff --git a/tests/download/test_download.py b/tests/download/test_download.py index a86892cbd..8ae36f490 100644 --- a/tests/download/test_download.py +++ b/tests/download/test_download.py @@ -3,8 +3,11 @@ import tempfile import os import shutil +import json +from datasets import Dataset from data_juicer.download.wikipedia import ( - get_wikipedia_urls, download_wikipedia + get_wikipedia_urls, download_wikipedia, + WikipediaDownloader, WikipediaIterator, WikipediaExtractor ) class TestDownload(unittest.TestCase): @@ -26,27 +29,54 @@ def test_wikipedia_urls(self): ] with patch('requests.get') as mock_get: - # Mock the response from Wikipedia API - mock_response = MagicMock() - mock_response.text = "some HTML containing the dump files" - mock_get.return_value = mock_response + def mock_get_response(*args, **kwargs): + url = args[0] + mock_response = MagicMock() + + if 'dumpstatus.json' in url: + mock_response.content = bytes(json.dumps({ + "jobs": { + "articlesmultistreamdump": { + "files": { + "enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2": { + "url": expected_urls[0] + }, + "enwiki-20241101-pages-articles-multistream2.xml-p41243p151573.bz2": { + "url": expected_urls[1] + }, + "enwiki-20241101-pages-articles-multistream3.xml-p311574p311329.bz2": { + "url": expected_urls[2] + } + } + } + } + }), 'utf-8') + else: + mock_response.content = bytes(""" + + + 20241101/ + + + """, 'utf-8') + + return mock_response + + mock_get.side_effect = mock_get_response urls = get_wikipedia_urls(dump_date=dump_date) - # Verify the function made the correct API call - mock_get.assert_called_once_with( - f"https://dumps.wikimedia.org/enwiki/{dump_date}/") - # Verify returned URLs - assert len(urls) > 3 + assert len(urls) == 3 assert urls[0] == expected_urls[0] assert urls[1] == expected_urls[1] assert urls[2] == expected_urls[2] + @patch('data_juicer.download.wikipedia.get_wikipedia_urls') - @patch('data_juicer.download.wikipedia.download_file') - @patch('data_juicer.download.wikipedia.process_wiki_dump') - def test_wikipedia_download(self, mock_process, mock_download, mock_get_urls): + @patch('data_juicer.download.downloader.download_and_extract') + @patch('data_juicer.download.wikipedia.download_and_extract') # Add this patch too + def test_wikipedia_download(self, mock_download_and_extract_wiki, mock_download_and_extract, mock_get_urls): dump_date = "20241101" url_limit = 1 item_limit = 50 @@ -57,36 +87,77 @@ def test_wikipedia_download(self, mock_process, mock_download, mock_get_urls): ] mock_get_urls.return_value = mock_urls - # Mock the download process - mock_download.return_value = "/tmp/mock_downloaded_file.bz2" + # Create expected output paths + output_paths = [ + os.path.join(self.temp_dir, "enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2.jsonl") + ] + + # Create mock dataset + mock_dataset = Dataset.from_dict({ + 'text': [f"Article {i}" for i in range(10)], + 'title': [f"Title {i}" for i in range(10)], + 'id': [str(i) for i in range(10)], + 'url': [f"https://en.wikipedia.org/wiki/Title_{i}" for i in range(10)], + 'language': ['en'] * 10, + 'source_id': ['enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2'] * 10, + 'filename': ['enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2.jsonl'] * 10 + }) - # Mock the processing result - mock_df = MagicMock() - mock_df.take.return_value = [{"text": f"Article {i}"} for i in range(10)] - mock_process.return_value = mock_df + # Set return value for both mocks + mock_download_and_extract.return_value = mock_dataset + mock_download_and_extract_wiki.return_value = mock_dataset + # Add print statements to debug + print("Before calling download_wikipedia") + # Run the function - wiki_df = download_wikipedia( - self.temp_dir, - dump_date=dump_date, - url_limit=url_limit, + result = download_wikipedia( + self.temp_dir, + dump_date=dump_date, + url_limit=url_limit, item_limit=item_limit ) + print("After calling download_wikipedia") + + # Print mock call counts + print(f"mock_download_and_extract.call_count: {mock_download_and_extract.call_count}") + print(f"mock_download_and_extract_wiki.call_count: {mock_download_and_extract_wiki.call_count}") + # Verify the calls - mock_get_urls.assert_called_once_with(dump_date=dump_date) - mock_download.assert_called_once_with( - mock_urls[0], - os.path.join(self.temp_dir, os.path.basename(mock_urls[0])) - ) - mock_process.assert_called_once() + mock_get_urls.assert_called_once_with(language='en', dump_date=dump_date) + + # Try both mocks + if mock_download_and_extract.call_count > 0: + mock = mock_download_and_extract + else: + mock = mock_download_and_extract_wiki + + # Verify download_and_extract was called with correct arguments + mock.assert_called_once() + call_args = mock.call_args[0] + assert call_args[0] == mock_urls[:url_limit] # urls (limited by url_limit) + assert call_args[1] == output_paths # output_paths + assert isinstance(call_args[2], WikipediaDownloader) # downloader + assert isinstance(call_args[3], WikipediaIterator) # iterator + assert isinstance(call_args[4], WikipediaExtractor) # extractor + + # Verify the output format + expected_format = { + 'text': str, + 'title': str, + 'id': str, + 'url': str, + 'language': str, + 'source_id': str, + 'filename': str, + } + assert call_args[5] == expected_format # output_format # Verify the result - sample = wiki_df.take(10) - assert len(sample) == 10 - - # Verify the mocks were used correctly - mock_df.take.assert_called_once_with(10) + assert isinstance(result, Dataset) + assert len(result) == 10 + assert all(field in result.features for field in expected_format.keys()) if __name__ == '__main__': From 828e7ba01c7fd6cd3d7ef16e2ec526eef31a18c9 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Fri, 7 Feb 2025 12:23:25 -0800 Subject: [PATCH 49/56] remove mixture formatter; logic captured in dataloader --- data_juicer/format/__init__.py | 5 +- data_juicer/format/mixture_formatter.py | 88 ------------------------- docs/Operators.md | 15 ++--- tests/format/test_mixture_formatter.py | 74 --------------------- 4 files changed, 9 insertions(+), 173 deletions(-) delete mode 100644 data_juicer/format/mixture_formatter.py delete mode 100644 tests/format/test_mixture_formatter.py diff --git a/data_juicer/format/__init__.py b/data_juicer/format/__init__.py index 368dcf220..4885cffd8 100644 --- a/data_juicer/format/__init__.py +++ b/data_juicer/format/__init__.py @@ -5,13 +5,12 @@ from .empty_formatter import EmptyFormatter, RayEmptyFormatter from .formatter import LocalFormatter, RemoteFormatter from .json_formatter import JsonFormatter -from .mixture_formatter import MixtureFormatter from .parquet_formatter import ParquetFormatter from .text_formatter import TextFormatter from .tsv_formatter import TsvFormatter __all__ = [ 'JsonFormatter', 'LocalFormatter', 'RemoteFormatter', 'TextFormatter', - 'ParquetFormatter', 'CsvFormatter', 'TsvFormatter', 'MixtureFormatter', - 'EmptyFormatter', 'RayEmptyFormatter' + 'ParquetFormatter', 'CsvFormatter', 'TsvFormatter', 'EmptyFormatter', + 'RayEmptyFormatter' ] diff --git a/data_juicer/format/mixture_formatter.py b/data_juicer/format/mixture_formatter.py deleted file mode 100644 index c62a6b845..000000000 --- a/data_juicer/format/mixture_formatter.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import List, Union - -import numpy as np -from datasets import Dataset, concatenate_datasets -from loguru import logger - -from data_juicer.format.formatter import BaseFormatter -from data_juicer.utils.sample import random_sample - - -class MixtureFormatter(BaseFormatter): - """The class mixes multiple datasets by randomly selecting samples from - every dataset and merging them, and then exports the merged datasset as a - new mixed dataset.""" - - def __init__(self, - dataset_path: str, - suffixes: Union[str, List[str], None] = None, - text_keys=None, - add_suffix=False, - max_samples=None, - **kwargs): - """ - Initialization method. - - :param dataset_path: a dataset file or a dataset dir or a list - of them, optional weights, default 1.0 e.g. ` ds.jsonl - ds_dir ds_file.json` - :param suffixes: files with specified suffixes to be processed - :param text_keys: key names of field that stores sample text. - :param add_suffix: whether to add the file suffix to dataset - meta info - :param max_samples: max samples number of mixed dataset. - :param kwargs: extra args - """ - - data_prefixes, weights = self._get_weight(data_prefix=dataset_path) - sample_numbers = [0] * len(weights) - if max_samples is not None: - # Normalize weights. - weights = np.array(weights, dtype=np.float64) - sum_weights = np.sum(weights) - assert sum_weights > 0.0 - weights /= sum_weights - sample_num_per_dataset = [ - int(np.ceil(max_samples * weight)) for weight in weights - ] - - # Adjust - acc_sample_numbers = 0 - for i in range(len(sample_num_per_dataset)): - sample_numbers[i] = min(sample_num_per_dataset[i], - max_samples - acc_sample_numbers) - acc_sample_numbers += sample_numbers[i] - - self.sample_numbers = sample_numbers - self.weights = weights - self.formatters = None - # [ - # load_formatter(dataset_path=data_prefix, - # suffixes=suffixes, - # text_keys=text_keys, - # add_suffix=add_suffix, - # **kwargs) for data_prefix in data_prefixes - # ] - - def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: - """ - Load a mixed dataset. - - :param num_proc: number of processes when loading the dataset - :param global_cfg: the global cfg used in consequent processes, - :return: mixed dataset - """ - dataset_list = [] - for weight, sample_num, formatter in zip(self.weights, - self.sample_numbers, - self.formatters): - dataset = formatter.load_dataset(num_proc, global_cfg) - sampled = random_sample(dataset, weight, sample_num) - logger.info(f'sampled {len(sampled)} from ' - f'{len(dataset)}') - dataset_list.append(sampled) - - from data_juicer.core.data import NestedDataset - mixed_dataset = NestedDataset(concatenate_datasets(dataset_list)) - logger.info(f'There are {len(mixed_dataset)} in final dataset') - return mixed_dataset diff --git a/docs/Operators.md b/docs/Operators.md index df9ccc348..ca82ba9e4 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -35,7 +35,7 @@ Data-Juicer 中的算子分为以下 7 种类型。 | [aggregator](#aggregator) | 4 | Aggregate for batched samples, such as summary or conclusion. 对批量样本进行汇总,如得出总结或结论。 | | [deduplicator](#deduplicator) | 10 | Detects and removes duplicate samples. 识别、删除重复样本。 | | [filter](#filter) | 45 | Filters out low-quality samples. 过滤低质量样本。 | -| [formatter](#formatter) | 9 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | +| [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | | [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 | | [mapper](#mapper) | 75 | Edits and transforms samples. 对数据样本进行编辑和转换。 | | [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 | @@ -67,7 +67,7 @@ All the specific operators are listed below, each featured with several capabili | entity_attribute_aggregator | 💻CPU 🔗API 🟢Stable | Return conclusion of the given entity's attribute from some docs. 从一些文档返回给定实体的属性的结论。 | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) | | meta_tags_aggregator | 💻CPU 🔗API 🟢Stable | Merge similar meta tags to one tag. 将类似的元标记合并到一个标记。 | [code](../data_juicer/ops/aggregator/meta_tags_aggregator.py) | [tests](../tests/ops/aggregator/test_meta_tags_aggregator.py) | | most_relavant_entities_aggregator | 💻CPU 🔗API 🟢Stable | Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. 从一些文本中提取与给定实体密切相关的实体,并按重要性的降序对它们进行排序。 | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) | -| nested_aggregator | 🔤Text 💻CPU 🔗API 🟢Stable | Considering the limitation of input length, nested aggregate contents for each given number of samples. 考虑到输入长度的限制,嵌套聚合每个给定数量的样本的内容。 | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) | +| nested_aggregator | 🔤Text 💻CPU 🔗API 🟢Stable | Considering the limitation of input length, nested aggregate contents for each given number of samples. 考虑到输入长度的限制,为每个给定数量的样本嵌套聚合内容。 | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) | ## deduplicator @@ -105,7 +105,7 @@ All the specific operators are listed below, each featured with several capabili | image_size_filter | 🏞Image 💻CPU 🟢Stable | Keep data samples whose image size (in Bytes/KB/MB/...) within a specific range. 保留图像大小 (以字节/KB/MB/... 为单位) 在特定范围内的数据样本。 | [code](../data_juicer/ops/filter/image_size_filter.py) | [tests](../tests/ops/filter/test_image_size_filter.py) | | image_text_matching_filter | 🔮Multimodal 🚀GPU 🧩HF 🟢Stable | Filter to keep samples those matching score between image and text within a specific range. 过滤器将图像和文本之间的匹配分数保持在特定范围内。 | [code](../data_juicer/ops/filter/image_text_matching_filter.py) | [tests](../tests/ops/filter/test_image_text_matching_filter.py) | | image_text_similarity_filter | 🔮Multimodal 🚀GPU 🧩HF 🟢Stable | Filter to keep samples those similarities between image and text within a specific range. 过滤器将图像和文本之间的相似性保持在特定范围内。 | [code](../data_juicer/ops/filter/image_text_similarity_filter.py) | [tests](../tests/ops/filter/test_image_text_similarity_filter.py) | -| image_watermark_filter | 🏞Image 🚀GPU 🧩HF 🟢Stable | Filter to keep samples whose images have no watermark with high probability. 过滤器以保持其图像没有水印的样本具有高概率。 | [code](../data_juicer/ops/filter/image_watermark_filter.py) | [tests](../tests/ops/filter/test_image_watermark_filter.py) | +| image_watermark_filter | 🏞Image 🚀GPU 🧩HF 🟢Stable | Filter to keep samples whose images have no watermark with high probability. 过滤器,以保留图像没有水印的样本。 | [code](../data_juicer/ops/filter/image_watermark_filter.py) | [tests](../tests/ops/filter/test_image_watermark_filter.py) | | language_id_score_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples in a specific language with confidence score larger than a specific min value. 过滤器以保留置信度得分大于特定最小值的特定语言的样本。 | [code](../data_juicer/ops/filter/language_id_score_filter.py) | [tests](../tests/ops/filter/test_language_id_score_filter.py) | | maximum_line_length_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with maximum line length within a specific range. 过滤器将最大行长度的样本保持在特定范围内。 | [code](../data_juicer/ops/filter/maximum_line_length_filter.py) | [tests](../tests/ops/filter/test_maximum_line_length_filter.py) | | perplexity_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with perplexity score less than a specific max value. 过滤以保留困惑度分数小于特定最大值的样本。 | [code](../data_juicer/ops/filter/perplexity_filter.py) | [tests](../tests/ops/filter/test_perplexity_filter.py) | @@ -142,10 +142,9 @@ All the specific operators are listed below, each featured with several capabili | empty_formatter | 🟢Stable | The class is used to create empty data. 类用于创建空数据。 | [code](../data_juicer/format/empty_formatter.py) | [tests](../tests/format/test_empty_formatter.py) | | json_formatter | 🔴Alpha | The class is used to load and format json-type files. 类用于加载和格式化json类型的文件。 | [code](../data_juicer/format/json_formatter.py) | - | | local_formatter | 🟢Stable | The class is used to load a dataset from local files or local directory. 类用于从本地文件或本地目录加载数据集。 | [code](../data_juicer/format/formatter.py) | [tests](../tests/format/test_unify_format.py) | -| mixture_formatter | 🟢Stable | The class mixes multiple datasets by randomly selecting samples from every dataset and merging them, and then exports the merged datasset as a new mixed dataset. 该类通过从每个数据集中随机选择样本并合并它们来混合多个数据集,然后将合并的datasset导出为新的混合数据集。 | [code](../data_juicer/format/mixture_formatter.py) | [tests](../tests/format/test_mixture_formatter.py) | | parquet_formatter | 🟢Stable | The class is used to load and format parquet-type files. 该类用于加载和格式化镶木地板类型的文件。 | [code](../data_juicer/format/parquet_formatter.py) | [tests](../tests/format/test_parquet_formatter.py) | | remote_formatter | 🟢Stable | The class is used to load a dataset from repository of huggingface hub. 该类用于从huggingface hub的存储库加载数据集。 | [code](../data_juicer/format/formatter.py) | [tests](../tests/format/test_unify_format.py) | -| text_formatter | 🔴Alpha | The class is used to load and format text-type files. 类用于加载和格式化文本类型文件。 | [code](../data_juicer/format/text_formatter.py) | - | +| text_formatter | 🔴Alpha | The class is used to load and format text-type files. 类用于加载和格式化文本类型的文件。 | [code](../data_juicer/format/text_formatter.py) | - | | tsv_formatter | 🟢Stable | The class is used to load and format tsv-type files. 该类用于加载和格式化tsv类型的文件。 | [code](../data_juicer/format/tsv_formatter.py) | [tests](../tests/format/test_tsv_formatter.py) | ## grouper @@ -161,7 +160,7 @@ All the specific operators are listed below, each featured with several capabili | Operator 算子 | Tags 标签 | Description 描述 | Source code 源码 | Unit tests 单测样例 | |----------|------|-------------|-------------|------------| | audio_ffmpeg_wrapped_mapper | 📣Audio 💻CPU 🟢Stable | Simple wrapper for FFmpeg audio filters. FFmpeg音频滤波器的简单包装。 | [code](../data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py) | [tests](../tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py) | -| calibrate_qa_mapper | 🔤Text 💻CPU 🔗API 🟢Stable | Mapper to calibrate question-answer pairs based on reference text. 映射器基于参考文本校准问题-答案对。 | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) | +| calibrate_qa_mapper | 🔤Text 💻CPU 🔗API 🟢Stable | Mapper to calibrate question-answer pairs based on reference text. 映射器基于参考文本校准问答对。 | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) | | calibrate_query_mapper | 💻CPU 🟢Stable | Mapper to calibrate query in question-answer pairs based on reference text. 映射器基于参考文本校准问答对中的查询。 | [code](../data_juicer/ops/mapper/calibrate_query_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_query_mapper.py) | | calibrate_response_mapper | 💻CPU 🟢Stable | Mapper to calibrate response in question-answer pairs based on reference text. 映射器基于参考文本校准问答对中的响应。 | [code](../data_juicer/ops/mapper/calibrate_response_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_response_mapper.py) | | chinese_convert_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji. 映射器在繁体中文,简体中文和日语汉字之间转换中文。 | [code](../data_juicer/ops/mapper/chinese_convert_mapper.py) | [tests](../tests/ops/mapper/test_chinese_convert_mapper.py) | @@ -189,7 +188,7 @@ All the specific operators are listed below, each featured with several capabili | image_captioning_mapper | 🔮Multimodal 🚀GPU 🧩HF 🟢Stable | Mapper to generate samples whose captions are generated based on another model and the figure. 映射器生成样本,其标题是基于另一个模型和图生成的。 | [code](../data_juicer/ops/mapper/image_captioning_mapper.py) | [tests](../tests/ops/mapper/test_image_captioning_mapper.py) | | image_diffusion_mapper | 🔮Multimodal 🚀GPU 🧩HF 🟢Stable | Generate image by diffusion model. 通过扩散模型生成图像。 | [code](../data_juicer/ops/mapper/image_diffusion_mapper.py) | [tests](../tests/ops/mapper/test_image_diffusion_mapper.py) | | image_face_blur_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to blur faces detected in images. 映射器模糊图像中检测到的人脸。 | [code](../data_juicer/ops/mapper/image_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_image_face_blur_mapper.py) | -| image_segment_mapper | 🏞Image 🚀GPU 🟢Stable | Perform segment-anything on images and return the bounding boxes. 在图像上执行segment-anything并返回边界框。 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) | +| image_segment_mapper | 🏞Image 🚀GPU 🟢Stable | Perform segment-anything on images and return the bounding boxes. 在图像上执行segment-everything并返回边界框。 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) | | image_tagging_mapper | 🏞Image 🚀GPU 🟢Stable | Mapper to generate image tags. 映射器生成图像标签。 | [code](../data_juicer/ops/mapper/image_tagging_mapper.py) | [tests](../tests/ops/mapper/test_image_tagging_mapper.py) | | mllm_mapper | 🔮Multimodal 🚀GPU 🧩HF 🟢Stable | Mapper to use MLLMs for visual question answering tasks. Mapper使用MLLMs进行视觉问答任务。 | [code](../data_juicer/ops/mapper/mllm_mapper.py) | [tests](../tests/ops/mapper/test_mllm_mapper.py) | | nlpaug_en_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to simply augment samples in English based on nlpaug library. 映射器基于nlpaug库简单地增加英语样本。 | [code](../data_juicer/ops/mapper/nlpaug_en_mapper.py) | [tests](../tests/ops/mapper/test_nlpaug_en_mapper.py) | @@ -211,7 +210,7 @@ All the specific operators are listed below, each featured with several capabili | remove_long_words_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to remove long words within a specific range. 映射器删除特定范围内的长词。 | [code](../data_juicer/ops/mapper/remove_long_words_mapper.py) | [tests](../tests/ops/mapper/test_remove_long_words_mapper.py) | | remove_non_chinese_character_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to remove non chinese Character in text samples. 映射器删除文本样本中的非中文字符。 | [code](../data_juicer/ops/mapper/remove_non_chinese_character_mapper.py) | [tests](../tests/ops/mapper/test_remove_non_chinese_character_mapper.py) | | remove_repeat_sentences_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to remove repeat sentences in text samples. 映射器删除文本样本中的重复句子。 | [code](../data_juicer/ops/mapper/remove_repeat_sentences_mapper.py) | [tests](../tests/ops/mapper/test_remove_repeat_sentences_mapper.py) | -| remove_specific_chars_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to clean specific chars in text samples. 映射器来清理文本样本中的特定字符。 | [code](../data_juicer/ops/mapper/remove_specific_chars_mapper.py) | [tests](../tests/ops/mapper/test_remove_specific_chars_mapper.py) | +| remove_specific_chars_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to clean specific chars in text samples. 映射器来清理文本示例中的特定字符。 | [code](../data_juicer/ops/mapper/remove_specific_chars_mapper.py) | [tests](../tests/ops/mapper/test_remove_specific_chars_mapper.py) | | remove_table_text_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to remove table texts from text samples. 映射器从文本样本中删除表文本。 | [code](../data_juicer/ops/mapper/remove_table_text_mapper.py) | [tests](../tests/ops/mapper/test_remove_table_text_mapper.py) | | remove_words_with_incorrect_substrings_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to remove words with incorrect substrings. 映射器删除不正确的子字符串的单词。 | [code](../data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py) | [tests](../tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py) | | replace_content_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to replace all content in the text that matches a specific regular expression pattern with a designated replacement string. 映射程序将文本中与特定正则表达式模式匹配的所有内容替换为指定的替换字符串。 | [code](../data_juicer/ops/mapper/replace_content_mapper.py) | [tests](../tests/ops/mapper/test_replace_content_mapper.py) | diff --git a/tests/format/test_mixture_formatter.py b/tests/format/test_mixture_formatter.py deleted file mode 100644 index a4d339695..000000000 --- a/tests/format/test_mixture_formatter.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import unittest - -from data_juicer.format.mixture_formatter import MixtureFormatter -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase - - -class MixtureFormatterTest(DataJuicerTestCaseBase): - - def setUp(self): - self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'data', 'structured') - self._file = os.path.join(self._path, 'demo-dataset.jsonl') - self._file2 = self._file - - def test_only_file(self): - formatter = MixtureFormatter(self._file) - ds = formatter.load_dataset() - self.assertEqual(len(ds), 6) - self.assertEqual(list(ds.features.keys()), ['text', 'meta']) - - def test_sample_weight(self): - formatter = MixtureFormatter('0.5 ' + self._file) - ds = formatter.load_dataset() - self.assertEqual(len(ds), 3) - self.assertEqual(list(ds.features.keys()), ['text', 'meta']) - - def test_sample_number(self): - max_samples = 2 - formatter = MixtureFormatter(self._file, max_samples=max_samples) - ds = formatter.load_dataset() - self.assertEqual(len(ds), max_samples) - self.assertEqual(list(ds.features.keys()), ['text', 'meta']) - - def test_sample_number_weight(self): - max_samples = 2 - formatter = MixtureFormatter('0.5 ' + self._file, - max_samples=max_samples) - ds = formatter.load_dataset() - self.assertEqual(len(ds), max_samples) - self.assertEqual(list(ds.features.keys()), ['text', 'meta']) - - def test_multi_datasets_without_weight(self): - data_path = self._file + ' ' + self._file2 - formatter = MixtureFormatter(data_path) - ds = formatter.load_dataset() - self.assertEqual(len(ds), 12) - self.assertEqual(list(ds.features.keys()), ['text', 'meta']) - - def test_multi_datasets_with_one_weight(self): - data_path = '0.5 ' + self._file + ' ' + self._file2 - formatter = MixtureFormatter(data_path) - ds = formatter.load_dataset() - self.assertEqual(len(ds), 9) - self.assertEqual(list(ds.features.keys()), ['text', 'meta']) - - def test_multi_datasets_with_weight(self): - data_path = '0.5 ' + self._file + ' 0.5 ' + self._file2 - formatter = MixtureFormatter(data_path) - ds = formatter.load_dataset() - self.assertEqual(len(ds), 6) - self.assertEqual(list(ds.features.keys()), ['text', 'meta']) - - def test_multi_datasets_with_sample(self): - max_samples = 7 - data_path = '0.5 ' + self._file + ' 0.5 ' + self._file2 - formatter = MixtureFormatter(data_path, max_samples=max_samples) - ds = formatter.load_dataset() - self.assertEqual(len(ds), max_samples) - self.assertEqual(list(ds.features.keys()), ['text', 'meta']) - - -if __name__ == '__main__': - unittest.main() From 4ffb3cffc7d8bd2874dae553a22b09676a0cacdb Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 13 Feb 2025 09:09:55 -0800 Subject: [PATCH 50/56] remove unused mixture formatter --- data_juicer/format/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/data_juicer/format/__init__.py b/data_juicer/format/__init__.py index 4885cffd8..2a1af7dde 100644 --- a/data_juicer/format/__init__.py +++ b/data_juicer/format/__init__.py @@ -1,6 +1,5 @@ from . import (csv_formatter, empty_formatter, json_formatter, - mixture_formatter, parquet_formatter, text_formatter, - tsv_formatter) + parquet_formatter, text_formatter, tsv_formatter) from .csv_formatter import CsvFormatter from .empty_formatter import EmptyFormatter, RayEmptyFormatter from .formatter import LocalFormatter, RemoteFormatter From 7c16b234174f16bd0d68eef0b65fc01a9c98658f Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 13 Feb 2025 10:28:55 -0800 Subject: [PATCH 51/56] minor fixes for CR comments --- data_juicer/core/data/data_validator.py | 2 +- data_juicer/core/data/load_strategy.py | 4 ++-- data_juicer/core/executor/default_executor.py | 1 - data_juicer/core/executor/ray_executor.py | 4 +++- data_juicer/download/arxiv.py | 2 +- data_juicer/format/load.py | 6 ++++-- environments/minimal_requires.txt | 1 + tests/core/test_dataset_builder.py | 7 +++---- 8 files changed, 15 insertions(+), 12 deletions(-) diff --git a/data_juicer/core/data/data_validator.py b/data_juicer/core/data/data_validator.py index 4225cc78a..d36adb60a 100644 --- a/data_juicer/core/data/data_validator.py +++ b/data_juicer/core/data/data_validator.py @@ -150,7 +150,7 @@ def validate(self, dataset: Union[NestedDataset, RayDataset]) -> None: MAX_SAMPLE_SIZE = 1000 if isinstance(dataset, NestedDataset): sample_size = min(MAX_SAMPLE_SIZE, len(dataset)) - sample = dataset.select(range(sample_size)) + sample = dataset.take(sample_size) values = sample[field] elif isinstance(dataset, RayDataset): # RayDataset sample_size = min(MAX_SAMPLE_SIZE, dataset.data.count()) diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 07d065253..c582eb2f3 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -199,7 +199,7 @@ class RayLocalJsonDataLoadStrategy(RayDataLoadStrategy): } def load_data(self, **kwargs): - dataset = rd.read_json(self.ds_config['path']) + dataset = RayDataset.read_json(self.ds_config['path']) return RayDataset(dataset, dataset_path=self.ds_config['path'], cfg=self.cfg) @@ -218,7 +218,7 @@ class RayHuggingfaceDataLoadStrategy(RayDataLoadStrategy): def load_data(self, **kwargs): raise NotImplementedError( - 'Huggingface data load strategy is not implemented') + 'Huggingface data load strategy for Ray is not implemented') @DataLoadStrategyRegistry.register('default', 'local', '*') diff --git a/data_juicer/core/executor/default_executor.py b/data_juicer/core/executor/default_executor.py index 6d4fdc636..d34b8869a 100644 --- a/data_juicer/core/executor/default_executor.py +++ b/data_juicer/core/executor/default_executor.py @@ -173,7 +173,6 @@ def sample_data(self, Sample a subset from the given dataset. TODO add support other than LocalExecutor - :param executor: executor :param dataset_to_sample: Dataset to sample from. If None, will use the formatter linked by the executor. Default is None. :param load_data_np: number of workers when loading the dataset. diff --git a/data_juicer/core/executor/ray_executor.py b/data_juicer/core/executor/ray_executor.py index 4512f483f..91b76feba 100644 --- a/data_juicer/core/executor/ray_executor.py +++ b/data_juicer/core/executor/ray_executor.py @@ -104,4 +104,6 @@ def run(self, dataset.data.write_json(self.cfg.export_path, force_ascii=False) tend = time.time() logger.info(f'All Ops are done in {tend - tstart:.3f}s.') - return dataset + + if not skip_return: + return dataset diff --git a/data_juicer/download/arxiv.py b/data_juicer/download/arxiv.py index 513474f1c..267fbd543 100644 --- a/data_juicer/download/arxiv.py +++ b/data_juicer/download/arxiv.py @@ -14,7 +14,7 @@ # The iterator and extractor code are in large part taken # from the Red-Pajama repo -# https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/arxiv +# https://github.com/togethercomputer/RedPajama-Data/tree/rp_v1/data_prep/arxiv class ArxivDownloader(DocumentDownloader): diff --git a/data_juicer/format/load.py b/data_juicer/format/load.py index 5bfcf1804..cf9863190 100644 --- a/data_juicer/format/load.py +++ b/data_juicer/format/load.py @@ -15,8 +15,10 @@ def load_formatter(dataset_path, :param dataset_path: Path to dataset file or dataset directory :param text_keys: key names of field that stores sample text. Default: None - :param suffixes: the suffix of files that will be read. Default: - None + :param suffixes: the suffix of files that will be read. + Default: None + :param add_suffix: whether to add the file suffix to dataset meta. + Default: False :return: a dataset formatter. """ diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt index 53b5b9762..2d7b1cc6a 100644 --- a/environments/minimal_requires.txt +++ b/environments/minimal_requires.txt @@ -35,3 +35,4 @@ Pillow fastapi[standard]>=0.100 httpx wordcloud +bs4 diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index c0964fe55..bee4b3cd3 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -17,7 +17,6 @@ WORK_DIR = os.path.dirname(os.path.realpath(__file__)) -@SKIPPED_TESTS.register_module() class DatasetBuilderTest(DataJuicerTestCaseBase): def setUp(self): @@ -32,7 +31,7 @@ def setUp(self): def test_rewrite_cli_datapath_local_single_file(self): - dataset_path = "./data/sample.txt" + dataset_path = os.path.join(WORK_DIR, "data/sample.txt") ans = rewrite_cli_datapath(dataset_path) self.assertEqual( {'configs': [ @@ -40,7 +39,7 @@ def test_rewrite_cli_datapath_local_single_file(self): ans) def test_rewrite_cli_datapath_local_directory(self): - dataset_path = "./data" + dataset_path = os.path.join(WORK_DIR, "data") ans = rewrite_cli_datapath(dataset_path) self.assertEqual( {'configs': [ @@ -58,7 +57,7 @@ def test_rewrite_cli_datapath_hf(self): ans) def test_rewrite_cli_datapath_local_wrong_files(self): - dataset_path = "./missingDir" + dataset_path = os.path.join(WORK_DIR, "missingDir") self.assertRaisesRegex(ValueError, "Unable to load the dataset", rewrite_cli_datapath, dataset_path) From f73dd418f39a6f6cb14360239fabb248942ea348 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 13 Feb 2025 10:43:15 -0800 Subject: [PATCH 52/56] resolve eager RayExecutor importing --- data_juicer/core/__init__.py | 6 +----- data_juicer/core/executor/__init__.py | 4 +--- tests/config/demo_4_dataset_test.yaml | 22 ---------------------- tools/process_data.py | 3 ++- 4 files changed, 4 insertions(+), 31 deletions(-) delete mode 100644 tests/config/demo_4_dataset_test.yaml diff --git a/data_juicer/core/__init__.py b/data_juicer/core/__init__.py index a3d05cb5c..6fbb9d5d6 100644 --- a/data_juicer/core/__init__.py +++ b/data_juicer/core/__init__.py @@ -1,8 +1,7 @@ from .adapter import Adapter from .analyzer import Analyzer from .data import NestedDataset -from .executor import Executor, ExecutorFactory, RayExecutor -from .executor.base import ExecutorBase +from .executor import Executor from .exporter import Exporter from .monitor import Monitor from .tracer import Tracer @@ -11,10 +10,7 @@ 'Adapter', 'Analyzer', 'NestedDataset', - 'ExecutorFactory', 'Executor', - 'RayExecutor', - 'ExecutorBase', 'Exporter', 'Monitor', 'Tracer', diff --git a/data_juicer/core/executor/__init__.py b/data_juicer/core/executor/__init__.py index 75ed7676a..fb683bdd4 100644 --- a/data_juicer/core/executor/__init__.py +++ b/data_juicer/core/executor/__init__.py @@ -1,7 +1,5 @@ from .base import ExecutorBase from .default_executor import Executor from .factory import ExecutorFactory -from .ray_executor import RayExecutor -__all__ = ['ExecutorBase' - 'ExecutorFactory', 'Executor', 'RayExecutor'] +__all__ = ['ExecutorBase', 'ExecutorFactory', 'Executor'] diff --git a/tests/config/demo_4_dataset_test.yaml b/tests/config/demo_4_dataset_test.yaml deleted file mode 100644 index bffc15f79..000000000 --- a/tests/config/demo_4_dataset_test.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# Process config example for Arxiv dataset - -# global parameters -project_name: 'test_demo' -dataset: - type: 'local' - format: 'json' - path: './demos/data/demo-dataset.jsonl' -np: 4 # number of subprocess to process your dataset - -export_path: './outputs/demo/demo-processed.parquet' - -# process schedule -# a list of several process operators with their arguments -process: - - whitespace_normalization_mapper: - - language_id_score_filter: - lang: 'zh' - - document_deduplicator: # deduplicate text samples using md5 hashing exact matching method - lowercase: false # whether to convert text to lower case - ignore_non_character: false - - remove_table_text_mapper: diff --git a/tools/process_data.py b/tools/process_data.py index 8893ec100..71241cf41 100644 --- a/tools/process_data.py +++ b/tools/process_data.py @@ -1,7 +1,7 @@ from loguru import logger from data_juicer.config import init_configs -from data_juicer.core import Executor, RayExecutor +from data_juicer.core import Executor @logger.catch(reraise=True) @@ -10,6 +10,7 @@ def main(): if cfg.executor_type == 'default': executor = Executor(cfg) elif cfg.executor_type == 'ray': + from data_juicer.core.executor.ray_executor import RayExecutor executor = RayExecutor(cfg) executor.run() From 8aae265fa83905d01971304e5214261b0c741fa1 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 13 Feb 2025 12:00:18 -0800 Subject: [PATCH 53/56] bugfix: handle missing configs --- data_juicer/core/data/load_strategy.py | 15 +++-- tests/core/test_dataload_strategy.py | 84 +++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index c582eb2f3..c80422652 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -237,12 +237,19 @@ class DefaultLocalDataLoadStrategy(DefaultDataLoadStrategy): } def load_data(self, **kwargs): - print(f'kwards: {kwargs}') + logger.info(f'kwargs: {kwargs}') + + # Get config values with defaults + text_keys = getattr(self.cfg, 'text_keys', + ['text']) # Default to ['text'] + suffixes = getattr(self.cfg, 'suffixes', None) # Default to None + add_suffix = getattr(self.cfg, 'add_suffix', False) # Default to False + # use proper formatter to load data formatter = load_formatter(dataset_path=self.ds_config['path'], - suffixes=self.cfg.suffixes, - text_keys=self.cfg.text_keys, - add_suffix=self.cfg.add_suffix, + text_keys=text_keys, + suffixes=suffixes, + add_suffix=add_suffix, **kwargs) # TODO more sophiscated localformatter routing return formatter.load_dataset() diff --git a/tests/core/test_dataload_strategy.py b/tests/core/test_dataload_strategy.py index a9a8f7087..7700f5b5c 100644 --- a/tests/core/test_dataload_strategy.py +++ b/tests/core/test_dataload_strategy.py @@ -1,7 +1,9 @@ import unittest from data_juicer.core.data.load_strategy import ( - DataLoadStrategyRegistry, DataLoadStrategy, StrategyKey + DataLoadStrategyRegistry, DataLoadStrategy, StrategyKey, + DefaultLocalDataLoadStrategy ) +from argparse import Namespace class MockStrategy(DataLoadStrategy): def load_data(self): @@ -129,3 +131,83 @@ def test_strategy_key_matches(self): self.assertTrue(pattern_key.matches(match_key)) self.assertFalse(pattern_key.matches(no_match_key)) + + def test_load_strategy_default_config(self): + """Test load strategy with minimal config""" + # Create minimal config + minimal_cfg = Namespace( + path='test/path' + ) + + ds_config = { + 'path': 'test/path' + } + + strategy = DefaultLocalDataLoadStrategy(ds_config, minimal_cfg) + + # Verify defaults are used + assert getattr(strategy.cfg, 'text_keys', ['text']) == ['text'] + assert getattr(strategy.cfg, 'suffixes', None) is None + assert getattr(strategy.cfg, 'add_suffix', False) is False + + def test_load_strategy_full_config(self): + """Test load strategy with full config""" + # Create config with all options + full_cfg = Namespace( + path='test/path', + text_keys=['content', 'title'], + suffixes=['.txt', '.md'], + add_suffix=True + ) + + ds_config = { + 'path': 'test/path' + } + + strategy = DefaultLocalDataLoadStrategy(ds_config, full_cfg) + + # Verify all config values are used + assert strategy.cfg.text_keys == ['content', 'title'] + assert strategy.cfg.suffixes == ['.txt', '.md'] + assert strategy.cfg.add_suffix is True + + def test_load_strategy_partial_config(self): + """Test load strategy with partial config""" + # Create config with some options + partial_cfg = Namespace( + path='test/path', + text_keys=['content'], + # suffixes and add_suffix omitted + ) + + ds_config = { + 'path': 'test/path' + } + + strategy = DefaultLocalDataLoadStrategy(ds_config, partial_cfg) + + # Verify mix of specified and default values + assert strategy.cfg.text_keys == ['content'] + assert getattr(strategy.cfg, 'suffixes', None) is None + assert getattr(strategy.cfg, 'add_suffix', False) is False + + def test_load_strategy_empty_config(self): + """Test load strategy with empty config""" + # Create empty config + empty_cfg = Namespace() + + ds_config = { + 'path': 'test/path' + } + + strategy = DefaultLocalDataLoadStrategy(ds_config, empty_cfg) + + # Verify all defaults are used + assert getattr(strategy.cfg, 'text_keys', ['text']) == ['text'] + assert getattr(strategy.cfg, 'suffixes', None) is None + assert getattr(strategy.cfg, 'add_suffix', False) is False + + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 1d65a3a4aa37b3bbdf4bb3b5b4dfebb61bae5bbb Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Thu, 13 Feb 2025 12:55:23 -0800 Subject: [PATCH 54/56] add schema support for datasets --- data_juicer/core/data/dj_dataset.py | 79 ++++++++++++++++++++- data_juicer/core/data/ray_dataset.py | 33 +++++++-- tests/core/data/test_config_ray.yaml | 2 +- tests/core/test_dataset_builder.py | 101 +++++++++++++++++++++++++-- 4 files changed, 203 insertions(+), 12 deletions(-) diff --git a/data_juicer/core/data/dj_dataset.py b/data_juicer/core/data/dj_dataset.py index f9af23f00..7e215c6bd 100644 --- a/data_juicer/core/data/dj_dataset.py +++ b/data_juicer/core/data/dj_dataset.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from functools import wraps from time import time -from typing import Union +from typing import Dict, List, Tuple, Union from datasets import Dataset, DatasetDict, is_caching_enabled from datasets.formatting.formatting import LazyBatch @@ -39,6 +39,17 @@ def process( """process a list of operators on the dataset.""" pass + @abstractmethod + def schema(self) -> Tuple[Dict, List[str]]: + """Get dataset schema and columns. + + Returns: + Tuple containing: + - Dict: Schema information mapping column names to types + - List[str]: List of column names + """ + pass + def wrap_func_with_nested_access(f): """ @@ -165,6 +176,31 @@ def __getitem__(self, key): res = super().__getitem__(key) return nested_obj_factory(res) + def schema(self) -> Tuple[Dict, List[str]]: + """Get dataset schema and columns. + + Returns: + Tuple containing: + - Dict: Schema information mapping column names to types + - List[str]: List of column names + """ + # Get features dictionary from HF dataset + features = self.features + + # Convert features to schema dict + schema = {} + for name, feature in features.items(): + # Map HF feature types to Python types + if hasattr(feature, 'dtype'): + schema[name] = feature.dtype + else: + schema[name] = str(feature) + + # Get column names + columns = self.column_names + + return schema, columns + def process( self, operators, @@ -463,3 +499,44 @@ def add_same_content_to_new_column(sample, """ sample[new_column_name] = initial_value return sample + + +class RayDataset(DJDataset): + """Ray-based dataset implementation.""" + + def schema(self) -> Tuple[Dict, List[str]]: + """Get dataset schema and columns. + + Returns: + Tuple containing: + - Dict: Schema information mapping column names to types + - List[str]: List of column names + """ + # Get PyArrow schema from Ray dataset + arrow_schema = self.data.schema() + + # Convert PyArrow schema to dict + schema = {} + for field in arrow_schema: + schema[field.name] = field.type + + # Get column names + columns = self.data.columns() + + return schema, columns + + def get_schema_string(self) -> str: + """Get a formatted string representation of the schema. + + Returns: + str: Formatted schema string + """ + schema, columns = self.schema() + + # Build formatted string + lines = ['Dataset Schema:'] + lines.append('-' * 40) + for col in columns: + lines.append(f'{col}: {schema[col]}') + + return '\n'.join(lines) diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 88b45290b..63fa32549 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -1,7 +1,9 @@ from __future__ import annotations +import copy +from argparse import Namespace from functools import partial -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import pyarrow from loguru import logger @@ -35,12 +37,31 @@ class RayDataset(DJDataset): def __init__(self, dataset: rd.Dataset, dataset_path: str = None, - cfg=None) -> None: + cfg: Optional[Namespace] = None) -> None: self.data = dataset - # self.data = preprocess_dataset(dataset, dataset_path, cfg) - self.num_proc = None - if cfg: - self.num_proc = cfg.np + self.num_proc = getattr(cfg, 'np', getattr(cfg, 'num_proc', + None)) if cfg else None + + def schema(self) -> Tuple[Dict, List[str]]: + """Get dataset schema and columns. + + Returns: + Tuple containing: + - Dict: Schema information mapping column names to types + - List[str]: List of column names + """ + # Get schema from Ray dataset + ray_schema = self.data.schema() + + # Convert PyArrow schema to dict + schema = {} + for n, t in zip(ray_schema.names, ray_schema.types): + schema[n] = t + + # Get column names + columns = copy.deepcopy(ray_schema.names) + + return schema, columns def process(self, operators, diff --git a/tests/core/data/test_config_ray.yaml b/tests/core/data/test_config_ray.yaml index ff3220c15..c1a64c43e 100644 --- a/tests/core/data/test_config_ray.yaml +++ b/tests/core/data/test_config_ray.yaml @@ -4,7 +4,7 @@ project_name: 'ray-demo-new-config' dataset: configs: - type: local - path: ./demos/process_on_ray/data/demo-dataset.jsonl # path to your dataset directory or file + path: ../../demos/process_on_ray/data/demo-dataset.jsonl # path to your dataset directory or file weight: 1.0 export_path: './outputs/demo/demo-processed' diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index bee4b3cd3..31a633d9f 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -9,12 +9,12 @@ parse_cli_datapath, DatasetBuilder) from data_juicer.core.data.config_validator import ConfigValidationError -from data_juicer.utils.unittest_utils import (DataJuicerTestCaseBase, - SKIPPED_TESTS) +from data_juicer.utils.unittest_utils import (DataJuicerTestCaseBase) from data_juicer.core.data.load_strategy import RayLocalJsonDataLoadStrategy +from data_juicer.core.data import RayDataset -WORK_DIR = os.path.dirname(os.path.realpath(__file__)) +WORK_DIR = os.path.abspath(os.path.dirname(os.path.realpath(__file__))) class DatasetBuilderTest(DataJuicerTestCaseBase): @@ -489,7 +489,7 @@ def test_builder_ray_config(self): self.assertEqual(cfg.dataset, { 'configs': [{ 'type': 'local', - 'path': './demos/process_on_ray/data/demo-dataset.jsonl', + 'path': '../../demos/process_on_ray/data/demo-dataset.jsonl', 'weight': 1.0 }] }) @@ -499,5 +499,98 @@ def test_builder_ray_config(self): self.assertEqual(len(builder.load_strategies), 1) self.assertIsInstance(builder.load_strategies[0], RayLocalJsonDataLoadStrategy) + # Load dataset and verify schema + dataset = builder.load_dataset() + schema, columns = dataset.schema() + + # Verify expected columns exist + self.assertIn('text', columns) + + # Verify schema types + import pyarrow as pa + self.assertTrue(pa.types.is_string(schema['text'])) + + + ### schema related tests + def test_builder_schema_single_dataset(self): + """Test schema for single dataset configuration""" + # Setup single dataset config + self.base_cfg.dataset = { + 'configs': [ + { + 'type': 'local', + 'path': os.path.join(WORK_DIR, 'data/sample.json') + } + ] + } + + builder = DatasetBuilder(self.base_cfg, self.executor_type) + dataset = builder.load_dataset() + + # Get schema + schema, columns = dataset.schema() + + # Verify expected columns exist + self.assertIn('text', columns) + + # Verify schema types + if isinstance(dataset, RayDataset): + import pyarrow as pa + self.assertTrue(pa.types.is_string(schema['text'])) + else: # NestedDataset + self.assertEqual(schema['text'], 'string') + + def test_builder_schema_multiple_datasets(self): + """Test schema for multiple dataset configurations""" + # Setup multiple dataset config + self.base_cfg.dataset = { + 'configs': [ + { + 'type': 'local', + 'path': os.path.join(WORK_DIR, 'data/sample.json') + }, + { + 'type': 'local', + 'path': os.path.join(WORK_DIR, 'data/sample.txt') + } + ] + } + + builder = DatasetBuilder(self.base_cfg, self.executor_type) + dataset = builder.load_dataset() + + schema, columns = dataset.schema() + + # Verify columns from both datasets + self.assertIn('text', columns) + + # Verify schema consistency + if isinstance(dataset, RayDataset): + import pyarrow as pa + self.assertTrue(pa.types.is_string(schema['text'])) + else: # NestedDataset + self.assertEqual(schema['text'], 'string') + + def test_builder_schema_validation(self): + """Test schema validation during dataset building""" + # Test with invalid schema + self.base_cfg.dataset = { + 'configs': [ + { + 'type': 'local', + 'path': os.path.join(WORK_DIR, 'data/invalid_schema.json') + } + ] + } + + with self.assertRaises(Exception) as context: + builder = DatasetBuilder(self.base_cfg, self.executor_type) + dataset = builder.load_dataset() + schema, columns = dataset.schema() + + # Verify error message + self.assertIn('schema', str(context.exception).lower()) + + if __name__ == '__main__': unittest.main() From 96a49975397a189af5f02516b5325cbe296ed289 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Fri, 14 Feb 2025 11:52:10 -0800 Subject: [PATCH 55/56] bugfix: handle relative path problem in tests --- .gitignore | 1 + tests/core/data/test_config_ray.yaml | 2 +- tests/core/test_dataset_builder.py | 4 ++-- tests/tools/test_process_data.py | 22 +++++++++++++++------- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 987d33bc3..f25ab4307 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ venv/ # dup files created by tests tests/ops/data/*dup* +tests/tools/tmp_*/ diff --git a/tests/core/data/test_config_ray.yaml b/tests/core/data/test_config_ray.yaml index c1a64c43e..19e251fdb 100644 --- a/tests/core/data/test_config_ray.yaml +++ b/tests/core/data/test_config_ray.yaml @@ -4,7 +4,7 @@ project_name: 'ray-demo-new-config' dataset: configs: - type: local - path: ../../demos/process_on_ray/data/demo-dataset.jsonl # path to your dataset directory or file + path: ./data/sample.json # path to your dataset directory or file weight: 1.0 export_path: './outputs/demo/demo-processed' diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index 31a633d9f..6f810d035 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -474,7 +474,7 @@ def test_invalid_max_sample_num(self): def test_builder_ray_config(self): """Test loading Ray configuration from YAML""" - test_config_file = os.path.join(WORK_DIR, 'data/test_config_ray.yaml') + test_config_file = os.path.join(WORK_DIR, 'data', 'test_config_ray.yaml') out = StringIO() with redirect_stdout(out): cfg = init_configs(args=f'--config {test_config_file}'.split()) @@ -489,7 +489,7 @@ def test_builder_ray_config(self): self.assertEqual(cfg.dataset, { 'configs': [{ 'type': 'local', - 'path': '../../demos/process_on_ray/data/demo-dataset.jsonl', + 'path': './data/sample.json', 'weight': 1.0 }] }) diff --git a/tests/tools/test_process_data.py b/tests/tools/test_process_data.py index b0c5a3063..c45f01570 100644 --- a/tests/tools/test_process_data.py +++ b/tests/tools/test_process_data.py @@ -62,8 +62,10 @@ def _test_status_code(self, yaml_file, output_path, text_keys): with open(yaml_file, 'w') as file: yaml.dump(yaml_config, file) + script_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), + "tools", "process_data.py") status_code = subprocess.call( - f'python tools/process_data.py --config {yaml_file}', shell=True) + f'python {script_path} --config {yaml_file}', shell=True) return status_code @@ -95,6 +97,8 @@ def setUp(self): cur_dir = osp.dirname(osp.abspath(__file__)) self.tmp_dir = osp.join(cur_dir, f'tmp_{uuid.uuid4().hex}') + self.script_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), + "tools", "process_data.py") os.makedirs(self.tmp_dir, exist_ok=True) def tearDown(self): @@ -113,7 +117,7 @@ def test_ray_image(self): text_keys = 'text' data_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), - 'demos', 'data', 'demo-dataset-images.jsonl') + 'demos', 'data', 'demo-dataset-images.jsonl') yaml_config = { 'dataset_path': data_path, 'executor_type': 'ray', @@ -141,7 +145,8 @@ def test_ray_image(self): with open(tmp_yaml_file, 'w') as file: yaml.dump(yaml_config, file) - run_in_subprocess(f'python tools/process_data.py --config {tmp_yaml_file}') + print(f"Is the config file present? {os.path.exists(tmp_yaml_file)}") + run_in_subprocess(f'python {self.script_path} --config {tmp_yaml_file}') self.assertTrue(osp.exists(tmp_out_path)) @@ -184,7 +189,9 @@ def test_ray_precise_dedup(self): with open(tmp_yaml_file, 'w') as file: yaml.dump(yaml_config, file) - run_in_subprocess(f'python tools/process_data.py --config {tmp_yaml_file}') + script_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), + "tools", "process_data.py") + run_in_subprocess(f'python {script_path} --config {tmp_yaml_file}') self.assertTrue(osp.exists(tmp_out_path)) @@ -227,7 +234,7 @@ def test_ray_minhash_dedup(self): with open(tmp_yaml_file, 'w') as file: yaml.dump(yaml_config, file) - run_in_subprocess(f'python tools/process_data.py --config {tmp_yaml_file}') + run_in_subprocess(f'python {self.script_path} --config {tmp_yaml_file}') self.assertTrue(osp.exists(tmp_out_path)) @@ -282,7 +289,7 @@ def test_ray_compute_stats_single_filter(self): with open(tmp_yaml_file, 'w') as file: yaml.dump(yaml_config, file) - run_in_subprocess(f'python tools/process_data.py --config {tmp_yaml_file}') + run_in_subprocess(f'python {self.script_path} --config {tmp_yaml_file}') self.assertTrue(osp.exists(tmp_out_path)) @@ -344,7 +351,8 @@ def test_ray_compute_stats_batched_filter(self): with open(tmp_yaml_file, 'w') as file: yaml.dump(yaml_config, file) - run_in_subprocess(f'python tools/process_data.py --config {tmp_yaml_file}') + + run_in_subprocess(f'python {self.script_path} --config {tmp_yaml_file}') self.assertTrue(osp.exists(tmp_out_path)) From 2f49eece52d935463dbefa39cef66bf37b3b68c9 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Fri, 14 Feb 2025 12:25:54 -0800 Subject: [PATCH 56/56] fix test cases --- tests/core/test_data_validator.py | 57 ++++++++++++++++++++-------- tests/core/test_dataload_strategy.py | 4 +- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/tests/core/test_data_validator.py b/tests/core/test_data_validator.py index dccbd8e96..1faba704b 100644 --- a/tests/core/test_data_validator.py +++ b/tests/core/test_data_validator.py @@ -1,32 +1,43 @@ from unittest import TestCase, main import datasets -import ray -import ray.data -import pandas as pd - -from data_juicer.core.data import NestedDataset, RayDataset +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, TEST_TAG +from data_juicer.core.data import NestedDataset from data_juicer.core.data.data_validator import (DataValidationError, RequiredFieldsValidator ) # Test RequiredFieldsValidator -class RequiredFieldsValidatorTest(TestCase): +class RequiredFieldsValidatorTest(DataJuicerTestCaseBase): def setUp(self): # Create sample DataFrame - self.df = pd.DataFrame({ - 'text': ['Hello', 'World', None, 'Test'], - 'metadata': [{'lang': 'en'}, {'lang': 'es'}, {'lang': 'fr'}, None], - 'score': [1.0, 2.0, 3.0, 4.0] - }) + self.data = [ + { + 'text': 'Hello', + 'metadata': {'lang': 'en'}, + 'score': 1.0 + }, + { + 'text': 'World', + 'metadata': {'lang': 'es'}, + 'score': 2.0 + }, + { + 'text': None, + 'metadata': {'lang': 'fr'}, + 'score': 3.0 + }, + { + 'text': 'Test', + 'metadata': None, + 'score': 4.0 + } + ] # Create dataset - self.dataset = NestedDataset(datasets.Dataset.from_pandas(self.df)) - - # Create ray dataset - self.ray_dataset = RayDataset(ray.data.from_pandas(self.df)) - + self.dataset = NestedDataset(datasets.Dataset.from_list(self.data)) + def test_basic_validation(self): """Test basic field validation""" @@ -67,7 +78,14 @@ def test_type_validation(self): validator.validate(self.dataset) self.assertIn("incorrect type", str(exc.exception).lower()) + @TEST_TAG('ray') def test_ray_dataset_support(self): + import ray.data + from data_juicer.core.data import RayDataset + + # Create ray dataset + self.ray_dataset = RayDataset(ray.data.from_items(self.data)) + """Test validation with RayDataset""" config = { 'required_fields': ['text', 'metadata'], @@ -103,7 +121,14 @@ def test_empty_required_fields(self): # Should pass as no fields are required validator.validate(self.dataset) + @TEST_TAG('ray') def test_multiple_dataset_types(self): + import ray.data + from data_juicer.core.data import RayDataset + + # Create ray dataset + self.ray_dataset = RayDataset(ray.data.from_items(self.data)) + """Test validation works with different dataset types""" datasets_to_test = [ ('nested', self.dataset), diff --git a/tests/core/test_dataload_strategy.py b/tests/core/test_dataload_strategy.py index 7700f5b5c..2d490286a 100644 --- a/tests/core/test_dataload_strategy.py +++ b/tests/core/test_dataload_strategy.py @@ -4,12 +4,12 @@ DefaultLocalDataLoadStrategy ) from argparse import Namespace - +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class MockStrategy(DataLoadStrategy): def load_data(self): pass -class DataLoadStrategyRegistryTest(unittest.TestCase): +class DataLoadStrategyRegistryTest(DataJuicerTestCaseBase): def setUp(self): # Clear existing strategies before each test DataLoadStrategyRegistry._strategies = {}