Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ buffer:
experience_buffer:
name: dpo_buffer
storage_type: file
enable_progress_bar: True
path: /PATH/TO/DATASET/
format:
prompt_type: plaintext # plaintext/messages/chatpair
Expand Down
1 change: 1 addition & 0 deletions tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get_unittest_dataset_config(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"),
split=split,
enable_progress_bar=False,
format=FormatConfig(
prompt_key="question",
response_key="answer",
Expand Down
31 changes: 26 additions & 5 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import datasets
import transformers
from datasets import Dataset, load_dataset
from ray.experimental.tqdm_ray import tqdm

from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm
from trinity.buffer.buffer_reader import BufferReader
Expand All @@ -19,6 +18,17 @@
FILE_READERS = Registry("file_readers")


class DummyProgressBar:
def __init__(self):
pass

def update(self, num: int):
pass

def close(self):
pass


class _HFBatchReader:
def __init__(
self,
Expand All @@ -29,6 +39,7 @@ def __init__(
offset: int = 0,
drop_last: bool = True,
total_steps: Optional[int] = None,
enable_progress_bar: Optional[bool] = True,
):
self.dataset = dataset
self.dataset_size = len(dataset)
Expand All @@ -47,10 +58,17 @@ def __init__(
self.total_samples = default_batch_size * total_steps
else:
self.total_samples = self.dataset_size * total_epochs
self.progress_bar = tqdm(
total=self.total_samples,
desc=f"Dataset [{self.name}] Progressing",
)

if enable_progress_bar:
from ray.experimental.tqdm_ray import tqdm

self.progress_bar = tqdm(
total=self.total_samples,
desc=f"Dataset [{self.name}] Progressing",
)
else:
self.progress_bar = DummyProgressBar()

self.progress_bar.update(self.current_offset)

def read_batch(self, batch_size: int) -> List:
Expand Down Expand Up @@ -99,6 +117,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
total_epochs=meta.total_epochs,
drop_last=True,
total_steps=meta.total_steps,
enable_progress_bar=meta.enable_progress_bar,
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)

Expand Down Expand Up @@ -180,6 +199,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
total_epochs=meta.total_epochs,
drop_last=True,
total_steps=meta.total_steps,
enable_progress_bar=meta.enable_progress_bar,
) # TODO: support resume
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)

Expand Down Expand Up @@ -259,6 +279,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
offset=self.meta.index,
drop_last=self.meta.task_type == TaskType.EXPLORE,
total_steps=meta.total_steps,
enable_progress_bar=meta.enable_progress_bar,
)
self.prompt_key = meta.format.prompt_key
self.response_key = meta.format.response_key
Expand Down
3 changes: 3 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class StorageConfig:
workflow_args: dict = field(default_factory=dict)
reward_fn_args: dict = field(default_factory=dict)

# enable progress bar (tqdm) for _HFBatchReader
enable_progress_bar: Optional[bool] = True

# get storage from existing experiment
ray_namespace: Optional[str] = None

Expand Down