Skip to content

Commit

Permalink
add determinisitc loading
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Dec 21, 2023
1 parent 54454a9 commit 4d52551
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions transformer_nuggets/llama/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class TrainingConfig:
eval_iters: int = 100
log_interval: int = 200
val_step_count: int = 0
deterministic_data_loading: bool = False

# This overfit param is used to test numerical issues by overfitting
# on a single batch. It should be set to False for normal training.
Expand Down Expand Up @@ -331,30 +332,36 @@ def load_datasets(hyper_params: Hyperparameters, training_config: TrainingConfig
train_data = Dataset(
str(training_config.data_dir / "train.bin"),
max_seq_length=hyper_params.max_seq_length,
overfit=training_config.overfit,
training_config=training_config,
)
val_data = Dataset(
str(training_config.data_dir / "val.bin"),
max_seq_length=hyper_params.max_seq_length,
overfit=training_config.overfit,
training_config=training_config,
)
return train_data, val_data


class Dataset(IterableDataset):
def __init__(self, data_file: Path, max_seq_length: int, overfit: bool = False):
def __init__(self, data_file: Path, max_seq_length: int, training_config: TrainingConfig):
super().__init__()
self.data_file = data_file
self.max_seq_length = max_seq_length
self.overfit = overfit
self.overfit = training_config.overfit
self.deterministic_data_loading = training_config.deterministic_data_loading
self.index = 0

def __iter__(self):
data = np.memmap(self.data_file, dtype=np.uint16, mode="r")
while True:
if self.overfit:
i = 0
else:
i = torch.randint(len(data) - self.max_seq_length, (1,)).item()
if self.deterministic_data_loading:
i = self.index
self.index += self.max_seq_length
else:
i = torch.randint(len(data) - self.max_seq_length, (1,)).item()
x = torch.from_numpy((data[i : i + self.max_seq_length]).astype(np.int64))
y = torch.from_numpy((data[i + 1 : i + 1 + self.max_seq_length]).astype(np.int64))
yield x, y
Expand Down

0 comments on commit 4d52551

Please sign in to comment.