Skip to content

Commit

Permalink
Merge pull request #41 from Visual-Behavior/1-update-samples
Browse files Browse the repository at this point in the history
1 update samples
  • Loading branch information
thibo73800 authored Sep 9, 2021
2 parents 04e8881 + 8446b05 commit c64dae3
Show file tree
Hide file tree
Showing 13 changed files with 40 additions and 24 deletions.
18 changes: 10 additions & 8 deletions alodataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
import aloscene

DATASETS_DOWNLOAD_PATHS = {
"coco": "https://storage.googleapis.com/visualbehavior-sample/coco_sample.pkl",
"waymo": "https://storage.googleapis.com/visualbehavior-sample/waymo_sample.pkl",
"mot17": "https://storage.googleapis.com/visualbehavior-sample/mot17_sample.pkl",
"chairsSDHom": "https://storage.googleapis.com/visualbehavior-sample/chairsSDHom_sample.pkl",
"crowdhuman": "https://storage.googleapis.com/visualbehavior-sample/crowdhuman_sample.pkl",
"flyingChairs2": "https://storage.googleapis.com/visualbehavior-sample/flyingChairs2_sample.pkl",
"flyingThings": "https://storage.googleapis.com/visualbehavior-sample/flyingThings_sample.pkl",
"Sintel": "https://storage.googleapis.com/visualbehavior-sample/Sintel_sample.pkl",
"coco": "https://storage.googleapis.com/visualbehavior-sample/coco.pkl",
"waymo": "https://storage.googleapis.com/visualbehavior-sample/waymo.pkl",
"mot17": "https://storage.googleapis.com/visualbehavior-sample/mot17.pkl",
"chairsSDHom": "https://storage.googleapis.com/visualbehavior-sample/chairsSDHom.pkl",
"crowdhuman": "https://storage.googleapis.com/visualbehavior-sample/crowdhuman.pkl",
"FlyingChairs2": "https://storage.googleapis.com/visualbehavior-sample/FlyingChairs2.pkl",
"FlyingThings3DSubset": "https://storage.googleapis.com/visualbehavior-sample/FlyingThings3DSubset.pkl",
"SintelDisparity": "https://storage.googleapis.com/visualbehavior-sample/SintelDisparity.pkl",
"SintelFlow": "https://storage.googleapis.com/visualbehavior-sample/SintelFlow.pkl",
"SintelMulti": "https://storage.googleapis.com/visualbehavior-sample/SintelMulti.pkl",
}


Expand Down
2 changes: 2 additions & 0 deletions alodataset/chairssdhom_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def getitem(self, idx):
"""
Return frames corresponding to a sequence
"""
if self.sample:
return BaseDataset.__getitem__(self, idx)
sequence_data = self.items[idx]
return self.get_frames(sequence_data)

Expand Down
4 changes: 3 additions & 1 deletion alodataset/crowd_human_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def load_gt(self, dict_input):
return bboxes, classes

def getitem(self, idx):
if self.sample:
return BaseDataset.__getitem__(self, idx)

record = self.items[idx]

image_id = record["ID"]
Expand Down Expand Up @@ -256,7 +259,6 @@ def main():
crowd_human_dataset = CrowdHumanDataset(sample=True)

crowd_human_dataset.prepare()

for i, frames in enumerate(crowd_human_dataset.train_loader(batch_size=2, sampler=None, num_workers=0)):
frames = Frame.batch_list(frames)
frames.get_view().render(figsize=(20, 10))
Expand Down
4 changes: 4 additions & 0 deletions alodataset/flying_chairs2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class FlyingChairs2Dataset(BaseDataset, SplitMixin):

def __init__(self, **kwargs):
super(FlyingChairs2Dataset, self).__init__(name="FlyingChairs2", **kwargs)
if self.sample:
return
self.dir_path = self._dir_path()
self.items = self._get_sequences()

