Skip to content

Commit

Permalink
Partially fixes #1 and adds scene-centric functionality, fixing #2.
Browse files Browse the repository at this point in the history
  • Loading branch information
BorisIvanovic committed Jul 7, 2022
1 parent 2425163 commit d1270b9
Show file tree
Hide file tree
Showing 18 changed files with 978 additions and 61 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ for t in range(1, sim_scene.scene_info.length_timesteps):
`examples/sim_example.py` contains a more comprehensive example which initializes a simulation from a scene in the nuScenes mini dataset, steps through it by replaying agents' GT motions, and computes metrics based on scene statistics (e.g., displacement error from the original GT data, velocity/acceleration/jerk histograms).

## TODO
- Merge in upstream scene batch pull request.
- Create a method like finalize() which writes all the batch information to a TFRecord/WebDataset/some other format which is (very) fast to read from for higher epoch training.
- Add more examples to the README.
- Finish README section about how to add a new dataset.
2 changes: 1 addition & 1 deletion examples/batch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def main():
future_sec=(4.8, 4.8),
only_types=[AgentType.VEHICLE],
agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=True,
incl_robot_future=False,
incl_map=True,
map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)},
augmentations=[noise_hists],
Expand Down
51 changes: 51 additions & 0 deletions examples/scene_batch_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from collections import defaultdict

from torch.utils.data import DataLoader
from tqdm import tqdm

from trajdata import AgentBatch, AgentType, UnifiedDataset
from trajdata.augmentation import NoiseHistories
from trajdata.visualization.vis import plot_scene_batch


def main():
noise_hists = NoiseHistories()

dataset = UnifiedDataset(
desired_data=["nusc_mini-mini_train"],
centric="scene",
desired_dt=0.1,
history_sec=(3.2, 3.2),
future_sec=(4.8, 4.8),
only_types=[AgentType.VEHICLE],
agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=True,
incl_map=True,
map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)},
augmentations=[noise_hists],
max_agent_num=20,
num_workers=4,
verbose=True,
data_dirs={ # Remember to change this to match your filesystem!
"nusc_mini": "~/datasets/nuScenes",
},
)

print(f"# Data Samples: {len(dataset):,}")

dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
collate_fn=dataset.get_collate_fn(),
num_workers=4,
persistent_workers=True,
)

batch: AgentBatch
for batch in tqdm(dataloader):
plot_scene_batch(batch, batch_idx=0)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = trajdata
version = 1.0.4
version = 1.0.5
author = Boris Ivanovic
author_email = bivanovic@nvidia.com
description = A unified interface to many trajectory forecasting datasets.
Expand Down
7 changes: 5 additions & 2 deletions src/trajdata/augmentation/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd

from trajdata.data_structures.batch import AgentBatch
from trajdata.data_structures.batch import AgentBatch, SceneBatch


class Augmentation:
Expand All @@ -14,5 +14,8 @@ def apply(self, scene_data_df: pd.DataFrame) -> None:


class BatchAugmentation(Augmentation):
def apply(self, agent_batch: AgentBatch) -> None:
def apply_agent(self, agent_batch: AgentBatch) -> None:
raise NotImplementedError()

def apply_scene(self, scene_batch: SceneBatch) -> None:
raise NotImplementedError()
9 changes: 7 additions & 2 deletions src/trajdata/augmentation/noise_histories.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import torch

from trajdata.augmentation.augmentation import BatchAugmentation
from trajdata.data_structures.batch import AgentBatch
from trajdata.data_structures.batch import AgentBatch, SceneBatch


class NoiseHistories(BatchAugmentation):
def __init__(self, mean: float = 0.0, stddev: float = 0.1) -> None:
self.mean = mean
self.stddev = stddev

