Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sintel Dataset to the dataset prototype API #4895

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
eff55d3
WIP: Sintel Dataset
krshrimali Nov 8, 2021
61831dc
Failing to read streamwrapper object in Python
krshrimali Nov 8, 2021
12b5915
KeyZipper updates
krshrimali Nov 9, 2021
d2dcba9
Merge remote-tracking branch 'upstream/main' into dataset/sintel
krshrimali Nov 9, 2021
1b690ac
seek of closed file error for now
krshrimali Nov 9, 2021
6f371c7
Working...
krshrimali Nov 9, 2021
081c70f
Rearranging functions
krshrimali Nov 9, 2021
ad74f96
Merge remote-tracking branch 'upstream/main' into dataset/sintel
krshrimali Nov 10, 2021
6c106e7
Fix mypy failures, minor edits
krshrimali Nov 10, 2021
32cc661
Apply suggestions from code review
krshrimali Nov 11, 2021
7f27e3f
Address reviews...
krshrimali Nov 11, 2021
28def28
Merge branch 'main' into dataset/sintel
krshrimali Nov 11, 2021
cdcb914
Update torchvision/prototype/datasets/_builtin/sintel.py
krshrimali Nov 12, 2021
b58c14b
Add support for 'both' as pass_name
krshrimali Nov 12, 2021
1d7a36e
Merge branch 'dataset/sintel' of github.com:krshrimali/vision into da…
krshrimali Nov 12, 2021
52ba6da
Keep imports in the same block
krshrimali Nov 12, 2021
e515fbb
Convert re.search output to bool
krshrimali Nov 12, 2021
7892eb6
Merge branch 'main' into dataset/sintel
krshrimali Nov 12, 2021
ee3c78f
Address reviews, cleanup, one more todo left...
krshrimali Nov 15, 2021
79c65fb
Merge branch 'dataset/sintel' of github.com:krshrimali/vision into da…
krshrimali Nov 15, 2021
08cd984
Merge branch 'main' into dataset/sintel
krshrimali Nov 15, 2021
591633a
little endian format for data (flow file)
krshrimali Nov 15, 2021
98872fd
Merge branch 'dataset/sintel' of github.com:krshrimali/vision into da…
krshrimali Nov 15, 2021
7ccca53
Merge branch 'main' into dataset/sintel
krshrimali Nov 15, 2021
8f84b51
As per review, use frombuffer consistently
krshrimali Nov 15, 2021
709263c
Merge branch 'dataset/sintel' of github.com:krshrimali/vision into da…
krshrimali Nov 15, 2021
6b40366
Only filter pass name, and not png, include flow filter there
krshrimali Nov 16, 2021
34e8de3
Rename the func
krshrimali Nov 16, 2021
cb904c5
Add label (scene dir), needs review
krshrimali Nov 16, 2021
0e13b3f
Merge branch 'main' into dataset/sintel
krshrimali Nov 16, 2021
10bdc4b
Add test for sintel dataset
krshrimali Nov 17, 2021
7b4265f
Merge branch 'dataset/sintel' of github.com:krshrimali/vision into da…
krshrimali Nov 17, 2021
d34ebe6
Merge branch 'main' into dataset/sintel
krshrimali Nov 17, 2021
54618c6
Remove comment
krshrimali Nov 17, 2021
6c04d5f
Temporary fix + test class fixes
krshrimali Nov 19, 2021
84c4e88
Revert temp fix
krshrimali Nov 19, 2021
ebf7e4a
Merge branch 'main' into dataset/sintel
pmeier Nov 19, 2021
c0b254c
use common read_flo instead of custom implementation
pmeier Nov 19, 2021
e9fa656
remove more obsolete code
pmeier Nov 19, 2021
3724869
[DEBUG] check if tests also run on Python 3.9
pmeier Nov 19, 2021
69194e1
Revert "[DEBUG] check if tests also run on Python 3.9"
pmeier Nov 19, 2021
b4cce90
store bytes to avoid reading twice from file handle
pmeier Nov 22, 2021
527d1fa
Merge branch 'main' into dataset/sintel
pmeier Nov 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD
from .semeion import SEMEION
from .sintel import SINTEL
from .voc import VOC
182 changes: 182 additions & 0 deletions torchvision/prototype/datasets/_builtin/sintel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import io
import pathlib
import re
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Iterable, TypeVar

import numpy as np
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Demultiplexer,
Mapper,
Shuffler,
Filter,
IterKeyZipper,
ZipArchiveReader,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE

T = TypeVar("T")
krshrimali marked this conversation as resolved.
Show resolved Hide resolved

try:
from itertools import pairwise # type: ignore[attr-defined]
except ImportError:
from itertools import tee

def pairwise(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]:
a, b = tee(iterable)
next(b, None)
return zip(a, b)


class InSceneGrouper(IterDataPipe[Tuple[Tuple[str, T], Tuple[str, T]]]):
def __init__(self, datapipe: IterDataPipe[Tuple[str, T]]) -> None:
self.datapipe = datapipe

def __iter__(self) -> Iterator[Tuple[Tuple[str, Any], Tuple[str, Any]]]:
for item1, item2 in pairwise(sorted(self.datapipe)):
if pathlib.Path(item1[0]).parent != pathlib.Path(item2[0]).parent:
continue

yield item1, item2


class SINTEL(Dataset):

_FILE_NAME_PATTERN = re.compile(r"(frame|image)_(?P<idx>\d+)[.](flo|png)")

def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"sintel",
type=DatasetType.IMAGE,
homepage="http://sintel.is.tue.mpg.de/",
valid_options=dict(
split=("train", "test"),
pass_name=("clean", "final", "both"),
),
)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = HttpResource(
"http://files.is.tue.mpg.de/sintel/MPI-Sintel-complete.zip",
sha256="bdc80abbe6ae13f96f6aa02e04d98a251c017c025408066a00204cd2c7104c5f",
)
return [archive]

def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool:
path = pathlib.Path(data[0])
# The dataset contains has the folder "training", while allowed options for `split` are
# "train" and "test", we don't check for equality here ("train" != "training") and instead
# check if split is in the folder name
return split in path.parents[2].name

def _filter_images(self, data: Tuple[str, Any], *, pass_name: str) -> bool:
path = pathlib.Path(data[0])
if pass_name == "both":
matched = path.parents[1].name in ["clean", "final"]
else:
matched = path.parents[1].name == pass_name
return matched and path.suffix == ".png"

def _classify_archive(self, data: Tuple[str, Any], *, pass_name: str) -> Optional[int]:
path = pathlib.Path(data[0])
suffix = path.suffix
if suffix == ".flo":
return 0
elif suffix == ".png":
return 1
else:
return None

def _read_flo(self, file: io.IOBase) -> torch.Tensor:
magic = file.read(4)
if magic != b"PIEH":
raise ValueError("Magic number incorrect. Invalid .flo file")
w = int.from_bytes(file.read(4), "little")
h = int.from_bytes(file.read(4), "little")

data = file.read(2 * w * h * 4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are expected to be encoded as little endian as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a TODO left from the current comments. Working on it now. Thanks @NicolasHug for pointing this out.

Copy link
Contributor Author

@krshrimali krshrimali Nov 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the dtype from np.float32 to <f4 for data_arr, but just wanted to quickly check, is this the correct fix, or is there any other way to do this? cc: @pmeier

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes <f4 would be correct, it's a little endian float that fits in 4 bytes (float32).

But it might be preferable to always call the same function instead of mixing frombuffer and from_bytes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that makes sense. Thanks, @NicolasHug! I've fixed this (preferred using frombuffer in both instances).

data_arr = np.frombuffer(data, dtype="<f4")

# Creating a copy of the underlying array, to avoid UserWarning: "The given NumPy array
# is not writeable, and PyTorch does not support non-writeable tensors."
return torch.from_numpy(np.copy(data_arr.reshape(h, w, 2).transpose(2, 0, 1)))

def _flows_key(self, data: Tuple[str, Any]) -> Tuple[str, int]:
path = pathlib.Path(data[0])
category = path.parent.name
idx = int(self._FILE_NAME_PATTERN.match(path.name).group("idx")) # type: ignore[union-attr]
return category, idx

def _add_fake_flow_data(self, data: Tuple[str, Any]) -> Tuple[tuple, Tuple[str, Any]]:
return ((None, None), data)

def _images_key(self, data: Tuple[Tuple[str, Any], Tuple[str, Any]]) -> Tuple[str, int]:
return self._flows_key(data[0])

def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, io.IOBase], Any],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
config: DatasetConfig,
) -> Dict[str, Any]:
flo, images = data
img1, img2 = images
flow_arr = self._read_flo(flo[1]) if flo[1] else None

path1, buffer1 = img1
path2, buffer2 = img2

return dict(
image1=decoder(buffer1) if decoder else buffer1,
image1_path=path1,
image2=decoder(buffer2) if decoder else buffer2,
image2_path=path2,
flow=flow_arr,
flow_path=flo[0],
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
archive_dp = ZipArchiveReader(dp)

curr_split = Filter(archive_dp, self._filter_split, fn_kwargs=dict(split=config.split))
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
filtered_curr_split = Filter(curr_split, self._filter_images, fn_kwargs=dict(pass_name=config.pass_name))
if config.split == "train":
flo_dp, pass_images_dp = Demultiplexer(
filtered_curr_split,
2,
partial(self._classify_archive, pass_name=config.pass_name),
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
flo_dp = Shuffler(flo_dp, buffer_size=INFINITE_BUFFER_SIZE)
pass_images_dp: IterDataPipe[Tuple[str, Any], Tuple[stry, Any]] = InSceneGrouper(pass_images_dp)
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
zipped_dp = IterKeyZipper(
flo_dp,
pass_images_dp,
key_fn=self._flows_key,
ref_key_fn=self._images_key,
)
else:
pass_images_dp = Shuffler(filtered_curr_split, buffer_size=INFINITE_BUFFER_SIZE)
pass_images_dp = InSceneGrouper(pass_images_dp)
zipped_dp = Mapper(pass_images_dp, self._add_fake_flow_data)

return Mapper(zipped_dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder, config=config))