Skip to content

Commit

Permalink
[data] webdataset - flatten return args (ray-project#48674)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

Closes ray-project#48672

## Related issue number

<!-- For example: "Closes ray-project#1234" -->

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [ ] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

Signed-off-by: jukejian <jukejian@bytedance.com>
Signed-off-by: hjiang <dentinyhao@gmail.com>
  • Loading branch information
Jay-ju authored and dentiny committed Dec 7, 2024
1 parent 29602d4 commit cace421
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -362,4 +362,9 @@ def get_tar_file_iterator():
for sample in samples:
if self.decoder is not None:
sample = _apply_list(self.decoder, sample, default=_default_decoder)
yield pd.DataFrame({k: [v] for k, v in sample.items()})
yield pd.DataFrame(
{
k: v if isinstance(v, list) and len(v) == 1 else [v]
for k, v in sample.items()
}
)
37 changes: 37 additions & 0 deletions python/ray/data/tests/test_webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,43 @@ def test_webdataset_coding(ray_start_2_cpus, tmp_path):
assert sample["custom"] == "custom-value"


def test_webdataset_decoding(ray_start_2_cpus, tmp_path):
import numpy as np
import torch

image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
gray = np.random.randint(0, 255, (100, 100), dtype=np.uint8)
dstruct = dict(a=np.nan, b=dict(c=2), d="hello", e={"img_filename": "for_test.jpg"})
ttensor = torch.tensor([1, 2, 3]).numpy()

sample = {
"__key__": "foo",
"jpg": image,
"gray.png": gray,
"mp": dstruct,
"json": dstruct,
"pt": ttensor,
"und": b"undecoded",
"custom": b"nothing",
}

# write the encoded data using the default encoder
data = [sample]
ds = ray.data.from_items(data).repartition(1)
ds.write_webdataset(path=tmp_path, try_create_dir=True)

ds = ray.data.read_webdataset(
paths=[str(tmp_path)],
override_num_blocks=1,
decoder=None,
)
samples = ds.take(1)
import json

meta_json = json.loads(samples[0]["json"].decode("utf-8"))
assert meta_json["e"]["img_filename"] == "for_test.jpg"


@pytest.mark.parametrize("num_rows_per_file", [5, 10, 50])
def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_per_file):
ray.data.from_items(
Expand Down

0 comments on commit cace421

Please sign in to comment.