diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 0a0864b8ef..ec7b4334a5 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -23,6 +23,7 @@ buffer: experience_buffer: name: dpo_buffer storage_type: file + enable_progress_bar: True path: /PATH/TO/DATASET/ format: prompt_type: plaintext # plaintext/messages/chatpair diff --git a/tests/tools.py b/tests/tools.py index a9b2ca8349..60df1122f2 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -48,6 +48,7 @@ def get_unittest_dataset_config( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"), split=split, + enable_progress_bar=False, format=FormatConfig( prompt_key="question", response_key="answer", diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index bc49b871d3..fa5904203d 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -5,7 +5,6 @@ import datasets import transformers from datasets import Dataset, load_dataset -from ray.experimental.tqdm_ray import tqdm from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm from trinity.buffer.buffer_reader import BufferReader @@ -19,6 +18,17 @@ FILE_READERS = Registry("file_readers") +class DummyProgressBar: + def __init__(self): + pass + + def update(self, num: int): + pass + + def close(self): + pass + + class _HFBatchReader: def __init__( self, @@ -29,6 +39,7 @@ def __init__( offset: int = 0, drop_last: bool = True, total_steps: Optional[int] = None, + enable_progress_bar: Optional[bool] = True, ): self.dataset = dataset self.dataset_size = len(dataset) @@ -47,10 +58,17 @@ def __init__( self.total_samples = default_batch_size * total_steps else: self.total_samples = self.dataset_size * total_epochs - self.progress_bar = tqdm( - total=self.total_samples, - desc=f"Dataset [{self.name}] Progressing", - ) + + if enable_progress_bar: + from ray.experimental.tqdm_ray import tqdm + + self.progress_bar = tqdm( + total=self.total_samples, + desc=f"Dataset [{self.name}] Progressing", + ) + else: + self.progress_bar = DummyProgressBar() + self.progress_bar.update(self.current_offset) def read_batch(self, batch_size: int) -> List: @@ -99,6 +117,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): total_epochs=meta.total_epochs, drop_last=True, total_steps=meta.total_steps, + enable_progress_bar=meta.enable_progress_bar, ) self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) @@ -180,6 +199,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): total_epochs=meta.total_epochs, drop_last=True, total_steps=meta.total_steps, + enable_progress_bar=meta.enable_progress_bar, ) # TODO: support resume self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) @@ -259,6 +279,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): offset=self.meta.index, drop_last=self.meta.task_type == TaskType.EXPLORE, total_steps=meta.total_steps, + enable_progress_bar=meta.enable_progress_bar, ) self.prompt_key = meta.format.prompt_key self.response_key = meta.format.response_key diff --git a/trinity/common/config.py b/trinity/common/config.py index 1e0bcc5e9d..506fe2c147 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -101,6 +101,9 @@ class StorageConfig: workflow_args: dict = field(default_factory=dict) reward_fn_args: dict = field(default_factory=dict) + # enable progress bar (tqdm) for _HFBatchReader + enable_progress_bar: Optional[bool] = True + # get storage from existing experiment ray_namespace: Optional[str] = None