diff --git a/flow/core/kernel/vehicle/aimsun.py b/flow/core/kernel/vehicle/aimsun.py index ce0d026e5..16c94558a 100644 --- a/flow/core/kernel/vehicle/aimsun.py +++ b/flow/core/kernel/vehicle/aimsun.py @@ -65,6 +65,7 @@ def __init__(self, # number of vehicles to exit the network for every time-step self._num_arrived = [] self._arrived_ids = [] + self._arrived_rl_ids = [] # contains conversion from Flow-ID to Aimsun-ID self._id_aimsun2flow = {} @@ -174,11 +175,17 @@ def update(self, reset): added_vehicles = self.kernel_api.get_entered_ids() exited_vehicles = self.kernel_api.get_exited_ids() + # keep track of arrived rl vehicles + arrived_rl_ids = [] + # add the new vehicles if they should be tracked for aimsun_id in added_vehicles: veh_type = self.kernel_api.get_vehicle_type_name(aimsun_id) if veh_type in self.tracked_vehicle_types: self._add_departed(aimsun_id) + if aimsun_id in self.get_rl_ids(): + arrived_rl_ids.append(aimsun_id) + self._arrived_rl_ids.append(arrived_rl_ids) # remove the exited vehicles if they were tracked if not reset: @@ -639,6 +646,16 @@ def get_arrived_ids(self): """See parent class.""" raise NotImplementedError + def get_arrived_rl_ids(self, k=1): + """See parent class.""" + if len(self._arrived_rl_ids) > 0: + arrived = [] + for arr in self._arrived_rl_ids[-k:]: + arrived.extend(arr) + return arrived + else: + return 0 + def get_departed_ids(self): """See parent class.""" raise NotImplementedError diff --git a/flow/core/kernel/vehicle/traci.py b/flow/core/kernel/vehicle/traci.py index 134bac49f..6f119b7bb 100644 --- a/flow/core/kernel/vehicle/traci.py +++ b/flow/core/kernel/vehicle/traci.py @@ -521,10 +521,13 @@ def get_arrived_ids(self): """See parent class.""" return self._arrived_ids - def get_arrived_rl_ids(self): + def get_arrived_rl_ids(self, k=1): """See parent class.""" if len(self._arrived_rl_ids) > 0: - return self._arrived_rl_ids[-1] + arrived = [] + for arr in self._arrived_rl_ids[-k:]: + arrived.extend(arr) + return arrived else: return 0 diff --git a/flow/envs/multiagent/base.py b/flow/envs/multiagent/base.py index ec95474c6..2d9c3cd78 100644 --- a/flow/envs/multiagent/base.py +++ b/flow/envs/multiagent/base.py @@ -122,7 +122,7 @@ def step(self, rl_actions): else: reward = self.compute_reward(rl_actions, fail=crash) - for rl_id in self.k.vehicle.get_arrived_rl_ids(): + for rl_id in self.k.vehicle.get_arrived_rl_ids(self.env_params.sims_per_step): done[rl_id] = True reward[rl_id] = 0 states[rl_id] = np.zeros(self.observation_space.shape[0])