Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image cast_storage very slow for arrays (e.g. numpy, tensors) #6782

Open
Modexus opened this issue Apr 5, 2024 · 3 comments
Open

Image cast_storage very slow for arrays (e.g. numpy, tensors) #6782

Modexus opened this issue Apr 5, 2024 · 3 comments

Comments

@Modexus
Copy link
Contributor

Modexus commented Apr 5, 2024

Update: see comments below

Describe the bug

Operations that save an image from a path are very slow.
I believe the reason for this is that the image data (numpy) is converted into pyarrow format but then back to python using .pylist() before being converted to a numpy array again.

pylist is already slow but used on a multi-dimensional numpy array such as an image it takes a very long time.

From the trace below we can see that __arrow_array__ takes a long time.
It is currently also called in get_inferred_type, this should be removable #6781 but doesn't change the underyling issue.

The conversion to pyarrow and back also leads to the numpy array having type int64 which causes a warning message because the image type excepts uint8.
However, originally the numpy image array was in uint8.

Steps to reproduce the bug

from PIL import Image
import numpy as np
import datasets
import cProfile

image = Image.fromarray(np.random.randint(0, 255, (2048, 2048, 3), dtype=np.uint8))
image.save("test_image.jpg")

ds = datasets.Dataset.from_dict(
    {"image": ["test_image.jpg"]},
    features=datasets.Features({"image": datasets.Image(decode=True)}),
)

# load as numpy array, e.g. for further processing with map
# same result as map returning numpy arrays
ds.set_format("numpy")

cProfile.run("ds.map(writer_batch_size=1, load_from_cache_file=False)", "restats")
Fri Apr  5 14:56:17 2024    restats

         66817 function calls (64992 primitive calls) in 33.382 seconds

   Ordered by: cumulative time
   List reduced from 1073 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     46/1    0.000    0.000   33.382   33.382 {built-in method builtins.exec}
        1    0.000    0.000   33.382   33.382 <string>:1(<module>)
        1    0.000    0.000   33.382   33.382 arrow_dataset.py:594(wrapper)
        1    0.000    0.000   33.382   33.382 arrow_dataset.py:551(wrapper)
        1    0.000    0.000   33.379   33.379 arrow_dataset.py:2916(map)
        4    0.000    0.000   33.327    8.332 arrow_dataset.py:3277(_map_single)
        1    0.000    0.000   33.311   33.311 arrow_writer.py:465(write)
        2    0.000    0.000   33.311   16.656 arrow_writer.py:423(write_examples_on_file)
        1    0.000    0.000   33.311   33.311 arrow_writer.py:527(write_batch)
        2   14.484    7.242   33.260   16.630 arrow_writer.py:161(__arrow_array__)
        1    0.001    0.001   16.438   16.438 arrow_writer.py:121(get_inferred_type)
        1    0.000    0.000   14.398   14.398 threading.py:637(wait)
        1    0.000    0.000   14.398   14.398 threading.py:323(wait)
        8   14.398    1.800   14.398    1.800 {method 'acquire' of '_thread.lock' objects}
      4/2    0.000    0.000    4.337    2.169 table.py:1800(wrapper)
        2    0.000    0.000    4.337    2.169 table.py:1950(cast_array_to_feature)
        2    0.475    0.238    4.337    2.169 image.py:209(cast_storage)
        9    2.583    0.287    2.583    0.287 {built-in method numpy.array}
        2    0.000    0.000    1.284    0.642 image.py:319(encode_np_array)
        2    0.000    0.000    1.246    0.623 image.py:301(image_to_bytes)

Expected behavior

The numpy image data should be passed through as it will be directly consumed by pillow to convert it to bytes.

As an example one can replace list_of_np_array_to_pyarrow_listarray(data) in __arrow_array__ with just out = data as a test.
We have to change cast_storage of the Image feature so it handles the passed through data (& if to handle type before)

bytes_array = pa.array(
    [encode_np_array(arr)["bytes"] if arr is not None else None for arr in storage],
    type=pa.binary(),
)

Leading to the following:

Fri Apr  5 15:44:27 2024    restats

         66419 function calls (64595 primitive calls) in 0.937 seconds

   Ordered by: cumulative time
   List reduced from 1023 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     47/1    0.000    0.000    0.935    0.935 {built-in method builtins.exec}
      2/1    0.000    0.000    0.935    0.935 <string>:1(<module>)
      2/1    0.000    0.000    0.934    0.934 arrow_dataset.py:594(wrapper)
      2/1    0.000    0.000    0.934    0.934 arrow_dataset.py:551(wrapper)
      2/1    0.000    0.000    0.934    0.934 arrow_dataset.py:2916(map)
        4    0.000    0.000    0.933    0.233 arrow_dataset.py:3277(_map_single)
        1    0.000    0.000    0.883    0.883 arrow_writer.py:466(write)
        2    0.000    0.000    0.883    0.441 arrow_writer.py:424(write_examples_on_file)
        1    0.000    0.000    0.882    0.882 arrow_writer.py:528(write_batch)
        2    0.000    0.000    0.877    0.439 arrow_writer.py:161(__arrow_array__)
      4/2    0.000    0.000    0.877    0.439 table.py:1800(wrapper)
        2    0.000    0.000    0.877    0.439 table.py:1950(cast_array_to_feature)
        2    0.009    0.005    0.877    0.439 image.py:209(cast_storage)
        2    0.000    0.000    0.868    0.434 image.py:335(encode_np_array)
        2    0.000    0.000    0.856    0.428 image.py:317(image_to_bytes)
        2    0.000    0.000    0.822    0.411 Image.py:2376(save)
        2    0.000    0.000    0.822    0.411 PngImagePlugin.py:1233(_save)
        2    0.000    0.000    0.822    0.411 ImageFile.py:517(_save)
        2    0.000    0.000    0.821    0.411 ImageFile.py:545(_encode_tile)
      589    0.803    0.001    0.803    0.001 {method 'encode' of 'ImagingEncoder' objects}

