Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add prototype utilities to read arbitrary numeric binary files #4882

Merged
merged 34 commits into from
Nov 19, 2021

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Nov 8, 2021

This is not needed right now, but helps @krshrimali who is porting the Sintel "legacy" dataset to prototypes. This has quite a few parallels to what we do for MNIST:

class MNISTFileReader(IterDataPipe[torch.Tensor]):

We can probably merge these two in the future.

cc @pmeier @bjuncek

@facebook-github-bot
Copy link

facebook-github-bot commented Nov 8, 2021

💊 CI failures summary and remediations

As of commit 0969cf9 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

torchvision/prototype/datasets/datapipes.py Outdated Show resolved Hide resolved
torchvision/prototype/datasets/datapipes.py Outdated Show resolved Hide resolved
torchvision/prototype/datasets/datapipes.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Collaborator Author

pmeier commented Nov 8, 2021

After some more offline discussion with @NicolasHug, I've refactored the PR to use numpy for reading the data since it natively is able to handle the endianness.

Note that we cannot use numpy.fromfile, because it doesn't support reading from compressed files like the ones the MNIST dataset uses:

import gzip
import tempfile

import numpy as np

data = np.array([0.0])

path = tempfile.mktemp()
with open(path, "wb") as file:
    file.write(gzip.compress(data.tobytes()))


with open(path, "rb") as compressed_file:
    with gzip.open(compressed_file, "rb") as file:
        print(np.fromfile(file))

        file.seek(0)
        print(np.frombuffer(file.read()))
[7.1653857e+161 5.9808143e+197]
[0.]

If we go through with this design, we should revisit #4598. It probably than should be reverted for consistency. cc @datumbox

@pmeier pmeier changed the title add FloReader datapipe add prototype utilities to read arbitrary numeric binary files Nov 8, 2021
@NicolasHug
Copy link
Member

My main concern wasn't just about using numpy vs pytorch it was mostly about code complexity. The original flow decoding function is extremely simple, seld-contained, and easy to understand:

