Skip to content

Commit

Permalink
Update costs
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Jan 31, 2025
1 parent 051673c commit a4ba565
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 4 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ argoverse =
av2>=0.2.1
Rtree>=0.9.7
benchmark =
%(argoverse)s
%(gymnasium)s
%(ray)s
%(sumo)s
Expand Down
12 changes: 12 additions & 0 deletions smarts/benchmark/driving_smarts/v2023/metric_formula_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,25 @@ def costs_to_score(costs: Costs) -> Score:
+ 0.25 * (1 - rule_violation)
)

weighted = (
0.1 * (1 - dist_to_destination)
+ 0.3 * (1 - time)
+ 0.15 * (1 - humanness_error)
+ 0.45 * (1 - rule_violation)
)

return Score(
{
"overall": overall,
"weighted": np.round(weighted, 3),
"dist_to_destination": dist_to_destination,
"time": time,
"humanness_error": humanness_error,
"rule_violation": rule_violation,
"progress_rate": np.round(1 - dist_to_destination, 3),
"rule_compliance": np.round(1 - rule_violation, 3),
"humanness": np.round(1 - humanness_error, 3),
"mission_time_efficiency": np.round(1 - time, 3),
}
)

Expand Down
12 changes: 12 additions & 0 deletions smarts/benchmark/driving_smarts/v2023/metric_formula_platoon.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,25 @@ def costs_to_score(costs: Costs) -> Score:
+ 0.25 * (1 - rule_violation)
)

weighted = (
0.1 * (1 - dist_to_destination)
+ 0.3 * (1 - vehicle_gap)
+ 0.15 * (1 - humanness_error)
+ 0.45 * (1 - rule_violation)
)

return Score(
{
"overall": overall,
"weighted": np.round(weighted, 3),
"dist_to_destination": dist_to_destination,
"vehicle_gap": vehicle_gap,
"humanness_error": humanness_error,
"rule_violation": rule_violation,
"progress_rate": np.round(1 - dist_to_destination, 3),
"rule_compliance": np.round(1 - rule_violation, 3),
"humanness": np.round(1 - humanness_error, 3),
"safe_following_distance": np.round(1 - vehicle_gap, 3),
}
)

Expand Down
62 changes: 58 additions & 4 deletions smarts/benchmark/entrypoints/benchmark_runner_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import psutil
import ray

