Skip to content

Commit

Permalink
devkit-v0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
mh0797 committed Apr 3, 2024
1 parent 8674f92 commit 1d25630
Show file tree
Hide file tree
Showing 28 changed files with 19,101 additions and 663 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@


## Changelog <a name="changelog"></a>
- **`[2024/04/03]`** NAVSIM v0.4 release
- Support for test phase frames of competition
- Download script for trainval
- Egostatus MLP Agent and training pipeline
- Refactoring, Fixes, Documentation
- **`[2024/03/25]`** NAVSIM v0.3 release (official devkit version for warm-up phase)
- Changes env variable NUPLAN_EXP_ROOT to NAVSIM_EXP_ROOT
- Adds code for Leaderboard submission
Expand Down
31 changes: 31 additions & 0 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,37 @@ Let’s dig deeper into this class. It has to implement the following methods:
Details on the output format can be found below.

**The future trajectory has to be returned as an object of type `from navsim.common.dataclasses.Trajectory`. For examples, see the constant velocity agent or the human agent.**

# Learning-based Agents
Most likely, your agent will involve learning-based components.
Navsim provides a lightweight and easy-to-use interface for training.
To use it, your agent has to implement some further functionality.
In addition to the methods mentioned above, you have to implement the methods below.
Have a look at `navsim.agents.ego_status_mlp_agent.EgoStatusMLPAgent` for an example.

- `get_feature_builders()`
Has to return a List of feature builders (of type `navsim.planning.training. abstract_feature_target_builder.AbstractFeatureBuilder`).
FeatureBuilders take the `AgentInput` object and compute the feature tensors used for agent training and inference. One feature builder can compute multiple feature tensors. They have to be returned in a dictionary, which is then provided to the model in the forward pass.
Currently, we provide the following feature builders:
- EgoStateFeatureBuilder (returns a Tensor containing current velocity, acceleration and driving command)
- _the list will be increased in future devkit versions_

- `get_target_builders()`
Similar to `get_feature_builders()`, returns the target builders of type `navsim.planning.training. abstract_feature_target_builder.AbstractTargetBuilder` used in training. In contrast to feature builders, they have access to the Scene object which contains ground-truth information (instead of just the AgentInput).

- `forward()`
The forward pass through the model. Features are provided as a dictionary which contains all the features generated by the feature builders. All tensors are already batched and on the same device as the model. The forward pass has to output a Dict of which one entry has to be "trajectory" and contain a tensor representing the future trajectory, i.e. of shape [B, T, 3], where B is the batch size, T is the number of future timesteps and 3 refers to x,y,heading.

- `compute_loss`()`
Given the features, the targets and the model predictions, this function computes the loss used for training. The loss has to be returned as a single Tensor.

- `get_optimizers()`
Use this function to define the optimizers used for training.
Depending on wheter you want to use a learning-rate scheduler or not, this function needs to either return just an Optimizer (of type `torch.optim.Optimizer`) or a dictionary that contains the Optimizer (key: "optimizer") and the learning-rate scheduler of type `torch.optim.lr_scheduler.LRScheduler` (key: "lr_scheduler").

- `compute_trajectory()`
In contrast to the non-learning-based Agent, you don't have to implement this function.
In inference, the trajectory will automatically be computed using the feature builders and the forward method.
## Inputs

`get_sensor_config()` can be overwritten to determine which sensors are accessible to the agent.
Expand Down
9 changes: 8 additions & 1 deletion docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ Navigate to the download directory and download the maps
cd download && ./download_maps
```

Next download the mini split and the test split
Next download the splits you want to use.
You can download the mini, trainval, test and submision_test split with the following scritps
```
./download_mini
./download_trainval
./download_test
./download_competition_test
```

**The mini split and the test split take around ~160GB and ~220GB of memory respectively**
Expand All @@ -36,9 +39,13 @@ This will download the splits into the download directory. From there, move it t
   ├── maps
   ├── navsim_logs