This is of course only a test as it passes through all numpy arrays irrespective of if they should be an image.
Also I guess cast_storage is meant for casting pyarrow storage exclusively.
Converting to pyarrow array seems like a good solution as it also handles pytorch tensors etc., maybe there is a more efficient way to create a PIL image from a pyarrow array?

Not sure how this should be handled but I would be happy to help if there is a good solution.

Environment info

  • datasets version: 2.18.1.dev0
  • Platform: Linux-6.7.11-200.fc39.x86_64-x86_64-with-glibc2.38
  • Python version: 3.12.2
  • huggingface_hub version: 0.22.2
  • PyArrow version: 15.0.2
  • Pandas version: 2.2.1
  • fsspec version: 2024.3.1
@Modexus
Copy link
Contributor Author

Modexus commented Apr 5, 2024

This may be a solution that only changes cast_storage of Image.
However, I'm not totally sure that the assumptions hold that are made about the ListArray.

elif pa.types.is_list(storage.type):
    from .features import Array3DExtensionType

    def get_shapes(arr):
        shape = ()
        while isinstance(arr, pa.ListArray):
            len_curr = len(arr)
            arr = arr.flatten()
            len_new = len(arr)
            shape = shape + (len_new // len_curr,)
        return shape

    def get_dtypes(arr):
        dtype = storage.type
        while hasattr(dtype, "value_type"):
            dtype = dtype.value_type
        return dtype

    arrays = []
    for i, is_null in enumerate(storage.is_null()):
        if not is_null.as_py():
            storage_part = storage.take([i])
            shape = get_shapes(storage_part)
            dtype = get_dtypes(storage_part)

            extension_type = Array3DExtensionType(shape=shape, dtype=str(dtype))
            array = pa.ExtensionArray.from_storage(extension_type, storage_part)
            arrays.append(array.to_numpy().squeeze(0))
        else:
            arrays.append(None)

    bytes_array = pa.array(
        [encode_np_array(arr)["bytes"] if arr is not None else None for arr in arrays],
        type=pa.binary(),
    )
    path_array = pa.array([None] * len(storage), type=pa.string())
    storage = pa.StructArray.from_arrays(
        [bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()
    )

(Edited): to handle nulls

Notably this doesn't change anything about the passing through of data or other things, just in the Image class.
Seems quite fast:

Fri Apr  5 17:55:51 2024    restats

         63818 function calls (61995 primitive calls) in 0.812 seconds

   Ordered by: cumulative time
   List reduced from 1051 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     47/1    0.000    0.000    0.810    0.810 {built-in method builtins.exec}
      2/1    0.000    0.000    0.810    0.810 <string>:1(<module>)
      2/1    0.000    0.000    0.809    0.809 arrow_dataset.py:594(wrapper)
      2/1    0.000    0.000    0.809    0.809 arrow_dataset.py:551(wrapper)
      2/1    0.000    0.000    0.809    0.809 arrow_dataset.py:2916(map)
        3    0.000    0.000    0.807    0.269 arrow_dataset.py:3277(_map_single)
        1    0.000    0.000    0.760    0.760 arrow_writer.py:589(finalize)
        1    0.000    0.000    0.760    0.760 arrow_writer.py:423(write_examples_on_file)
        1    0.000    0.000    0.759    0.759 arrow_writer.py:527(write_batch)
        1    0.001    0.001    0.754    0.754 arrow_writer.py:161(__arrow_array__)
      2/1    0.000    0.000    0.719    0.719 table.py:1800(wrapper)
        1    0.000    0.000    0.719    0.719 table.py:1950(cast_array_to_feature)
        1    0.006    0.006    0.718    0.718 image.py:209(cast_storage)
        1    0.000    0.000    0.451    0.451 image.py:361(encode_np_array)
        1    0.000    0.000    0.444    0.444 image.py:343(image_to_bytes)
        1    0.000    0.000    0.413    0.413 Image.py:2376(save)
        1    0.000    0.000    0.413    0.413 PngImagePlugin.py:1233(_save)
        1    0.000    0.000    0.413    0.413 ImageFile.py:517(_save)
        1    0.000    0.000    0.413    0.413 ImageFile.py:545(_encode_tile)
      397    0.409    0.001    0.409    0.001 {method 'encode' of 'ImagingEncoder' objects}

@jdf-prog
Copy link

jdf-prog commented Apr 5, 2024

Also encounter this problem. Has been strugging with it for a long time...

@Modexus Modexus changed the title Map/Saving Image from external filepath extremely slow Image cast_storage very slow for arrays (e.g. numpy, tensors) Apr 10, 2024
@Modexus
Copy link
Contributor Author

Modexus commented Apr 10, 2024

This actually applies to all arrays (numpy or tensors like in torch), not only from external files.

import numpy as np
import datasets

ds = datasets.Dataset.from_dict(
    {"image": [np.random.randint(0, 255, (2048, 2048, 3), dtype=np.uint8)]},
    features=datasets.Features({"image": datasets.Image(decode=True)}),
)
ds.set_format("numpy")

ds = ds.map(load_from_cache_file=False)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants