-
Notifications
You must be signed in to change notification settings - Fork 6
/
traffic_simulator.py
130 lines (107 loc) · 4.24 KB
/
traffic_simulator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import numpy as np
from dataclasses import replace
from example_adapter import get_observation_adapter
from utils import get_vehicle_start_at_time
from smarts.core.smarts import SMARTS
from smarts.core.agent import AgentSpec
from smarts.core.agent_interface import AgentInterface
from smarts.core.controllers import ActionSpaceType
from smarts.core.scenario import Scenario
from smarts.core.traffic_history_provider import TrafficHistoryProvider
def get_action_adapter():
def action_adapter(model_action):
assert len(model_action) == 2
return (model_action[0], model_action[1])
return action_adapter
class TrafficSim:
def __init__(self, scenarios, obs_stacked_size=1):
self.scenarios_iterator = Scenario.scenario_variations(scenarios, [])
self._init_scenario()
self.obs_stacked_size = obs_stacked_size
self.agent_spec = AgentSpec(
interface=AgentInterface(
max_episode_steps=None,
waypoints=False,
neighborhood_vehicles=True,
ogm=False,
rgb=False,
lidar=False,
action=ActionSpaceType.Imitation,
),
action_adapter=get_action_adapter(),
observation_adapter=get_observation_adapter(obs_stacked_size),
)
self.smarts = SMARTS(
agent_interfaces={},
traffic_sim=None,
envision=None,
)
def seed(self, seed):
np.random.seed(seed)
def step(self, action):
raw_observations, rewards, dones, _ = self.smarts.step(
{self.vehicle_id: self.agent_spec.action_adapter(action)}
)
observation = self.agent_spec.observation_adapter(
raw_observations[self.vehicle_id]
)
return (
observation,
rewards[self.vehicle_id],
{"__all__": dones[self.vehicle_id]},
{
"vehicle_id": self.vehicle_id,
"reached_goal": raw_observations[self.vehicle_id].events.reached_goal,
},
)
def reset(self, internal_replacement=False, min_successor_time=5.0):
if self.vehicle_itr >= len(self.vehicle_ids):
self.vehicle_itr = 0
self.vehicle_id = self.vehicle_ids[self.vehicle_itr]
vehicle_mission = self.vehicle_missions[self.vehicle_id]
traffic_history_provider = self.smarts.get_provider_by_type(
TrafficHistoryProvider
)
assert traffic_history_provider
if internal_replacement:
end_time = self.scenario.traffic_history.vehicle_final_exit_time(
self.vehicle_id
)
alive_time = end_time - vehicle_mission.start_time
if alive_time <= 0:
raise ValueError(vehicle_mission.start_time, end_time, alive_time)
traffic_history_provider.start_time = (
vehicle_mission.start_time
+ np.random.choice(
max(0, round(alive_time * 10) - round(min_successor_time * 10))
)
/ 10
)
else:
traffic_history_provider.start_time = vehicle_mission.start_time
modified_mission = replace(
vehicle_mission,
start_time=0.0,
start=get_vehicle_start_at_time(
self.vehicle_id,
traffic_history_provider.start_time,
self.scenario.traffic_history,
),
)
self.scenario.set_ego_missions({self.vehicle_id: modified_mission})
self.smarts.switch_ego_agents({self.vehicle_id: self.agent_spec.interface})
raw_observations = self.smarts.reset(self.scenario)
observation = self.agent_spec.observation_adapter(
raw_observations[self.vehicle_id]
)
self.vehicle_itr += 1
return observation
def _init_scenario(self):
self.scenario = next(self.scenarios_iterator)
self.vehicle_missions = self.scenario.discover_missions_of_traffic_histories()
self.vehicle_ids = list(self.vehicle_missions.keys())
np.random.shuffle(self.vehicle_ids)
self.vehicle_itr = 0
def close(self):
if self.smarts is not None:
self.smarts.destroy()