Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lane error penalty #2030

Merged
merged 4 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Copy and pasting the git commit messages is __NOT__ enough.
- `TrapEntryTactic.wait_to_hijack_limit_s` field now defaults to `0`.
- `EntryTactic` derived classes now contain `condition` to provide extra filtering of candidate actors.
- `EntryTactic` derived classes now contain `start_time`.
- `RoadMap.Route` now optionally stores the start and end lanes of the route.
- `DistToDestination` metric now adds lane error penalty when agent terminates in different lane but same road as the goal position.
### Deprecated
- `visdom` is set to be removed from the SMARTS object parameters.
- Deprecated `start_time` on missions.
Expand Down
2 changes: 2 additions & 0 deletions smarts/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ def create_route(self, mission: Mission, radius: Optional[float] = None):
start_lane.road, end_lane.road, via_roads, 1
)[0]
if self._route.road_length > 0:
self._route.start_lane = start_lane
self._route.end_lane = end_lane
break

if len(self._route.roads) == 0:
Expand Down
20 changes: 20 additions & 0 deletions smarts/core/road_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,26 @@ def __eq__(self, other) -> bool:
def _add_road(self, road: RoadMap.Road):
raise NotImplementedError()

@property
def start_lane(self) -> Optional[RoadMap.Lane]:
"Route's start lane."
return None

@start_lane.setter
def start_lane(self,value:RoadMap.Lane):
"Route's start lane."
raise NotImplementedError()
Gamenot marked this conversation as resolved.
Show resolved Hide resolved

@property
def end_lane(self) -> Optional[RoadMap.Lane]:
"Route's end lane."
return None

@end_lane.setter
def end_lane(self,value:RoadMap.Lane):
"Route's end lane."
raise NotImplementedError()

@property
def roads(self) -> List[RoadMap.Road]:
"""A possibly-unordered list of roads that this route covers"""
Expand Down
26 changes: 24 additions & 2 deletions smarts/core/route_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class RouteWithCache(RoadMap.Route):
def __init__(self, road_map: RoadMap):
self._map = road_map
self._logger = logging.getLogger(self.__class__.__name__)
self._start_lane: Optional[RoadMap.Lane] = None
Adaickalavan marked this conversation as resolved.
Show resolved Hide resolved
self._end_lane: Optional[RoadMap.Lane] = None

def __hash__(self) -> int:
key: int = self._cache_key # pytype: disable=annotation-type-mismatch
Expand All @@ -62,9 +64,29 @@ def __eq__(self, other) -> bool:
def _add_road(self, road: RoadMap.Road):
raise NotImplementedError()

@property
def start_lane(self) -> Optional[RoadMap.Lane]:
"Route's start lane."
return self._start_lane

@start_lane.setter
def start_lane(self,value:RoadMap.Lane):
"Route's start lane."
self._start_lane=value

@property
def end_lane(self) -> Optional[RoadMap.Lane]:
"Route's end lane."
return self._end_lane

@end_lane.setter
def end_lane(self,value:RoadMap.Lane):
"Route's end lane."
self._end_lane = value

@cached_property
def road_ids(self) -> List[str]:
"""Retruns a list of the road_ids for the Roads in this Route."""
"""Returns a list of the road_ids for the Roads in this Route."""
return [road.road_id for road in self.roads]

@classmethod
Expand Down Expand Up @@ -110,7 +132,7 @@ def remove_from_cache(self):

# TAI: could pre-cache curvatures here too (like waypoints) ?
def add_to_cache(self):
"""Add informationa about this Route to the cache if not already there."""
"""Add information about this Route to the cache if not already there."""
if self.is_cached:
return

Expand Down
71 changes: 41 additions & 30 deletions smarts/env/gymnasium/wrappers/metric/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,33 +544,44 @@ def get_dist(road_map: RoadMap, point_a: Point, point_b: Point) -> float:
float: Shortest road distance between two points in the road map.
"""

def _get_dist(start: Point, end: Point) -> float:
mission = Mission(
start=Start(
position=start.as_np_array,
heading=Heading(0),
from_front_bumper=False,
),
goal=PositionalGoal(
position=end,
radius=3,
),
)
plan = Plan(road_map=road_map, mission=mission, find_route=False)
plan.create_route(mission=mission, radius=20)
from_route_point = RoadMap.Route.RoutePoint(pt=start)
to_route_point = RoadMap.Route.RoutePoint(pt=end)

dist_tot = plan.route.distance_between(
start=from_route_point, end=to_route_point
)
if dist_tot == None:
raise CostError("Unable to find road on route near given points.")
elif dist_tot < 0:
raise CostError(
"Path from start point to end point flows in "
"the opposite direction of the generated route."
)
return dist_tot

return _get_dist(point_a, point_b)
mission = Mission(
start=Start(
position=point_a.as_np_array,
heading=Heading(0),
from_front_bumper=False,
),
goal=PositionalGoal(
position=point_b,
radius=2,
),
)
plan = Plan(road_map=road_map, mission=mission, find_route=False)
plan.create_route(mission=mission, radius=20)
assert isinstance(plan.route, RoadMap.Route)
from_route_point = RoadMap.Route.RoutePoint(pt=point_a)
to_route_point = RoadMap.Route.RoutePoint(pt=point_b)

dist_tot = plan.route.distance_between(
start=from_route_point, end=to_route_point
)
if dist_tot == None:
raise CostError("Unable to find road on route near given points.")
elif dist_tot < 0:
# This happens when agent overshoots the goal position while
# remaining outside the goal capture radius at all times. Default
# positional goal radius is 2m.
dist_tot = abs(dist_tot)

# Account for agent ending in a different lane but in the same road as
# the goal position.
start_lane = plan.route.start_lane
end_lane = plan.route.end_lane
lane_error = abs(start_lane.index-end_lane.index)
if len(plan.route.roads) == 1 and lane_error > 0:
Adaickalavan marked this conversation as resolved.
Show resolved Hide resolved
assert start_lane.road == end_lane.road
end_offset = end_lane.offset_along_lane(world_point=point_b)
lane_width, _ = end_lane.width_at_offset(end_offset)
lane_error_dist = lane_error*lane_width
dist_tot += lane_error_dist

return dist_tot