Expand Down Expand Up @@ -73,6 +75,8 @@ def getitem(self, idx):
idx : int
index of the sequence
"""
if self.sample:
return BaseDataset.__getitem__(self, idx)
sequence_data = self.items[idx]
return self._get_frames(sequence_data)

Expand Down
2 changes: 2 additions & 0 deletions alodataset/flyingthings3D_subset_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ def get_frames(self, sequence_data):
return frames

def getitem(self, idx):
if self.sample:
return BaseDataset.__getitem__(self, idx)
sequence_data = self.items[idx]
return self.get_frames(sequence_data)

Expand Down
4 changes: 2 additions & 2 deletions alodataset/merge_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def train_loader(self, batch_size=1, num_workers=2, sampler=torch.utils.data.Ran
if __name__ == "__main__":
from alodataset import ChairsSDHomDataset, FlyingThings3DSubsetDataset, Split

chairs = ChairsSDHomDataset(split=Split.VAL)
flying = FlyingThings3DSubsetDataset(split=Split.VAL, sequence_size=2, transform_fn=lambda f: f["left"])
chairs = ChairsSDHomDataset(sample=True)
flying = FlyingThings3DSubsetDataset(sample=True, transform_fn=lambda f: f["left"])

multi = MergeDataset([chairs, flying])

Expand Down
8 changes: 4 additions & 4 deletions alodataset/mot17.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def _add_line(self, sequence: str, line: str, config: configparser.ConfigParser)
)

def getitem(self, idx):
if self.sample:
return BaseDataset.__getitem__(self, idx)

seq = list(self.items[idx]["seq"])
sequence_name = self.items[idx]["mot_sequence"]

Expand Down Expand Up @@ -219,10 +222,7 @@ def getitem(self, idx):

def main():
"""Main"""
mot_dataset = Mot17(
split=Split.TRAIN, validation_sequences=["MOT17-05"], detections_set=["FRCNN"], sequence_size=2, random_step=30
)

mot_dataset = Mot17(sample=True)
for frames in mot_dataset.stream_loader():
frames.names
frames.get_view(
Expand Down
5 changes: 5 additions & 0 deletions alodataset/sintel_base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(self, cameras=None, labels=None, passes=None, sintel_sequences=None
self.labels = labels if labels is not None else self.LABELS
self.passes = passes if passes is not None else self.PASSES
self.sintel_sequences = sintel_sequences if sintel_sequences is not None else self.SINTEL_SEQUENCES
if self.sample:
return

self._assert_inputs()

self.items = self._get_sequences()
Expand Down Expand Up @@ -154,5 +157,7 @@ def _get_frames(self, sequence_data):
return frames

def getitem(self, idx):
if self.sample:
return BaseDataset.__getitem__(self, idx)
sequence_data = self.items[idx]
return self._get_frames(sequence_data)
4 changes: 2 additions & 2 deletions alodataset/sintel_disparity_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def _get_camera_frames(self, sequence_data, camera):


if __name__ == "__main__":
dataset = SintelDisparityDataset(sequence_size=2)
dataset = SintelDisparityDataset(sample=True)
# show some frames at various indices
for idx in [1, 15, 64]:
for idx in [1, 2, 5]:
frames = dataset.getitem(idx)["left"]
frames.get_view().render()
4 changes: 2 additions & 2 deletions alodataset/sintel_flow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def _get_camera_frames(self, sequence_data, camera):


if __name__ == "__main__":
dataset = SintelFlowDataset(sequence_size=2)
dataset = SintelFlowDataset(sample=True)
# show some frames at various indices
for idx in [1, 15, 64]:
for idx in [1, 2, 5]:
frames = dataset.getitem(idx)["left"]
frames.get_view().render()
4 changes: 2 additions & 2 deletions alodataset/sintel_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def set_dataset_dir(self, dataset_dir: str):


if __name__ == "__main__":
dataset = SintelMultiDataset(sequence_size=2)
dataset = SintelMultiDataset(sample=True)
# show some frames at various indices
for idx in [1, 15, 64]:
for idx in [1, 2, 5]:
frames = dataset.getitem(idx)
frames["left"].get_view().render()
frames["right"].get_view().render()
1 change: 0 additions & 1 deletion alodataset/waymo_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ def main():
"""Main"""
waymo_dataset = WaymoDataset(sample=True)
waymo_dataset.prepare()
print(waymo_dataset.dataset_dir)

for frames in waymo_dataset.train_loader(batch_size=2):
frames = Frame.batch_list([frame["front"] for frame in frames])
Expand Down
4 changes: 2 additions & 2 deletions unittest/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def test_deformable_detr(mock_args):


@mock.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**raft_args))
def _test_raft(mock_args):
def test_raft(mock_args):
raft_train_on_chairs()


if __name__ == "__main__":
test_deformable_detr()
test_detr()
# test_raft()
test_raft()

0 comments on commit c64dae3

Please sign in to comment.