Skip to content

Commit

Permalink
Use rng_seed param when creating custom dataset sampler (#3592)
Browse files Browse the repository at this point in the history
Use the `rng_seed` configuration parameter in class
`PerDatasetSampler.build_sampler_from_config()` static factory class
method. Until now always the fixed default value of 0 was used as seed
for the dataset sampling (which I think was as a bug).
  • Loading branch information
andreaskoepf authored Jul 21, 2023
1 parent 11a1842 commit 6336f31
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
3 changes: 2 additions & 1 deletion model/model_training/trainer_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import random
from argparse import Namespace
from typing import Sequence

import numpy as np
import torch
Expand All @@ -21,7 +22,7 @@
from utils.utils_rl import prepare_tensor


def argument_parsing(notebook=False, notebook_args=None, **kwargs):
def argument_parsing(notebook: bool = False, notebook_args: Sequence[str] | None = None, **kwargs):
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+", required=True)
parser.add_argument("--local_rank", type=int, default=-1)
Expand Down
4 changes: 2 additions & 2 deletions model/model_training/trainer_rm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import logging
import os
from typing import Callable, Literal, Optional, Union
from typing import Callable, Literal, Optional, Sequence, Union

import datasets
import torch
Expand Down Expand Up @@ -128,7 +128,7 @@ def get_train_dataloader(self):
return dataloader


def argument_parsing(notebook=False, notebook_args=None):
def argument_parsing(notebook: bool = False, notebook_args: Sequence[str] | None = None):
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+", required=True)
parser.add_argument("--local_rank", type=int, default=-1)
Expand Down
4 changes: 2 additions & 2 deletions model/model_training/trainer_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import datasets
import torch
Expand Down Expand Up @@ -166,7 +166,7 @@ def get_train_dataloader(self):
return dataloader


def argument_parsing(notebook=False, notebook_args=None):
def argument_parsing(notebook: bool = False, notebook_args: Sequence[str] | None = None):
parser = argparse.ArgumentParser()
parser.add_argument(
"--configs",
Expand Down
8 changes: 5 additions & 3 deletions model/model_training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
self.shuffle = shuffle
self.rank = rank
self.world_size = world_size
self.epoch = 0

if world_size == 1:
self.rank = 0
Expand All @@ -89,7 +90,7 @@ def __init__(
self.seed = seed
self.samples_length = samples_length

def set_epoch(self, epoch) -> None:
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch

def __len__(self) -> int:
Expand Down Expand Up @@ -126,11 +127,12 @@ def __iter__(self):
return iter(epoch_idx)

@classmethod
def build_sampler_from_config(cls, training_conf, datasets: List[Dataset], verbose: bool = False, *args, **kwargs):
def build_sampler_from_config(cls, training_conf, datasets: List[Dataset], verbose: bool = False, **kwargs):
dataset_sizes = [len(x) for x in datasets]
fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes, verbose)
dataset_size_per_epoch = [int(size * frac) for size, frac in zip(dataset_sizes, fractions)]
return cls(dataset_sizes, dataset_size_per_epoch, *args, **kwargs)
seed = training_conf.rng_seed
return cls(dataset_sizes=dataset_sizes, dataset_size_per_epoch=dataset_size_per_epoch, seed=seed, **kwargs)


def get_dataset_fractions(conf, dataset_sizes: List[int], verbose: bool = False):
Expand Down

0 comments on commit 6336f31

Please sign in to comment.