Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast and extendable dataset sampling #110

Draft
wants to merge 7 commits into
base: modular_dataset
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ The following applies to all files unless otherwise noted:

END OF TERMS AND CONDITIONS

Copyright 2024 ServiceNow, Inc.
Copyright 2024-2025 ServiceNow, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion docs/license.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ title: License
Fast-LLM is licenced under the Apache 2.0 license:

```text
Copyright 2024 ServiceNow, Inc.
Copyright 2024-2025 ServiceNow, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
6 changes: 3 additions & 3 deletions docs/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Now, select the compute environment that matches your setup or preferred workflo
Install Python 3.12 (or later) if it's not already available on your system. For a Python virtual environment, run:

```bash
python3.10 -m venv ./fast-llm-tutorial/venv
python3.12 -m venv ./fast-llm-tutorial/venv
source ./fast-llm-tutorial/venv/bin/activate
pip install --upgrade pip
```
Expand Down Expand Up @@ -202,11 +202,11 @@ Choose based on your goals for this tutorial.

=== "Big"

For the big configuration, we'll use a Llama model with 8B parameters. We'll grab the model from the Huggingface Hub and save it to our inputs folder.
For the big configuration, we'll use a Llama model with 8B parameters. We'll grab the model from the HuggingFace Hub and save it to our inputs folder.

!!! note "Access Required"

Meta gates access to their Llama models. You need to request access to the model from Meta before you can download it at https://huggingface.co/meta-llama/Llama-3.1-8B. You'll need to authenticate with your Hugging Face account to download the model:
Meta gates access to their Llama models. You need to request access to the model from Meta before you can download it at https://huggingface.co/meta-llama/Llama-3.1-8B. You'll need to authenticate with your HuggingFace account to download the model:

```bash
pip install huggingface_hub
Expand Down
64 changes: 23 additions & 41 deletions fast_llm/csrc/data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ void build_blending_indices(py::array_t<int16_t>& dataset_index,
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch,
const bool verbose) {
const int64_t num_samples) {
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
Expand All @@ -115,29 +113,14 @@ py::array build_sample_idx(const py::array_t<int32_t>& sizes_,

// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);

// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();

// Mapping and it's length (1D).
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
int32_t* sample_idx = new int32_t[2*(num_samples+1)];

if (verbose) {
cout << " using:" << endl << std::flush;
cout << " number of documents: " <<
doc_idx_.shape(0) / num_epochs << endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " sequence length: " << seq_length <<
endl << std::flush;
cout << " total number of samples: " << num_samples <<
endl << std::flush;
}

// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
Expand All @@ -151,30 +134,29 @@ py::array build_sample_idx(const py::array_t<int32_t>& sizes_,

while (sample_index <= num_samples) {
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length != 0) {
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length > 0) {
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
remaining_seq_length = 0;
} else {
// Otherwise, start from the beginning of the next document.
++doc_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
} else {
// Otherwise, start from the beginning of the next document.
++doc_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
}

// Method to deallocate memory.
Expand Down
11 changes: 11 additions & 0 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
hint=FieldHint.feature,
valid=check_field(Assert.gt, 0),
)
shuffle_epochs: bool = Field(
default=True,
desc="Shuffle all epochs together. Adds extra randomness,"
" but makes it harder to change the training length after training is started.",
hint=FieldHint.feature,
)
distributed_data_sampling: bool = Field(
default=True,
desc="When possible, distribute data sampling across all available processes to speed it up.",
hint=FieldHint.performance,
)
multiprocessing_context: MultiprocessingContext = Field(
default=MultiprocessingContext.spawn,
desc="Multiprocessing context. Do not touch.",
Expand Down
25 changes: 10 additions & 15 deletions fast_llm/data/dataset/blended.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.config import SamplingConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.config import SamplingConfig
from fast_llm.engine.config_utils.run import log_main_rank
Expand Down Expand Up @@ -44,9 +44,7 @@ def __init__(

if sampling_config.cache_directory is None:
self._dataset_idx_filename, self._sample_idx_filename = None, None
self._dataset_index, self._sample_index = self._build_blending_indices(
sampling_config.verbose and len(self._datasets) <= 20
)
self._dataset_index, self._sample_index = self._build_blending_indices()
else:
group = sampling_config.distributed.world_group
self._dataset_idx_filename = sampling_config.cache_directory / (self._name + "_blending_dataset_idx.npy")
Expand All @@ -57,16 +55,11 @@ def __init__(
if (group is None or group.rank() == 0) and not (
self._dataset_idx_filename.is_file() and self._sample_idx_filename.is_file()
):
dataset_index, sample_index = self._build_blending_indices(
sampling_config.verbose and len(self._datasets) <= 20
)
dataset_index, sample_index = self._build_blending_indices()
sampling_config.cache_directory.mkdir(exist_ok=True, parents=True)
np.save(self._dataset_idx_filename, dataset_index)
np.save(self._sample_idx_filename, sample_index)

safe_barrier(group, self._name)
self._load_mappings(sampling_config.verbose)

def __getstate__(self) -> tuple[typing.Any, ...]:
return (
self._datasets,
Expand All @@ -88,12 +81,13 @@ def __setstate__(self, state: tuple[typing.Any, ...]):
) = state
if isinstance(dataset_index, pathlib.Path):
self._dataset_idx_filename, self._sample_idx_filename = dataset_index, sample_index
self._load_mappings(False)
else:
self._dataset_idx_filename, self._sample_idx_filename = None, None
self._dataset_index, self._sample_index = dataset_index, sample_index

def _load_mappings(self, verbose: bool) -> None:
def _load_mappings(self, verbose: bool = False) -> None:
if hasattr(self, "_dataset_index"):
return
if verbose:
log_main_rank(lambda: f" > loading blending dataset index mapping from {self._dataset_idx_filename}")
self._dataset_index = np.load(self._dataset_idx_filename, mmap_mode="r")
Expand All @@ -104,7 +98,7 @@ def _load_mappings(self, verbose: bool) -> None:
def __len__(self) -> int:
return self._num_samples

def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray]:
def _build_blending_indices(self, verbose: bool = False) -> tuple[np.ndarray, np.ndarray]:
assert _extension_available, (
"The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly."
)
Expand All @@ -117,7 +111,8 @@ def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray
self._weights,
len(self._datasets),
self._num_samples,
verbose,
# TODO: Verbose option?
True, # verbose
)
available_samples_per_dataset = np.array([len(dataset) for dataset in self._datasets])
sampled_per_dataset = np.bincount(dataset_index)
Expand All @@ -137,7 +132,7 @@ def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray
return dataset_index, dataset_sample_index

def __getitem__(self, idx: int) -> typing.Any:
return self._datasets[self._dataset_index[idx]][self._sample_index[idx]]
return self._datasets[self._dataset_index[idx]][self._sample_index[idx].item()]

@property
def name(self):
Expand Down
Loading