Skip to content

Commit 532f870

Browse files
authored
Merge pull request #1 from mit-wu-lab/env_demo
add env demo
2 parents d04a312 + 6a8549e commit 532f870

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

code/env/environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, config: EnvContext):
5757
super().__init__()
5858
self.config: IntersectionZooEnvConfig = config["intersectionzoo_env_config"]
5959
self.task_context: TaskContext | None = self.config.task_context
60-
self.prefix: str = str(config.worker_index)
60+
self.prefix: str = str(config.worker_index) if hasattr(config, "worker_index") else ""
6161
self.traci = None
6262
self.traffic_state: Optional[TrafficState] = None
6363
self._curr_step = 0

code/env_demo.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import argparse
2+
from pathlib import Path
3+
4+
import numpy as np
5+
from env.config import IntersectionZooEnvConfig
6+
from env.task_context import PathTaskContext
7+
from env.environment import IntersectionZooEnv
8+
from sumo.constants import REGULAR
9+
10+
parser = argparse.ArgumentParser(description='Demo run arguments')
11+
parser.add_argument('--dir', default='/Users/bfreydt/MIT_local/IntersectionZoo/wd/new_exp', type=str, help='Result directory')
12+
parser.add_argument('--intersection_dir', default='/Users/bfreydt/MIT_local/IntersectionZoo/dataset/salt-lake-city', type=str, help='Path to intersection dataset')
13+
parser.add_argument('--penetration', default=0.33, type=str, help='Eco drive adoption rate')
14+
parser.add_argument('--temperature_humidity', default='68_46', type=str, help='Temperature and humidity for evaluations')
15+
16+
args = parser.parse_args()
17+
print(args)
18+
19+
tasks = PathTaskContext(
20+
dir=Path(args.intersection_dir),
21+
single_approach=True,
22+
penetration_rate=args.penetration,
23+
temperature_humidity=args.temperature_humidity,
24+
electric_or_regular=REGULAR,
25+
)
26+
27+
env_conf = IntersectionZooEnvConfig(
28+
task_context=tasks.sample_task(),
29+
working_dir=Path(args.dir),
30+
moves_emissions_models=[args.temperature_humidity],
31+
fleet_reward_ratio=1,
32+
)
33+
34+
# Create the environment
35+
env = IntersectionZooEnv({"intersectionzoo_env_config": env_conf})
36+
37+
def filter_obs(obs: dict):
38+
def simplify(v):
39+
if isinstance(v, np.ndarray):
40+
if len(v) == 1:
41+
return v[0]
42+
else:
43+
return v.tolist()
44+
else:
45+
return v
46+
47+
return {k: {
48+
k2: simplify(v2) for k2, v2 in v.items() if k2 in ["speed", "relative_distance", "tl_phase"]
49+
} for k,v in obs.items() if k != "mock"}
50+
51+
def filter_rew(rew: dict):
52+
return {k: v for k,v in rew.items() if k != "mock"}
53+
54+
# Reset the environment
55+
obs, _ = env.reset()
56+
terminated = {"__all__": False}
57+
while not terminated["__all__"]:
58+
# Send a constant action for all agents
59+
action = {agent: [1] for agent in obs.keys()}
60+
61+
# Take a step in the environment
62+
obs, reward, terminated, truncated, info = env.step(action)
63+
64+
# Print the observations and reward
65+
print("Observations:", filter_obs(obs))
66+
print("Reward:", filter_rew(reward))
67+
input("Press Enter to continue...")
68+
69+
# Close the environment
70+
env.close()

0 commit comments

Comments
 (0)