def _read_flo(file_name):
"""Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# Everything needs to be in little Endian according to
# https://vision.middlebury.edu/flow/code/flow-code/README.txt
with open(file_name, "rb") as f:
magic = np.fromfile(f, "c", count=4).tobytes()
if magic != b"PIEH":
raise ValueError("Magic number incorrect. Invalid .flo file")
w = int(np.fromfile(f, "<i4", count=1))
h = int(np.fromfile(f, "<i4", count=1))
data = np.fromfile(f, "<f4", count=2 * w * h)
return data.reshape(h, w, 2).transpose(2, 0, 1)

I personally see no strong reason to create a new class or a new helper that will add maintenance overhead - we don't need a helper for anything. I would just copy paste it and be done with it.

If we go through with this design, we should revisit #4598. It probably than should be reverted for consistency

Not necessarily, I don't think we need to worry about consistency between the use of np.frombuffer vs torch.frombuffer. We should just use which tool gets the best job done, while accounting for factors like code simplicity (i.e. maintainability) / robustness.

@pmeier
Copy link
Collaborator Author

pmeier commented Nov 9, 2021

My main concern wasn't just about using numpy vs pytorch it was mostly about code complexity. The original flow decoding function is extremely simple, seld-contained, and easy to understand

That is true with two caveats:

  1. We can't use PyTorch dtypes to indicate the datatypes. Before diving into this, I had no idea what "<f4" meant. IMO, using dtype=torch.float32, byte_order="little" is a lot more expressive and does not require you to look up documentation.

  2. Your implementation assumes that everything stays as numpy.ndarray and thus does not need to care how the data is stored in memory. As soon as you want to convert the result into a torch.Tensor your system also needs to be little endian. In other words: if I'm on a big-endian system, your implementation is broken without manually fixing the byte order:

    >>> torch.from_numpy(np.array([0.0], dtype=np.dtype(">f4")))
    ValueError: given numpy array has byte order different from the native byte order. Conversion between byte orders is currently not supported.

    (This example uses the inverted logic, i.e. having big endian data on a little endian system, but it will be the same the other way around)

As for simplicity, my implementation is even shorter and hides all of the complexity above from the user. Plus it is as self-contained as your version in the sense that both implementations let another function do the heavy lifting for them.

Not necessarily, I don't think we need to worry about consistency between the use of np.frombuffer vs torch.frombuffer. We should just use which tool gets the best job done, while accounting for factors like code simplicity (i.e. maintainability) / robustness.

#4598 was specifically about using torch.frombuffer over np.frombuffer in a situation where endianness matters. Thus, the new implementation is more complex in favor of using a PyTorch native approach. If we go by your sentiment, we should revert the PR.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @pmeier , I made some comments below.

Before merging I think it would be nice to have some tests to ensure that the new implementation of read_flow is consistent with the current one.

I would also be curious if you could run some quick benchmark to see if there's a significant time difference between the 2.

torchvision/prototype/datasets/utils/_internal.py Outdated Show resolved Hide resolved
torchvision/prototype/datasets/utils/_internal.py Outdated Show resolved Hide resolved
torchvision/prototype/datasets/utils/_internal.py Outdated Show resolved Hide resolved
torchvision/prototype/datasets/utils/_internal.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Collaborator Author

pmeier commented Nov 17, 2021

I would also be curious if you could run some quick benchmark to see if there's a significant time difference between the 2.

import contextlib
import pathlib
import time

import torch
from torchvision.datasets._optical_flow import _read_flo as read_flo_baseline
from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import read_flo as read_flo_new


files = [
    file
    for file in (pathlib.Path(datasets.home()) / "sintel" / "training" / "flow").glob("**/*")
    if file.suffix == ".flo"
]


@contextlib.contextmanager
def timeit(label):
    start = time.perf_counter()
    yield
    stop = time.perf_counter()
    print(label, f"{(stop - start) / len(files) * 1e3:.1f} milliseconds per file")


with timeit("baseline"):
    for file in files:
        torch.from_numpy(read_flo_baseline(file).astype("f4", copy=False))

with timeit("baseline+copy"):
    for file in files:
        torch.from_numpy(read_flo_baseline(file).astype("f4"))

with timeit("new"):
    for file in files:
        with open(file, "r+b") as f:
            read_flo_new(f)
baseline 0.3 milliseconds per file
baseline+copy 2.2 milliseconds per file
new 2.3 milliseconds per file

So roughly 8x slower than before, but in absolute terms probably insignificant within the image training loop. The memcopy is the offender. If we add this to the old implementation, the difference is negligible.

@pmeier
Copy link
Collaborator Author

pmeier commented Nov 18, 2021

I've added support for dropping the memcopy if the file is opened in update mode (r+b or w+b). This is even faster than the original solution. In addition, calling bytearray() on the read bytes if we are not in update mode is faster than letting numpy do the copy. With the latest commit and this addition

with timeit("new+update"):
    for file in files:
        with open(file, "r+b") as f:
            read_flo_new(f)

the benchmark script now gives:

baseline 0.3 milliseconds per file
baseline+copy 2.1 milliseconds per file
new 1.3 milliseconds per file
new+update 0.1 milliseconds per file

That makes the new implementation in the regular case roughly 4x slower than the original one.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @pmeier , some last comments for the tests, looks good otherwise

test/test_prototype_datasets_utils.py Outdated Show resolved Hide resolved
test/test_prototype_datasets_utils.py Show resolved Hide resolved
test/test_prototype_datasets_utils.py Show resolved Hide resolved
test/test_prototype_datasets_utils.py Show resolved Hide resolved
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your work and patience @pmeier , approving now since the remaining comments are a bit trivial / not very important.

@pmeier pmeier merged commit 8dcb5b8 into main Nov 19, 2021
@pmeier pmeier deleted the flow-reader-datapipe branch November 19, 2021 10:15
facebook-github-bot pushed a commit that referenced this pull request Nov 30, 2021
…es (#4882)

Summary:
* add FloReader datapipe

* add NumericBinaryReader

* revert unrelated change

* cleanup

* cleanup

* add comment for byte reversal

* use numpy after all

* appease mypy

* use .astype() with copy=False

* add docstring and cleanuo

* reuse current _read_flo and revert MNIST changes

* cleanup

* revert demonstration

* refactor

* cleanup

* add support for mutable memory

* add test

* add comments

* catch more exceptions

* fix mypy

* fix variable names

* hardcode flow sizes in test

* add fix dtype docstring

* expand comment on different reading modes

* add comment about files in update mode

* add tests for fromfile

* cleanup

* cleanup

Reviewed By: NicolasHug

Differential Revision: D32694313

fbshipit-source-id: 53c7c9ed32a948ad4bddc0b219c01e291835206d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants