diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md
index 1058222d33..b7a50875b7 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md
@@ -182,6 +182,7 @@ buffer:
- `batch_size`: Number of tasks used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*.
- `total_epochs`: Total number of training epochs.
+- `total_steps`: Optional. The total number of training steps. If specified, `total_epochs` will be ignored.
### Explorer Input
diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py
index 2882dd8e0f..490117da1d 100644
--- a/tests/buffer/file_test.py
+++ b/tests/buffer/file_test.py
@@ -57,7 +57,7 @@ def test_file_buffer(self):
self.assertEqual(loaded_data, data)
self.assertRaises(StopIteration, reader.read)
- def test_file_reader(self):
+ def test_file_reader(self): # noqa: C901
"""Test file reader."""
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
@@ -81,7 +81,21 @@ def test_file_reader(self):
break
self.assertEqual(len(tasks), 16 * 2 - 4)
- # test offset > dataset_len
+ # test total steps and offset
+ self.config.buffer.explorer_input.taskset.total_steps = 5
+ self.config.buffer.explorer_input.taskset.index = 8
+ reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
+ tasks = []
+ while True:
+ try:
+ tasks.extend(reader.read())
+ print(f"read from buffer, current len {len(tasks)}.")
+ except StopIteration:
+ break
+ self.assertEqual(len(tasks), 20 - 8)
+
+ # test offset > dataset_len with total_epoch
+ self.config.buffer.explorer_input.taskset.total_steps = None
self.config.buffer.explorer_input.taskset.total_epochs = 3
self.config.buffer.explorer_input.taskset.index = 20
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
@@ -93,6 +107,18 @@ def test_file_reader(self):
break
self.assertEqual(len(tasks), 16 * 3 - 20)
+ # test offset > dataset_len with total_steps
+ self.config.buffer.explorer_input.taskset.total_steps = 10
+ self.config.buffer.explorer_input.taskset.index = 24
+ reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
+ tasks = []
+ while True:
+ try:
+ tasks.extend(reader.read())
+ except StopIteration:
+ break
+ self.assertEqual(len(tasks), 40 - 24)
+
def test_file_writer(self):
writer = get_buffer_writer(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
diff --git a/tests/template/data/sft_for_gsm8k/sft.jsonl b/tests/template/data/sft_for_gsm8k/sft.jsonl
index a8d6972103..8aa1f6d3d4 100644
--- a/tests/template/data/sft_for_gsm8k/sft.jsonl
+++ b/tests/template/data/sft_for_gsm8k/sft.jsonl
@@ -6,27 +6,3 @@
{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
-{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
-{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
-{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
-{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
-{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
-{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
-{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
-{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
-{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
-{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
-{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
-{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
-{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
-{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
-{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
-{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
-{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
-{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
-{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
-{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
-{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
-{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
-{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
-{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index 27aad9d8cc..4c87732926 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -218,10 +218,11 @@ def test_trainer(self):
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.synchronizer.sync_interval = 1
self.config.trainer.save_interval = 8
- self.config.buffer.trainer_input.sft_warmup_steps = 2
+ # sft data is only enough for 2 steps
self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config(
"sft_for_gsm8k"
)
+ self.config.buffer.trainer_input.sft_warmup_steps = 3
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
@@ -229,16 +230,16 @@ def test_trainer(self):
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
- self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 6)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 7)
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
sft_metrics = parser.metric_list("actor/sft")
- self.assertEqual(parser.metric_max_step(sft_metrics[0]), 2) # SFT
- self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 6) # RFT
+ self.assertEqual(parser.metric_max_step(sft_metrics[0]), 3) # SFT
+ self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 7) # RFT
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
- self.assertEqual(parser.metric_min_step(response_metrics[0]), 3)
- self.assertEqual(parser.metric_max_step(response_metrics[0]), 6)
+ self.assertEqual(parser.metric_min_step(response_metrics[0]), 4)
+ self.assertEqual(parser.metric_max_step(response_metrics[0]), 7)
# test save checkpoint when sft finish
self.assertEqual(
get_checkpoint_dir_with_step_num(
@@ -251,7 +252,7 @@ def test_trainer(self):
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type="verl",
)
- self.assertEqual(step_num, 6)
+ self.assertEqual(step_num, 7)
self.assertTrue(len(os.listdir(os.path.join(checkpoint_dir, "actor"))) > 0)
def tearDown(self):
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index 2304db0ad8..91ca4bc030 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -24,38 +24,37 @@ def __init__(
self,
dataset: Dataset,
name: str,
- max_epoch: int = 1,
+ default_batch_size: int,
+ total_epochs: int = 1,
offset: int = 0,
drop_last: bool = True,
+ total_steps: Optional[int] = None,
):
self.dataset = dataset
self.dataset_size = len(dataset)
self.name = name
self.current_batch_size = None
- self.max_epoch = max_epoch
self.drop_last = drop_last
- if offset >= self.dataset_size:
- self.current_epoch = offset // self.dataset_size
- self.current_offset = offset % self.dataset_size
- else:
- self.current_epoch = 0
- self.current_offset = offset
+
+ self.current_offset = offset
self.iter = iter(self.dataset)
- for _ in range(self.current_offset):
+ for _ in range(self.current_offset % self.dataset_size):
next(self.iter)
- # Initialize tqdm progress bar
- self.total_steps = self.dataset_size * self.max_epoch
+ # convert epochs/steps to sample number
+ if total_steps:
+ self.total_samples = default_batch_size * total_steps
+ else:
+ self.total_samples = self.dataset_size * total_epochs
self.progress_bar = tqdm(
- total=self.total_steps,
+ total=self.total_samples,
desc=f"Dataset [{self.name}] Progressing",
)
- initial = self.current_epoch * self.dataset_size + self.current_offset
- self.progress_bar.update(initial)
+ self.progress_bar.update(self.current_offset)
def read_batch(self, batch_size: int) -> List:
- if self.current_epoch >= self.max_epoch:
+ if self.current_offset >= self.total_samples:
self.progress_bar.close()
raise StopIteration
batch = []
@@ -66,11 +65,10 @@ def read_batch(self, batch_size: int) -> List:
batch.append(item)
self.current_offset += 1
except StopIteration:
- self.current_epoch += 1
- self.current_offset = 0
-
- if self.current_epoch >= self.max_epoch:
+ if self.current_offset >= self.total_samples:
+ # No more data to read
if not self.drop_last and len(batch) > 0:
+ # return last batch
self.progress_bar.update(len(batch))
return batch
else:
@@ -97,9 +95,11 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True),
name=meta.name,
- max_epoch=meta.total_epochs,
+ default_batch_size=self.read_batch_size,
+ total_epochs=meta.total_epochs,
drop_last=True,
- ) # TODO: support resume
+ total_steps=meta.total_steps,
+ )
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
def read(
@@ -176,8 +176,10 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True),
name=meta.name,
- max_epoch=meta.total_epochs,
+ default_batch_size=self.read_batch_size,
+ total_epochs=meta.total_epochs,
drop_last=True,
+ total_steps=meta.total_steps,
) # TODO: support resume
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
@@ -248,14 +250,16 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
# disable datasets caching to avoid reuse old-version dataset
self.epoch = 0
datasets.disable_caching()
+ self.read_batch_size = config.batch_size
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True),
name=meta.name,
- max_epoch=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1,
+ default_batch_size=self.read_batch_size,
+ total_epochs=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1,
offset=self.meta.index,
drop_last=self.meta.task_type == TaskType.EXPLORE,
+ total_steps=meta.total_steps,
)
- self.read_batch_size = config.batch_size
self.prompt_key = meta.format.prompt_key
self.response_key = meta.format.response_key
self.workflow_key = meta.format.workflow_key
@@ -297,9 +301,6 @@ def read(
tasks.append(task)
return tasks
- def reset(self):
- self.dataset.reset()
-
@FILE_READERS.register_module("raw")
class RawDataReader(BufferReader):
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 100d74d9da..29973d8342 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -104,6 +104,9 @@ class StorageConfig:
# ! DO NOT SET, automatically set from buffer.total_epochs
total_epochs: int = 1 # automatically set
+ # ! DO NOT SET, automatically set from buffer.total_steps
+ total_steps: Optional[int] = None # automatically set
+
# ! DO NOT SET, automatically set corresponding to train/eval
task_type: TaskType = TaskType.EXPLORE
@@ -275,6 +278,7 @@ class BufferConfig:
batch_size: int = 1
total_epochs: int = 1
+ total_steps: Optional[int] = None
# for explorer
explorer_input: ExplorerInput = field(default_factory=ExplorerInput)
@@ -438,6 +442,7 @@ def _check_buffer(self) -> None: # noqa: C901
)
self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE
self.buffer.explorer_input.taskset.total_epochs = self.buffer.total_epochs
+ self.buffer.explorer_input.taskset.total_steps = self.buffer.total_steps
if self.buffer.explorer_input.taskset.default_workflow_type is None:
self.buffer.explorer_input.taskset.default_workflow_type = (
self.buffer.explorer_input.default_workflow_type
@@ -520,6 +525,9 @@ def _check_buffer(self) -> None: # noqa: C901
)
if self.buffer.trainer_input.sft_warmup_dataset is not None:
self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = "sft" # TODO
+ self.buffer.trainer_input.sft_warmup_dataset.total_steps = (
+ self.buffer.trainer_input.sft_warmup_steps
+ )
if self.buffer.trainer_input.sft_warmup_dataset.ray_namespace is None:
self.buffer.trainer_input.sft_warmup_dataset.ray_namespace = self.ray_namespace