Skip to content

Commit

Permalink
[Tune] Add use_threads=False in pyarrow syncing (ray-project#32256)
Browse files Browse the repository at this point in the history
Fixes a pyarrow issue where the syncing deadlocks when there are more files in a directory than available CPU cores.

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
Co-authored-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
2 people authored and edoakes committed Mar 22, 2023
1 parent 3403cc9 commit bfe06f2
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
27 changes: 23 additions & 4 deletions python/ray/air/_internal/remote_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ def create_dir(self, path, recursive):
from ray import logger


def _pyarrow_fs_copy_files(
source, destination, source_filesystem=None, destination_filesystem=None, **kwargs
):
if isinstance(source_filesystem, pyarrow.fs.S3FileSystem) or isinstance(
destination_filesystem, pyarrow.fs.S3FileSystem
):
# Workaround multi-threading issue with pyarrow
# https://github.com/apache/arrow/issues/32372
kwargs.setdefault("use_threads", False)

return pyarrow.fs.copy_files(
source,
destination,
source_filesystem=source_filesystem,
destination_filesystem=destination_filesystem,
**kwargs,
)


def _assert_pyarrow_installed():
if pyarrow is None:
raise RuntimeError(
Expand Down Expand Up @@ -214,9 +233,9 @@ def download_from_uri(uri: str, local_path: str, filelock: bool = True):

if filelock:
with TempFileLock(f"{os.path.normpath(local_path)}.lock"):
pyarrow.fs.copy_files(bucket_path, local_path, source_filesystem=fs)
_pyarrow_fs_copy_files(bucket_path, local_path, source_filesystem=fs)
else:
pyarrow.fs.copy_files(bucket_path, local_path, source_filesystem=fs)
_pyarrow_fs_copy_files(bucket_path, local_path, source_filesystem=fs)


def upload_to_uri(
Expand All @@ -233,7 +252,7 @@ def upload_to_uri(
)

if not exclude:
pyarrow.fs.copy_files(local_path, bucket_path, destination_filesystem=fs)
_pyarrow_fs_copy_files(local_path, bucket_path, destination_filesystem=fs)
return

# Else, walk and upload
Expand Down Expand Up @@ -262,7 +281,7 @@ def _should_exclude(candidate: str) -> bool:
full_source_path = os.path.normpath(os.path.join(local_path, candidate))
full_target_path = os.path.normpath(os.path.join(bucket_path, candidate))

pyarrow.fs.copy_files(
_pyarrow_fs_copy_files(
full_source_path, full_target_path, destination_filesystem=fs
)

Expand Down
40 changes: 40 additions & 0 deletions python/ray/tune/tests/test_syncer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import subprocess
import tempfile
import time
from pathlib import Path
from typing import List, Optional
from unittest.mock import patch

import pytest
import boto3
from freezegun import freeze_time

import ray
Expand All @@ -18,6 +20,7 @@
from ray.tune.syncer import Syncer, _DefaultSyncer
from ray.tune.utils.file_transfer import _pack_dir, _unpack_dir
from ray.air._internal.remote_storage import upload_to_uri, download_from_uri
from ray._private.test_utils import simulate_storage


@pytest.fixture
Expand Down Expand Up @@ -673,6 +676,43 @@ def train_func(config):
)


def test_sync_folder_with_many_files_s3(tmpdir):
# Create 256 files to upload
for i in range(256):
(tmpdir / str(i)).write_text("", encoding="utf-8")

root = "bucket_test_syncer/dir"
with simulate_storage("s3", root) as s3_uri:
# Upload to S3

s3 = boto3.client(
"s3", region_name="us-west-2", endpoint_url="http://localhost:5002"
)
s3.create_bucket(
Bucket="bucket_test_syncer",
CreateBucketConfiguration={"LocationConstraint": "us-west-2"},
)
upload_to_uri(tmpdir, s3_uri)

with tempfile.TemporaryDirectory() as download_dir:
download_from_uri(s3_uri, download_dir)

assert (Path(download_dir) / "255").exists()


def test_sync_folder_with_many_files_fs(tmpdir):
# Create 256 files to upload
for i in range(256):
(tmpdir / str(i)).write_text("", encoding="utf-8")

# Upload to file URI
with tempfile.TemporaryDirectory() as upload_dir:
target_uri = "file://" + upload_dir
upload_to_uri(tmpdir, target_uri)

assert (tmpdir / "255").exists()


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit bfe06f2

Please sign in to comment.