diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 1058222d33..b7a50875b7 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -182,6 +182,7 @@ buffer: - `batch_size`: Number of tasks used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*. - `total_epochs`: Total number of training epochs. +- `total_steps`: Optional. The total number of training steps. If specified, `total_epochs` will be ignored. ### Explorer Input diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 2882dd8e0f..490117da1d 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -57,7 +57,7 @@ def test_file_buffer(self): self.assertEqual(loaded_data, data) self.assertRaises(StopIteration, reader.read) - def test_file_reader(self): + def test_file_reader(self): # noqa: C901 """Test file reader.""" reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) @@ -81,7 +81,21 @@ def test_file_reader(self): break self.assertEqual(len(tasks), 16 * 2 - 4) - # test offset > dataset_len + # test total steps and offset + self.config.buffer.explorer_input.taskset.total_steps = 5 + self.config.buffer.explorer_input.taskset.index = 8 + reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) + tasks = [] + while True: + try: + tasks.extend(reader.read()) + print(f"read from buffer, current len {len(tasks)}.") + except StopIteration: + break + self.assertEqual(len(tasks), 20 - 8) + + # test offset > dataset_len with total_epoch + self.config.buffer.explorer_input.taskset.total_steps = None self.config.buffer.explorer_input.taskset.total_epochs = 3 self.config.buffer.explorer_input.taskset.index = 20 reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) @@ -93,6 +107,18 @@ def test_file_reader(self): break self.assertEqual(len(tasks), 16 * 3 - 20) + # test offset > dataset_len with total_steps + self.config.buffer.explorer_input.taskset.total_steps = 10 + self.config.buffer.explorer_input.taskset.index = 24 + reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) + tasks = [] + while True: + try: + tasks.extend(reader.read()) + except StopIteration: + break + self.assertEqual(len(tasks), 40 - 24) + def test_file_writer(self): writer = get_buffer_writer( self.config.buffer.trainer_input.experience_buffer, self.config.buffer diff --git a/tests/template/data/sft_for_gsm8k/sft.jsonl b/tests/template/data/sft_for_gsm8k/sft.jsonl index a8d6972103..8aa1f6d3d4 100644 --- a/tests/template/data/sft_for_gsm8k/sft.jsonl +++ b/tests/template/data/sft_for_gsm8k/sft.jsonl @@ -6,27 +6,3 @@ {"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} {"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} {"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} -{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} -{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} -{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} -{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} -{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} -{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} -{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} -{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} -{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} -{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} -{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} -{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} -{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} -{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} -{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} -{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} -{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} -{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} -{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} -{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} -{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} -{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} -{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} -{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 27aad9d8cc..4c87732926 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -218,10 +218,11 @@ def test_trainer(self): self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") self.config.synchronizer.sync_interval = 1 self.config.trainer.save_interval = 8 - self.config.buffer.trainer_input.sft_warmup_steps = 2 + # sft data is only enough for 2 steps self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config( "sft_for_gsm8k" ) + self.config.buffer.trainer_input.sft_warmup_steps = 3 self.config.check_and_update() self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5 @@ -229,16 +230,16 @@ def test_trainer(self): parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") self.assertTrue(len(rollout_metrics) > 0) - self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 6) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 7) actor_metrics = parser.metric_list("actor") self.assertTrue(len(actor_metrics) > 0) sft_metrics = parser.metric_list("actor/sft") - self.assertEqual(parser.metric_max_step(sft_metrics[0]), 2) # SFT - self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 6) # RFT + self.assertEqual(parser.metric_max_step(sft_metrics[0]), 3) # SFT + self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 7) # RFT response_metrics = parser.metric_list("response_length") self.assertTrue(len(response_metrics) > 0) - self.assertEqual(parser.metric_min_step(response_metrics[0]), 3) - self.assertEqual(parser.metric_max_step(response_metrics[0]), 6) + self.assertEqual(parser.metric_min_step(response_metrics[0]), 4) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 7) # test save checkpoint when sft finish self.assertEqual( get_checkpoint_dir_with_step_num( @@ -251,7 +252,7 @@ def test_trainer(self): checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type="verl", ) - self.assertEqual(step_num, 6) + self.assertEqual(step_num, 7) self.assertTrue(len(os.listdir(os.path.join(checkpoint_dir, "actor"))) > 0) def tearDown(self): diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 2304db0ad8..91ca4bc030 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -24,38 +24,37 @@ def __init__( self, dataset: Dataset, name: str, - max_epoch: int = 1, + default_batch_size: int, + total_epochs: int = 1, offset: int = 0, drop_last: bool = True, + total_steps: Optional[int] = None, ): self.dataset = dataset self.dataset_size = len(dataset) self.name = name self.current_batch_size = None - self.max_epoch = max_epoch self.drop_last = drop_last - if offset >= self.dataset_size: - self.current_epoch = offset // self.dataset_size - self.current_offset = offset % self.dataset_size - else: - self.current_epoch = 0 - self.current_offset = offset + + self.current_offset = offset self.iter = iter(self.dataset) - for _ in range(self.current_offset): + for _ in range(self.current_offset % self.dataset_size): next(self.iter) - # Initialize tqdm progress bar - self.total_steps = self.dataset_size * self.max_epoch + # convert epochs/steps to sample number + if total_steps: + self.total_samples = default_batch_size * total_steps + else: + self.total_samples = self.dataset_size * total_epochs self.progress_bar = tqdm( - total=self.total_steps, + total=self.total_samples, desc=f"Dataset [{self.name}] Progressing", ) - initial = self.current_epoch * self.dataset_size + self.current_offset - self.progress_bar.update(initial) + self.progress_bar.update(self.current_offset) def read_batch(self, batch_size: int) -> List: - if self.current_epoch >= self.max_epoch: + if self.current_offset >= self.total_samples: self.progress_bar.close() raise StopIteration batch = [] @@ -66,11 +65,10 @@ def read_batch(self, batch_size: int) -> List: batch.append(item) self.current_offset += 1 except StopIteration: - self.current_epoch += 1 - self.current_offset = 0 - - if self.current_epoch >= self.max_epoch: + if self.current_offset >= self.total_samples: + # No more data to read if not self.drop_last and len(batch) > 0: + # return last batch self.progress_bar.update(len(batch)) return batch else: @@ -97,9 +95,11 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.dataset = _HFBatchReader( load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), name=meta.name, - max_epoch=meta.total_epochs, + default_batch_size=self.read_batch_size, + total_epochs=meta.total_epochs, drop_last=True, - ) # TODO: support resume + total_steps=meta.total_steps, + ) self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) def read( @@ -176,8 +176,10 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.dataset = _HFBatchReader( load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), name=meta.name, - max_epoch=meta.total_epochs, + default_batch_size=self.read_batch_size, + total_epochs=meta.total_epochs, drop_last=True, + total_steps=meta.total_steps, ) # TODO: support resume self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) @@ -248,14 +250,16 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): # disable datasets caching to avoid reuse old-version dataset self.epoch = 0 datasets.disable_caching() + self.read_batch_size = config.batch_size self.dataset = _HFBatchReader( load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), name=meta.name, - max_epoch=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1, + default_batch_size=self.read_batch_size, + total_epochs=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1, offset=self.meta.index, drop_last=self.meta.task_type == TaskType.EXPLORE, + total_steps=meta.total_steps, ) - self.read_batch_size = config.batch_size self.prompt_key = meta.format.prompt_key self.response_key = meta.format.response_key self.workflow_key = meta.format.workflow_key @@ -297,9 +301,6 @@ def read( tasks.append(task) return tasks - def reset(self): - self.dataset.reset() - @FILE_READERS.register_module("raw") class RawDataReader(BufferReader): diff --git a/trinity/common/config.py b/trinity/common/config.py index 100d74d9da..29973d8342 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -104,6 +104,9 @@ class StorageConfig: # ! DO NOT SET, automatically set from buffer.total_epochs total_epochs: int = 1 # automatically set + # ! DO NOT SET, automatically set from buffer.total_steps + total_steps: Optional[int] = None # automatically set + # ! DO NOT SET, automatically set corresponding to train/eval task_type: TaskType = TaskType.EXPLORE @@ -275,6 +278,7 @@ class BufferConfig: batch_size: int = 1 total_epochs: int = 1 + total_steps: Optional[int] = None # for explorer explorer_input: ExplorerInput = field(default_factory=ExplorerInput) @@ -438,6 +442,7 @@ def _check_buffer(self) -> None: # noqa: C901 ) self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE self.buffer.explorer_input.taskset.total_epochs = self.buffer.total_epochs + self.buffer.explorer_input.taskset.total_steps = self.buffer.total_steps if self.buffer.explorer_input.taskset.default_workflow_type is None: self.buffer.explorer_input.taskset.default_workflow_type = ( self.buffer.explorer_input.default_workflow_type @@ -520,6 +525,9 @@ def _check_buffer(self) -> None: # noqa: C901 ) if self.buffer.trainer_input.sft_warmup_dataset is not None: self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = "sft" # TODO + self.buffer.trainer_input.sft_warmup_dataset.total_steps = ( + self.buffer.trainer_input.sft_warmup_steps + ) if self.buffer.trainer_input.sft_warmup_dataset.ray_namespace is None: self.buffer.trainer_input.sft_warmup_dataset.ray_namespace = self.ray_namespace