Skip to content

feat: use a shared queue across workers for data processing #559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ clean:
rm -rf ./dist

install-dependencies:
pip install -U lightning-sdk
pip install -r requirements.txt
pip install -r requirements/test.txt
pip install -r requirements/docs.txt
176 changes: 128 additions & 48 deletions src/litdata/processing/data_processor.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
@@ -125,7 +125,7 @@ def __init__(
def prepare_structure(self, _: Optional[str]) -> Any:
return self._inputs

def prepare_item(self, item_metadata: Any, output_dir: str, is_last: bool) -> None:
def prepare_item(self, item_metadata: Any, output_dir: str, is_last: bool = False) -> None:
if self._contains_device and self._device is None:
self._find_device()

1 change: 1 addition & 0 deletions src/litdata/streaming/writer.py
Original file line number Diff line number Diff line change
@@ -172,6 +172,7 @@ def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]:

worker_rank = get_worker_rank()
if worker_rank is not None:
print(flush=True) # to prevent truncated printing when using concurrent threads/processes
print(f"Rank {worker_rank} inferred the following `{data_format}` data format.")
self._data_format = data_format
self._data_spec = data_spec
267 changes: 122 additions & 145 deletions tests/processing/test_data_processor.py

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions tests/processing/test_functions.py
Original file line number Diff line number Diff line change
@@ -96,7 +96,7 @@ def test_optimize_append_overwrite(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 5
assert ds[:] == [(i, i**2) for i in range(5)]
assert sorted(ds[:]) == [(i, i**2) for i in range(5)]

with pytest.raises(RuntimeError, match="HINT: If you want to append/overwrite to the existing dataset"):
optimize(
@@ -129,7 +129,7 @@ def test_optimize_append_overwrite(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 5
assert ds[:] == [(i, i**2) for i in range(5, 10)]
assert sorted(ds[:]) == [(i, i**2) for i in range(5, 10)] # each worker can pick items in any order

optimize(
fn=compress,
@@ -143,7 +143,7 @@ def test_optimize_append_overwrite(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 10
assert ds[:] == [(i, i**2) for i in range(5, 15)]
assert sorted(ds[:]) == [(i, i**2) for i in range(5, 15)]

optimize(
fn=compress,
@@ -157,7 +157,7 @@ def test_optimize_append_overwrite(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 15
assert ds[:] == [(i, i**2) for i in range(5, 20)]
assert sorted(ds[:]) == [(i, i**2) for i in range(5, 20)]

with pytest.raises(Exception, match="The config isn't consistent between chunks"):
optimize(
@@ -181,7 +181,7 @@ def test_optimize_append_overwrite(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 5
assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)]
assert sorted(ds[:]) == [(i, i**2, i**3) for i in range(0, 5)]


@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
@@ -216,7 +216,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 4
assert ds[:] == [(i, i**2) for i in range(4)]
assert sorted(ds[:]) == [(i, i**2) for i in range(4)] # for multiple workers, the order of items is not guaranteed
# checkpoints should be deleted
assert not os.path.exists(os.path.join(output_dir, ".checkpoints"))

@@ -257,7 +257,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 8
assert ds[:] == [(i, i**2) for i in range(8)]
assert sorted(ds[:]) == [(i, i**2) for i in range(8)]
# checkpoints should be deleted
assert not os.path.exists(os.path.join(output_dir, ".checkpoints"))

81 changes: 58 additions & 23 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
@@ -758,10 +758,12 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
L = len(dataset)
assert L == 20

returned_data = []
for i in range(L):
sequence = dataset[i]
assert sequence[0].item() == i * block_size
assert sequence[-1].item() == (i + 1) * block_size - 1
returned_data.append((sequence[0].item(), sequence[-1].item()))
expected_data = [(i * block_size, (i + 1) * block_size - 1) for i in range(L)]
assert sorted(returned_data) == expected_data

monkeypatch.setenv("WORLD_SIZE", "2")
monkeypatch.setenv("GLOBAL_RANK", "0")
@@ -780,11 +782,13 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
# one worker will yield 2 batches, other will yield 3 batches => len(dataloader) = 5
assert len(dataloader) == 5

expected = [[0, 10], [60, 70], [20, 30], [80, 90], [40, 50]]
returned = []
# we can't foresay the items that node 0 and node 1 will
# but, they will be different and should completely describe the dataset
# expected = [[0, 10], [60, 70], [20, 30], [80, 90], [40, 50]]
rank_0_returned = []
for batch in dataloader:
returned.append(batch[:, 0].tolist())
assert returned == expected
rank_0_returned.append(batch[:, 0].tolist())
assert len(rank_0_returned) == 5

monkeypatch.setenv("WORLD_SIZE", "2")
monkeypatch.setenv("GLOBAL_RANK", "1")
@@ -795,11 +799,15 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk

assert len(dataloader) == 5

expected = [[100, 110], [160, 170], [120, 130], [180, 190], [140, 150]]
returned = []
rank_1_returned = []
for batch in dataloader:
returned.append(batch[:, 0].tolist())
assert returned == expected
rank_1_returned.append(batch[:, 0].tolist())
assert len(rank_1_returned) == 5

returned_items = sorted(rank_0_returned + rank_1_returned)
assert len(returned_items) == 10
print(f"{returned_items=}")
assert returned_items == [[i, i + 10] for i in range(0, 200, 20)]


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
@@ -978,7 +986,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False):

@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
@mock.patch.dict(os.environ, {}, clear=True)
@pytest.mark.timeout(60)
@pytest.mark.timeout(120)
@pytest.mark.parametrize("shuffle", [True, False])
def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
"""Tests resuming from a chunk past the first chunk, when subsequent chunks don't have the same size."""
@@ -989,6 +997,16 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", optimize_data_cache_dir)
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", optimize_cache_dir)

# 8*10*10 = 800 items will be stored in chunks of max_size = 190 with 4 workers
# so, if 4 chunks of 190 items = 760 items can be packed in 4 chunks
# left chunks = 800 - 760 = 140 chunks
# these 140 chunks can be stored in any random order, so we can't predict the exact count
# but we can put a `min-max` value.
# min => 140 can be stored in a single chunk by a single worker = 4 + 1 = 5 chunks minimum
# max => 140 items can be picked by each of the 4 works = 4 chunks with (~35 items)
# (can't be 35, some will've 30 or 40)
# so, max chunk count = 4 + 4 = 8 chunks maximum

optimize(
fn=_simple_preprocess,
inputs=list(range(8)),
@@ -998,17 +1016,30 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
num_uploaders=1,
item_loader=TokensLoader(block_size=10),
)
assert set(os.listdir(data_dir)) == {
"chunk-0-0.bin",
"chunk-0-1.bin",
"chunk-1-0.bin",
"chunk-1-1.bin",
"chunk-2-0.bin",
"chunk-2-1.bin",
"chunk-3-0.bin",
"chunk-3-1.bin",
"index.json",
}
# print(f"{os.listdir(data_dir)=}")
# # print items in head of each
# for file_name in os.listdir(data_dir):
# file_path = os.path.join(data_dir, file_name)

# with open(file_path, "rb") as f:
# head_bytes = f.read(4) # read first 4 bytes
# if len(head_bytes) < 4:
# print(f"{file_name}: File too short")
# continue
# val = np.frombuffer(head_bytes, dtype=np.int32)[0]
# print(f"{file_name}: {val}")
assert 6 <= len(os.listdir(data_dir)) <= 9 # +1 for index.json file

# check if the dataloader contains the complete dataset
os.mkdir(s3_cache_dir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle)

fetched_dataset = []
for i, batch in enumerate(train_dataloader):
fetched_dataset.extend(batch)
assert len(fetched_dataset) == 80

shutil.rmtree(s3_cache_dir)

os.mkdir(s3_cache_dir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle)
@@ -1029,8 +1060,12 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
assert dataloader_state is not None
assert batch_to_resume_from is not None
train_dataloader.load_state_dict(dataloader_state)
print(f"{dataloader_state=}")
print(f"{batch_to_resume_from=}")
next_batch_data = next(iter(train_dataloader))
print(f"{next_batch_data=}")
# The next batch after resuming must match what we should have gotten next in the initial loop
assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)
assert torch.equal(next_batch_data, batch_to_resume_from)


@pytest.mark.timeout(60)