| ├── test
| ├── trainval
| ├── competition_test
   │ └── mini
   └── sensor_blobs
├── test
├── trainval
├── competition_test
   └── mini
```
Set the required environment variables, by adding the following to your `~/.bashrc` file
Expand Down
14 changes: 9 additions & 5 deletions docs/submission.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ NAVSIM comes with official leaderboards on HuggingFace. The leaderboards prevent

To submit to a leaderboard you need to create a pickle file that contains a trajectory for each test scenario. NAVSIM provides a script to create such a pickle file.

Have a look at `run_cv_submission_evaluation.sh`: this file creates the pickle file for the ConstantVelocity agent. You can run it for your own agent by replacing the `agent` override.

**Note that you have to set the variables `TEAM_NAME`, `AUTHORS`, `EMAIL`, `INSTITUTION`, and `COUNTRY` for your submission to be valid.**
Have a look at `run_create_submission_pickle.sh`: this file creates the pickle file for the ConstantVelocity agent. You can run it for your own agent by replacing the `agent` override.
Follow the [submission instructions on huggingface](https://huggingface.co/spaces/AGC2024-P/e2e-driving-2024) to upload your submission.
**Note that you have to set the variables `TEAM_NAME`, `AUTHORS`, `EMAIL`, `INSTITUTION`, and `COUNTRY` in `run_create_submission_pickle.sh` to generate a valid submission file**

### Warm-up track
The warm-up track evaluates your submission on a [warm-up leaderboard](https://huggingface.co/spaces/AGC2024-P/e2e-driving-warmup) based on the `mini` split. This allows you to test your method and get familiar with the devkit and the submisison procedure, with a less restrictive submission budget (up to 5 submissions daily). Instructions on making a submission on HuggingFace are available in the HuggingFace space. Performance on the warm-up leaderboard is not taken into consideration for determining your team's ranking for the 2024 Autonomous Grand Challenge.
Use the script `run_create_submission_pickle_warmup.sh` which already contains the overrides `scene_filter=warmup_test` and `split=mini` to generate the submission file for the warmup track.

You should be able to obtain the same evaluation results as on the server, by running the evaluation locally with the `warmup_test` scene filter. To do so, use the override `scene_filter=warmup_test` when executing the script to run the PDM scoring (e.g., `run_cv_pdm_score_evaluation.sh` for the constant-velocity agent).
You should be able to obtain the same evaluation results as on the server, by running the evaluation locally.
To do so, use the overrides `scene_filter=warmup_test` when executing the script to run the PDM scoring (e.g., `run_cv_pdm_score_evaluation.sh` for the constant-velocity agent).

### Formal track
This is the [official challenge leaderboard](https://huggingface.co/spaces/AGC2024-P/e2e-driving-2024), based on secret held-out test frames. **Details and instructions for submission will be provided soon!**
This is the [official challenge leaderboard](https://huggingface.co/spaces/AGC2024-P/e2e-driving-2024), based on secret held-out test frames (see submission_test split on the install page).
Use the script `run_create_submission_pickle.sh`. It will by default run with `scene_filter=competition_test` and `split=competition_test`.
You only need to set your own agent with the `agent` override.
9 changes: 9 additions & 0 deletions download/download_competition_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
wget https://huggingface.co/datasets/OpenDriveLab/OpenScene/resolve/main/openscene-v1.1/openscene_metadata_private_test_e2e.tgz
tar -xzf openscene_metadata_private_test_e2e.tgz
rm openscene_metadata_private_test_e2e.tgz
mv competition_test competition_test_navsim_logs

wget https://huggingface.co/datasets/OpenDriveLab/OpenScene/resolve/main/openscene-v1.1/openscene_sensor_private_test_e2e.tgz
tar -xzf openscene_sensor_private_test_e2e.tgz
rm openscene_sensor_private_test_e2e.tgz
mv competition_test competition_test_sensor_blobs
21 changes: 21 additions & 0 deletions download/download_trainval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
wget https://huggingface.co/datasets/OpenDriveLab/OpenScene/resolve/main/openscene-v1.1/openscene_metadata_trainval.tgz
tar -xzf openscene_metadata_trainval.tgz
rm openscene_metadata_trainval.tgz

for split in {0..142}; do
wget https://huggingface.co/datasets/OpenDriveLab/OpenScene/resolve/main/openscene-v1.1/openscene_sensor_trainval_camera/openscene_sensor_trainval_camera_${split}.tgz
echo "Extracting file openscene_sensor_trainval_camera_${split}.tgz"
tar -xzf openscene_sensor_trainval_camera_${split}.tgz
rm openscene_sensor_trainval_camera_${split}.tgz
done

for split in {0..142}; do
wget https://huggingface.co/datasets/OpenDriveLab/OpenScene/resolve/main/openscene-v1.1/openscene_sensor_trainval_lidar/openscene_sensor_trainval_lidar_${split}.tgz
echo "Extracting file openscene_sensor_trainval_lidar_${split}.tgz"
tar -xzf openscene_sensor_trainval_lidar_${split}.tgz
rm openscene_sensor_trainval_lidar_${split}.tgz
done

mv openscene-v1.1/meta_datas trainval_navsim_logs
mv openscene-v1.1/sensor_blobs trainval_sensor_blobs
rm -r openscene-v1.1
98 changes: 73 additions & 25 deletions navsim/agents/abstract_agent.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,18 @@
from __future__ import annotations

import abc

from abc import abstractmethod
from typing import Any, List
from abc import abstractmethod, ABC
from typing import Dict, Union, List
import torch

from navsim.common.dataclasses import AgentInput, Trajectory, SensorConfig
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder


class AbstractAgent(abc.ABC):
"""
Interface for a generic end-to-end agent.
"""
requires_scene = False

def __new__(cls, *args: Any, **kwargs: Any) -> AbstractAgent:
"""
Define attributes needed by all agents, take care when overriding.
:param cls: class being constructed.
:param args: arguments to constructor.
:param kwargs: keyword arguments to constructor.
"""
instance: AbstractAgent = super().__new__(cls)
instance._compute_trajectory_runtimes = []
return instance
class AbstractAgent(torch.nn.Module, ABC):
def __init__(
self,
requires_scene: bool = False,
):
super().__init__()
self.requires_scene = requires_scene

@abstractmethod
def name(self) -> str:
Expand All @@ -39,19 +28,78 @@ def get_sensor_config(self) -> SensorConfig:
"""
pass

@abc.abstractmethod
@abstractmethod
def initialize(self) -> None:
"""
Initialize agent
:param initialization: Initialization class.
"""
pass

@abc.abstractmethod
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Forward pass of the agent.
:param features: Dictionary of features.
:return: Dictionary of predictions.
"""
raise NotImplementedError

def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
"""
:return: List of target builders.
"""
raise NotImplementedError("No feature builders. Agent does not support training.")

def get_target_builders(self) -> List[AbstractTargetBuilder]:
"""
:return: List of feature builders.
"""
raise NotImplementedError("No target builders. Agent does not support training.")

def compute_trajectory(self, agent_input: AgentInput) -> Trajectory:
"""
Computes the ego vehicle trajectory.
:param current_input: Dataclass with agent inputs.
:return: Trajectory representing the predicted ego's position in future
"""
pass
features : Dict[str, torch.Tensor] = {}
# build features
for builder in self.get_feature_builders():
features.update(builder.compute_features(agent_input))

# add batch dimension
features = {k: v.unsqueeze(0) for k, v in features.items()}

# forward pass
with torch.no_grad():
predictions = self.forward(features)
poses = predictions["trajectory"].squeeze(0).numpy()

# extract trajectory
return Trajectory(poses)

def compute_loss(
self,
features: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
predictions: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""
Computes the loss used for backpropagation based on the features, targets and model predictions.
"""
raise NotImplementedError("No loss. Agent does not support training.")

def get_optimizers(
self
) -> Union[
torch.optim.Optimizer,
Dict[str, Union[
torch.optim.Optimizer,
torch.optim.lr_scheduler.LRScheduler]
]
]:
"""
Returns the optimizers that are used by thy pytorch-lightning trainer.
Has to be either a single optimizer or a dict of optimizer and lr scheduler.
"""
raise NotImplementedError("No optimizers. Agent does not support training.")
102 changes: 102 additions & 0 deletions navsim/agents/ego_status_mlp_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Any, List, Dict
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling

from navsim.agents.abstract_agent import AbstractAgent
from navsim.common.dataclasses import AgentInput, SensorConfig
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
from navsim.common.dataclasses import Scene


import torch


class EgoStatusFeatureBuilder(AbstractFeatureBuilder):
def __init__(self):
pass

def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]:
ego_status = agent_input.ego_statuses[-1]
velocity = torch.tensor(ego_status.ego_velocity)
acceleration = torch.tensor(ego_status.ego_acceleration)
driving_command = torch.tensor(ego_status.driving_command)
ego_state_feature = torch.cat([velocity, acceleration, driving_command], dim=-1)

return {"ego_state": ego_state_feature}


class TrajectoryTargetBuilder(AbstractTargetBuilder):
def __init__(self, trajectory_sampling: TrajectorySampling):
self._trajectory_sampling = trajectory_sampling

def compute_targets(self, scene: Scene) -> Dict[str, torch.Tensor]:
future_trajectory = scene.get_future_trajectory(
num_trajectory_frames=self._trajectory_sampling.num_poses
)
return {"trajectory": torch.tensor(future_trajectory.poses)}


class EgoStatusMLPAgent(AbstractAgent):
def __init__(
self,
trajectory_sampling: TrajectorySampling,
hidden_layer_dim: int,
lr: float,
checkpoint_path: str = None,
):
super().__init__()
self._trajectory_sampling = trajectory_sampling
self._checkpoint_path = checkpoint_path

self._lr = lr

self._mlp = torch.nn.Sequential(
torch.nn.Linear(8, hidden_layer_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_layer_dim, hidden_layer_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_layer_dim, hidden_layer_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_layer_dim, self._trajectory_sampling.num_poses * 3),
)

def name(self) -> str:
"""Inherited, see superclass."""

return self.__class__.__name__

def initialize(self) -> None:
"""Inherited, see superclass."""
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"]
self.load_state_dict({k.replace("agent.",""):v for k,v in state_dict.items()})

def get_sensor_config(self) -> SensorConfig:
"""Inherited, see superclass."""
return SensorConfig.build_no_sensors()

def get_target_builders(self) -> List[AbstractTargetBuilder]:
return [
TrajectoryTargetBuilder(
trajectory_sampling=self._trajectory_sampling
),
]

def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
return [EgoStatusFeatureBuilder()]

def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
poses: torch.Tensor = self._mlp(features["ego_state"])
return {"trajectory": poses.reshape(-1, self._trajectory_sampling.num_poses, 3)}

def compute_loss(
self,
features: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
predictions: Dict[str, torch.Tensor],
) -> torch.Tensor:
return torch.nn.functional.l1_loss(predictions["trajectory"], targets["trajectory"])

def get_optimizers(self) -> Optimizer | Dict[str, Optimizer | LRScheduler]:
return torch.optim.Adam(self._mlp.parameters(), lr=self._lr)
2 changes: 1 addition & 1 deletion navsim/common/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __post_init__(self):
class Trajectory:
poses: npt.NDArray[np.float32] # local coordinates
trajectory_sampling: TrajectorySampling = TrajectorySampling(
time_horizon=5, interval_length=0.5
time_horizon=4, interval_length=0.5
)

def __post_init__(self):
Expand Down
Loading

0 comments on commit 1d25630

Please sign in to comment.