import smarts
from smarts.benchmark.driving_smarts import load_config
from smarts.core import config
from smarts.core.agent_interface import ObservationFormat
from smarts.core.utils.core_logging import suppress_output
from smarts.core.utils.import_utils import import_module_from_file
from smarts.env.gymnasium.wrappers.metric.formula import FormulaBase, Score
Expand All @@ -47,6 +49,25 @@ def _eval_worker(name, env_config, episodes, agent_locator, error_tolerant=False
return _eval_worker_local(name, env_config, episodes, agent_locator, error_tolerant)


def _get_env_attr(env, attr):
try:
return env.get_wrapper_attr(attr)
except AttributeError:
return getattr(env, attr)

def _observation_format_adapt(observation_format):

if observation_format == ObservationFormat.DICT:
def func(env, info: Dict, observation: Dict, agent_id: str):
return observation[agent_id]
elif observation_format == ObservationFormat.SMARTS_OBS:
def func(env, info: Dict, observation: Dict, agent_id: str):
return info[agent_id]['env_obs']
else:
raise NotImplementedError(f"Observation format `{observation_format}` is not supported!")

return func

def _eval_worker_local(name, env_config, episodes, agent_locator, error_tolerant=False):
import warnings

Expand All @@ -57,19 +78,27 @@ def _eval_worker_local(name, env_config, episodes, agent_locator, error_tolerant
agent_interface=agent_registry.make(locator=agent_locator).interface,
**env_config["kwargs"],
)

env = Metrics(env, formula_path=env_config["metric_formula"])
agents = {
agent_id: agent_registry.make_agent(locator=agent_locator)[0]
for agent_id in env.agent_ids
for agent_id in _get_env_attr(env, "agent_ids")
}
agent_observation_format_adaptors = {
agent_id: _observation_format_adapt(agent_registry.make(locator=agent_locator).interface.observation_format)
for agent_id in _get_env_attr(env, "agent_ids")
}


obs, info = env.reset()
current_resets = 0
try:
while current_resets < episodes:
try:
action = {
agent_id: agents[agent_id].act(agent_obs)
agent_id: agents[agent_id].act(
agent_observation_format_adaptors[agent_id](env, info, obs, agent_id)
)
for agent_id, agent_obs in obs.items()
}
# assert env.action_space.contains(action)
Expand All @@ -88,6 +117,29 @@ def _eval_worker_local(name, env_config, episodes, agent_locator, error_tolerant
finally:
records = env.records()
env.close()
# try:
# while current_resets < episodes:
# try:
# action = {
# agent_id: agents[agent_id].act(agent_obs['env_obs'])
# for agent_id, agent_obs in info.items()
# }
# # assert env.action_space.contains(action)
# except ArithmeticError:
# logging.error("Policy robustness failed.")
# # # TODO MTA: mark policy failures
# # env.mark_policy_failure()
# if not error_tolerant:
# raise
# terminated, truncated = False, True
# else:
# obs, reward, terminated, truncated, info = env.step(action)
# if terminated["__all__"] or truncated["__all__"]:
# current_resets += 1
# obs, info = env.reset()
# finally:
# records = env.records()
# env.close()
return name, records


Expand Down Expand Up @@ -178,11 +230,13 @@ def benchmark(benchmark_args, agent_locator) -> Tuple[Dict, Dict]:
print(message)

debug = benchmark_args.get("debug", {})
# debug = {"serial": True}
iterator = _serial_task_iterator if debug.get("serial") else _parallel_task_iterator

root_dir = Path(__file__).resolve().parents[3]
smarts_dir = Path(smarts.__path__[0]).resolve()
root_dir = smarts_dir.parent
metric_formula_default = (
root_dir / "smarts" / "env" / "gymnasium" / "wrappers" / "metric" / "formula.py"
smarts_dir / "env" / "gymnasium" / "wrappers" / "metric" / "formula.py"
)
weighted_scores, agent_scores = {}, {}
for env_name, env_config in benchmark_args["envs"].items():
Expand Down
15 changes: 15 additions & 0 deletions smarts/core/agent_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,16 @@ def actors_alive(self):
warnings.warn("Use interest.", category=DeprecationWarning)
return self.interest

class ObservationFormat(Enum):
DICT = auto()
"""This agent uses dictionary formatted SMARTS observations. Likely with gym environments."""
SMARTS_OBS = auto()
"""This agent uses SMARTS observations from the high level engine interface."""
CUSTOM_DICT = auto()
"""This agent uses a modified set of dictionary formatted observations and is not fully wrapped."""
# PRIVILAGED = auto()
# """This agent uses privilaged SMARTS engine calls."""


@dataclass
class AgentInterface:
Expand Down Expand Up @@ -481,6 +491,11 @@ class AgentInterface:
"""Add custom renderer outputs.
"""

observation_format: ObservationFormat = ObservationFormat.DICT
"""
The observation format that this agent wants to use.
"""

def __post_init__(self):
self.neighborhood_vehicle_states = AgentInterface._resolve_config(
self.neighborhood_vehicle_states, NeighborhoodVehicles
Expand Down
5 changes: 5 additions & 0 deletions zoo/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AgentInterface,
AgentType,
DoneCriteria,
ObservationFormat,
Waypoints,
)
from smarts.core.controllers import ActionSpaceType
Expand Down Expand Up @@ -162,6 +163,7 @@ def entry_point_iamp(**kwargs):
return AgentSpec(
interface=AgentInterface(
action=ActionSpaceType.TargetPose,
observation_format=ObservationFormat.SMARTS_OBS,
),
agent_builder=lib.Policy,
)
Expand All @@ -180,6 +182,7 @@ def entry_point_casl(**kwargs):
return AgentSpec(
interface=AgentInterface(
action=ActionSpaceType.TargetPose,
observation_format=ObservationFormat.SMARTS_OBS,
),
agent_builder=lib.Policy,
)
Expand All @@ -199,6 +202,8 @@ def entry_point_dsac(**kwargs):
return AgentSpec(
interface=AgentInterface(
action=ActionSpaceType.TargetPose,
lidar_point_cloud=True,
observation_format=ObservationFormat.SMARTS_OBS,
),
agent_builder=lib.Policy,
)
Expand Down

0 comments on commit a4ba565

Please sign in to comment.