Skip to content

Commit

Permalink
updated sampled dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jingGM committed Nov 16, 2024
1 parent 2066920 commit 3f27a97
Show file tree
Hide file tree
Showing 38 changed files with 78 additions and 247 deletions.
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.idea/*
*/.idea/*

build/*
*/build/*

__pycache__/*

loginfo/*
loginfo

3 changes: 0 additions & 3 deletions .idea/.gitignore

This file was deleted.

8 changes: 0 additions & 8 deletions .idea/DTG.iml

This file was deleted.

98 changes: 0 additions & 98 deletions .idea/inspectionProfiles/Project_Default.xml

This file was deleted.

6 changes: 0 additions & 6 deletions .idea/inspectionProfiles/profiles_settings.xml

This file was deleted.

4 changes: 0 additions & 4 deletions .idea/misc.xml

This file was deleted.

8 changes: 0 additions & 8 deletions .idea/modules.xml

This file was deleted.

6 changes: 0 additions & 6 deletions .idea/vcs.xml

This file was deleted.

Empty file modified DTG_paper.pdf
100644 → 100755
Empty file.
10 changes: 8 additions & 2 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ pip install torch_geometric pyg_lib torch_scatter torch_sparse torch_cluster tor
pip install -r requirements.txt
```

Download sample dataset in the root folder:
https://drive.google.com/drive/folders/1YClCBSCUc3_Zy3WIQfAE6_kIQ0xTOe0I?usp=sharing

# Run
- generator_type: 0: diffusion model; 1: cvae
- diffusion_model: 0: crnn; 1:unet
- crnn_type: 0: gru; 1:lstm
```
python main.py
```
python main.py --wandb_api=YOUR_WANDB_API --generator_type=0 --diffusion_model=0 --crnn_type=0
```
Empty file modified front.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified main.py
100644 → 100755
Empty file.
Empty file modified requirements.txt
100644 → 100755
Empty file.
Empty file modified src/.gitignore
100644 → 100755
Empty file.
Empty file modified src/__init__.py
100644 → 100755
Empty file.
Empty file modified src/data_loader/.gitignore
100644 → 100755
Empty file.
Empty file modified src/data_loader/__init__.py
100644 → 100755
Empty file.
Empty file modified src/data_loader/data_loader.py
100644 → 100755
Empty file.
Empty file modified src/data_loader/dataset.py
100644 → 100755
Empty file.
68 changes: 21 additions & 47 deletions src/loss.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
import time

from src.models.diff_hausdorf import HausdorffLoss
from src.utils.configs import GeneratorType, DataDict, Hausdorff, LossNames, DiffusionTypes
from src.utils.configs import GeneratorType, DataDict, Hausdorff, LossNames


class Loss(nn.Module):
def __init__(self, cfg):
super(Loss, self).__init__()

with open(join(cfg.root, "data.pkl"), "rb") as input_file:
data = pickle.load(input_file)
self.all_positions = data[DataDict.all_positions]
self.network = data[DataDict.network]
# with open(join(cfg.root, "data.pkl"), "rb") as input_file:
# data = pickle.load(input_file)
# self.network = data[DataDict.network]

self.generator_type = cfg.generator_type
self.use_traversability = cfg.use_traversability
Expand All @@ -35,7 +34,6 @@ def __init__(self, cfg):
self.distance = HausdorffLoss(mode=cfg.distance_type)

self.train_poses = cfg.train_poses
self.diffusion_type = cfg.diffusion_type
self.distance_type = cfg.distance_type
self.scale_waypoints = cfg.scale_waypoints
self.last_ratio = cfg.last_ratio
Expand All @@ -61,17 +59,15 @@ def _cropped_distance(self, path, single_map):
traversability = torch.clamp(d, 0.0001, self.collision_distance)
values = traversability[torch.where(traversability < self.collision_distance)]
if len(values) < 1:
return torch.tensor(0, device=traversability.device), torch.tensor(1, device=traversability.device)
return (torch.tensor(0, device=traversability.device, dtype=torch.float),
torch.tensor(1, device=traversability.device, dtype=torch.float))
else:
torch.cuda.empty_cache()
loss = torch.arctanh((self.collision_distance - values) / self.collision_distance)
return loss.mean(), values.mean()

def _local_collision(self, yhat, local_map):
assert len(yhat.shape) == 3, "the shape should be B,N,2"
# if len(yhat.shape) == 3:
# B, N, C = yhat.shape
# yhat = yhat.view(B, 1, N, C)
By, N, C = yhat.shape
Bl, W, H = local_map.shape
assert Bl == By, "the batch shape {} and {} should be the same".format(By, Bl)
Expand Down Expand Up @@ -120,38 +116,27 @@ def forward_cvae(self, input_dict):
return output

def forward_diffusion(self, input_dict):
noise = input_dict[DataDict.noise]
ygt = input_dict[DataDict.path]
y_hat = input_dict[DataDict.prediction]

output = {}
if self.diffusion_type == DiffusionTypes.noise:
all_loss = self.target_dis(y_hat, noise)
if self.use_traversability:
traversablility_hat = input_dict[DataDict.predict_path]
if self.train_poses:
traversability_hat_poses = traversablility_hat * self.scale_waypoints
else:
traversability_hat_poses = torch.cumsum(traversablility_hat, dim=1) * self.scale_waypoints
elif self.diffusion_type == DiffusionTypes.trajectory:
if self.train_poses:
y_hat_poses = y_hat * self.scale_waypoints
else:
y_hat_poses = torch.cumsum(y_hat, dim=1) * self.scale_waypoints
if self.use_traversability:
B, _, _ = y_hat.shape
traversability_hat_poses = y_hat_poses[int(B / 2):]
y_hat_poses = y_hat_poses[:int(B / 2)]

path_dis = self.distance(ygt, y_hat_poses).mean()
last_pose_dis = self.target_dis(ygt[:, -1, :], y_hat_poses[:, -1, :])
all_loss = self.distance_ratio * path_dis + self.last_ratio * last_pose_dis
output.update({
LossNames.last_dis: last_pose_dis,
LossNames.path_dis: path_dis,
})
if self.train_poses:
y_hat_poses = y_hat * self.scale_waypoints
else:
raise Exception("the diffusion type is not defined")
y_hat_poses = torch.cumsum(y_hat, dim=1) * self.scale_waypoints
if self.use_traversability:
B, _, _ = y_hat.shape
traversability_hat_poses = y_hat_poses[int(B / 2):]
y_hat_poses = y_hat_poses[:int(B / 2)]

path_dis = self.distance(ygt, y_hat_poses).mean()
last_pose_dis = self.target_dis(ygt[:, -1, :], y_hat_poses[:, -1, :])
all_loss = self.distance_ratio * path_dis + self.last_ratio * last_pose_dis
output.update({
LossNames.last_dis: last_pose_dis,
LossNames.path_dis: path_dis,
})

if self.use_traversability:
local_map = input_dict[DataDict.local_map]
Expand All @@ -163,22 +148,11 @@ def forward_diffusion(self, input_dict):
output.update({LossNames.loss: all_loss})
return output

def forward_estimation(self, input_dict):
gt = input_dict[DataDict.traversability_gt]
est_loss, est_dict = self.forward_traversability(loss=gt,
mu=input_dict[DataDict.traversability_mu],
var=input_dict[DataDict.traversability_var],
val=input_dict[DataDict.traversability_pred])
est_dict.update({LossNames.loss: est_loss})
return est_dict

def forward(self, input_dict):
if self.generator_type == GeneratorType.cvae:
return self.forward_cvae(input_dict=input_dict)
elif self.generator_type == GeneratorType.diffusion:
return self.forward_diffusion(input_dict=input_dict)
elif self.generator_type == GeneratorType.estimator:
return self.forward_estimation(input_dict=input_dict)

def convert_path_pixel(self, trajectory):
return np.clip(np.around(trajectory / self.map_resolution)[:, :2] + self.map_range, 0, np.inf)
Expand Down
Empty file modified src/models/.gitignore
100644 → 100755
Empty file.
Empty file modified src/models/__init__.py
100644 → 100755
Empty file.
Empty file modified src/models/backbones/.gitignore
100644 → 100755
Empty file.
Empty file modified src/models/backbones/__init__.py
100644 → 100755
Empty file.
Empty file modified src/models/backbones/rnn.py
100644 → 100755
Empty file.
Empty file modified src/models/backbones/unet.py
100644 → 100755
Empty file.
Empty file modified src/models/diff_hausdorf.py
100644 → 100755
Empty file.
Loading

0 comments on commit 3f27a97

Please sign in to comment.