def apply(self, agent_batch: AgentBatch) -> None:
def apply_agent(self, agent_batch: AgentBatch) -> None:
agent_batch.agent_hist[..., :-1, :] += torch.normal(
self.mean, self.stddev, size=agent_batch.agent_hist[..., :-1, :].shape
)
agent_batch.neigh_hist[..., :-1, :] += torch.normal(
self.mean, self.stddev, size=agent_batch.neigh_hist[..., :-1, :].shape
)

def apply_scene(self, scene_batch: SceneBatch) -> None:
scene_batch.agent_hist[..., :-1, :] += torch.normal(
self.mean, self.stddev, size=scene_batch.agent_hist[..., :-1, :].shape
)
67 changes: 66 additions & 1 deletion src/trajdata/data_structures/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class AgentBatch:
maps_resolution: Optional[Tensor]
rasters_from_world_tf: Optional[Tensor]
agents_from_world_tf: Tensor
scene_ids: Optional[List]

def to(self, device) -> None:
excl_vals = {
Expand All @@ -50,6 +51,7 @@ def to(self, device) -> None:
"neigh_types",
"num_neigh",
"robot_fut_len",
"scene_ids",
}
for val in vars(self).keys():
tensor_val = getattr(self, val)
Expand Down Expand Up @@ -96,7 +98,70 @@ def for_agent_type(self, agent_type: AgentType) -> AgentBatch:
if self.rasters_from_world_tf is not None
else None,
agents_from_world_tf=self.agents_from_world_tf[match_type],
scene_ids=[
scene_id
for idx, scene_id in enumerate(self.scene_ids)
if match_type[idx]
],
)


SceneBatch = namedtuple("SceneBatch", "")
@dataclass
class SceneBatch:
data_idx: Tensor
dt: Tensor
num_agents: Tensor
agent_type: Tensor
centered_agent_state: Tensor
agent_hist: Tensor
agent_hist_extent: Tensor
agent_hist_len: Tensor
agent_fut: Tensor
agent_fut_extent: Tensor
agent_fut_len: Tensor
robot_fut: Optional[Tensor]
robot_fut_len: Optional[Tensor]
maps: Optional[Tensor]
maps_resolution: Optional[Tensor]
rasters_from_world_tf: Optional[Tensor]
centered_agent_from_world_tf: Tensor
centered_world_from_agent_tf: Tensor

def to(self, device) -> None:
for val in vars(self).keys():
tensor_val = getattr(self, val)
if tensor_val is not None:
setattr(self, val, tensor_val.to(device))

def agent_types(self) -> List[AgentType]:
unique_types: Tensor = torch.unique(self.agent_type)
return [AgentType(unique_type.item()) for unique_type in unique_types]

def for_agent_type(self, agent_type: AgentType) -> AgentBatch:
match_type = self.agent_type == agent_type
return SceneBatch(
data_idx=self.data_idx[match_type],
dt=self.dt[match_type],
num_agents=self.num_agents[match_type],
agent_type=self.agent_type[match_type],
centered_agent_state=self.centered_agent_state[match_type],
agent_hist=self.agent_hist[match_type],
agent_hist_extent=self.agent_hist_extent[match_type],
agent_hist_len=self.agent_hist_len[match_type],
agent_fut=self.agent_fut[match_type],
agent_fut_extent=self.agent_fut_extent[match_type],
agent_fut_len=self.agent_fut_len[match_type],
robot_fut=self.robot_fut[match_type]
if self.robot_fut is not None
else None,
robot_fut_len=self.robot_fut_len[match_type],
maps=self.maps[match_type] if self.maps is not None else None,
maps_resolution=self.maps_resolution[match_type]
if self.maps_resolution is not None
else None,
rasters_from_world_tf=self.rasters_from_world_tf[match_type]
if self.rasters_from_world_tf is not None
else None,
centered_agent_from_world_tf=self.centered_agent_from_world_tf[match_type],
centered_world_from_agent_tf=self.centered_world_from_agent_tf[match_type],
)
Loading

0 comments on commit d1270b9

Please sign in to comment.