Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve StreamingDataset Speed #19114

Merged
merged 10 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 60 additions & 66 deletions src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
import os
import shutil
import warnings
from threading import Lock, Thread
from queue import Empty
from threading import Thread
from time import sleep
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -34,35 +36,25 @@
class PrepareChunksThread(Thread):
"""This thread is responsible to download the chunks associated to a given worker."""

def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None, pre_download: int = 10) -> None:
def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None) -> None:
super().__init__(daemon=True)
self._config = config
self._chunks_index_to_be_downloaded: List[int] = []
self._chunks_index_to_be_deleted: List[int] = []
self._lock = Lock()
self._max_cache_size = max_cache_size
self._downloaded_chunks = 0
self._processed_chunks = 0
self._processed_chunks_counter = 0
self._delete_chunks = 0
self._pre_download = pre_download
self._should_stop = False

def download(self, chunk_indices: List[int]) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
with self._lock:
for chunk_indice in chunk_indices:
if chunk_indice not in self._chunks_index_to_be_downloaded:
self._chunks_index_to_be_downloaded.append(chunk_indice)
self._to_download_queue: multiprocessing.Queue = multiprocessing.Queue()
self._to_delete_queue: multiprocessing.Queue = multiprocessing.Queue()
self._to_stop_queue: multiprocessing.Queue = multiprocessing.Queue()

def delete(self, chunk_indices: List[int]) -> None:
def download(self, chunk_indexes: List[int]) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
with self._lock:
for chunk_indice in chunk_indices:
if chunk_indice not in self._chunks_index_to_be_deleted:
self._chunks_index_to_be_deleted.append(chunk_indice)
self._processed_chunks += 1
self._processed_chunks_counter += 1
for chunk_index in chunk_indexes:
self._to_download_queue.put(chunk_index)

def delete(self, chunk_indexes: List[int]) -> None:
"""Receive the list of the chunk indices to delete for the current epoch."""
for chunk_index in chunk_indexes:
self._to_delete_queue.put(chunk_index)

def _delete(self, chunk_index: int) -> None:
chunk_filepath, begin, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]
Expand All @@ -72,54 +64,56 @@ def _delete(self, chunk_index: int) -> None:

def stop(self) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
with self._lock:
self._should_stop = True
self._to_stop_queue.put(None)

def run(self) -> None:
while True:
with self._lock:
if self._should_stop:
if (
self._max_cache_size
and self._max_cache_size <= shutil.disk_usage(self._config._cache_dir).total
):
for chunk_index in self._chunks_index_to_be_deleted:
if chunk_index not in self._chunks_index_to_be_downloaded:
self._delete(chunk_index)
self._delete_chunks += 1
self._processed_chunks_counter = 0
return

# Wait for something to do
if len(self._chunks_index_to_be_downloaded) == 0 and len(self._chunks_index_to_be_deleted) == 0:
continue

# Delete the chunks if we are missing disk space.
if self._max_cache_size and self._processed_chunks_counter >= self._pre_download:
try:
chunk_index = self._to_download_queue.get(timeout=0.01)
self._config.download_chunk_from_index(chunk_index)
except Empty:
pass
except OSError as e:
# handle closed queue before the thread terminates
if "handle is closed" in str(e):
pass
else:
raise e

try:
chunk_index = self._to_delete_queue.get(timeout=0.01)
if self._max_cache_size:
if shutil.disk_usage(self._config._cache_dir).total >= self._max_cache_size:
for chunk_index in self._chunks_index_to_be_deleted:
if chunk_index not in self._chunks_index_to_be_downloaded:
self._delete(chunk_index)
self._delete_chunks += 1
self._processed_chunks_counter = 0
self._chunks_index_to_be_deleted = []

# If there is no chunks to download, go back to waiting
if len(self._chunks_index_to_be_downloaded) == 0:
continue

# If we have already downloaded too many chunks, let's wait for processed chunks to catch up
if self._max_cache_size and (self._downloaded_chunks - self._processed_chunks) > self._pre_download:
sleep(0.1)
continue

chunk_index = self._chunks_index_to_be_downloaded.pop(0)

self._config.download_chunk_from_index(chunk_index)
self._downloaded_chunks += 1
self._chunks_index_to_be_deleted.append(chunk_index)

# Delete 2 chunk at the time to give enough space while not blocking downloads
for chunk_index in self._chunks_index_to_be_deleted[:2]:
self._delete(chunk_index)

self._chunks_index_to_be_deleted = self._chunks_index_to_be_deleted[2:]
else:
self._chunks_index_to_be_deleted.append(chunk_index)
except Empty:
pass
except OSError as e:
# handle closed queue before the thread terminates
if "handle is closed" in str(e):
pass
else:
raise e

try:
self._to_stop_queue.get(timeout=0.01)
return
except Empty:
pass
except OSError as e:
# handle closed queue before the thread terminates
if "handle is closed" in str(e):
return
raise e

# Sleep to release the lock
sleep(0.1)
sleep(0.01)


class BinaryReader:
Expand Down
44 changes: 22 additions & 22 deletions tests/tests_data/streaming/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,38 @@ def test_reader_chunk_removal(tmpdir, monkeypatch):
shutil_mock.disk_usage.return_value = disk_usage
monkeypatch.setattr(reader, "shutil", shutil_mock)

expected = []
generated = []
for i in range(25):
expected.append([i, len(os.listdir(cache_dir))])
generated.append([i, len(os.listdir(cache_dir))])
index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), is_last_index=i == 24)
assert cache[index] == i

assert expected == [
assert generated == [
[0, 0],
[1, 1],
[2, 1],
[3, 2],
[4, 2],
[5, 3],
[6, 3],
[7, 4],
[8, 4],
[9, 5],
[10, 5],
[11, 6],
[12, 6],
[13, 7],
[14, 7],
[15, 8],
[16, 8],
[17, 9],
[18, 9],
[19, 10],
[20, 10],
[5, 2],
[6, 2],
[7, 2],
[8, 2],
[9, 2],
[10, 2],
[11, 2],
[12, 2],
[13, 2],
[14, 2],
[15, 2],
[16, 2],
[17, 2],
[18, 2],
[19, 2],
[20, 2],
[21, 2],
[22, 2],
[23, 3],
[24, 3],
[23, 2],
[24, 2],
]

assert len(os.listdir(cache_dir)) in [3, 4]
assert len(os.listdir(cache_dir)) == 2