From e5206d78eff5067ef9a7376f9dc7345965a116fa Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Mon, 17 Apr 2023 19:31:49 +0000 Subject: [PATCH 01/59] Add conditions to sstudio --- CHANGELOG.md | 3 + smarts/sstudio/types.py | 158 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 159 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c5d91cb78..296bd5f072 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ Copy and pasting the git commit messages is __NOT__ enough. - Documented the challenge objective, desired inference code structure, and use of baseline example, for Driving SMARTS 2023.1 (i.e., basic motion planning) and 2023.2 (i.e, turns) benchmarks. - Added an env wrapper for constraining the relative target pose action range. - Added a specialised metric formula module for Driving SMARTS 2023.1 and 2023.2 benchmark. +- Added representation interface `Condition` and `ConditionState` for conditions to scenario studio. ### Changed - The trap manager, `TrapManager`, is now a subclass of `ActorCaptureManager`. - Considering lane-change time ranges between 3s and 6s, assuming a speed of 13.89m/s, the via sensor lane acquisition range was increased from 40m to 80m, for better driving ability. @@ -83,6 +84,8 @@ Copy and pasting the git commit messages is __NOT__ enough. - Driving SMARTS 2023.3 benchmark and the metrics module now uses `actor_of_interest_re_filter` from scenario metadata to identify the lead vehicle. - Included `RelativeTargetPose` action space to the set of allowed action spaces in `platoon-v0` env. - `Collision.collidee_id` now gives the vehicle id rather than the name of the owner of the vehicle (usually the agent id.) `Collision.collidee_owner_id` now provides the id of the controlling `agent` (or other controlling entity in the future.) This is because 1) `collidee_id` should refer to the body and 2) in most cases the owner name would be `None`. +- `sstudio` generated scenario vehicle traffic IDs are now shortened. +- Entry tactics now use conditions to determine when they should capture an actor. ### Deprecated ### Fixed - Fixed issues related to waypoints in junctions on Argoverse maps. Waypoints will now be generated for all paths leading through the lane(s) the vehicle is on. diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 3095ff0658..8db9cb671f 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -18,11 +18,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. import collections.abc as collections_abc +import enum import logging import math import random from dataclasses import dataclass, field, replace -from enum import IntEnum +from enum import IntEnum, IntFlag from sys import maxsize from typing import ( Any, @@ -595,11 +596,164 @@ class Traffic: """ +class ConditionState(IntFlag): + """Represents the state of a condition.""" + + FALSE = 0 + """This condition is false.""" + UNTRIGGERED = 1 + """This condition is false and never evaluated true before.""" + EXPIRED = 2 + """This condition is false and will never evaluate true.""" + TRUE = 4 + """This condition is true.""" + + def __bool__(self) -> bool: + return self.TRUE in self + + +@dataclass(frozen=True) +class Condition: + """This describes a case that must be true.""" + + def evaluate(self, *args, **kwargs) -> ConditionState: + """Used to evaluate if a condition is met. + + Returns: + ConditionState: The evaluation result of the condition. + """ + raise NotImplementedError() + + def negate(self) -> "Condition": + """Negates this condition.""" + return NegatedCondition(self) + + def conjoin(self, other: "Condition"): + """AND's this condition with the other condition.""" + return CompoundCondition( + self, other, operator=ConditionLogicalOperator.CONJUNCTION + ) + + def disjoin(self, other: "Condition"): + """OR's this condition with the other condition.""" + return CompoundCondition( + self, other, operator=ConditionLogicalOperator.DISJUNCTION + ) + + def implicate(self, other: "Condition"): + """Current condition must be false or both conditions true to be true.""" + return CompoundCondition( + self, other, operator=ConditionLogicalOperator.IMPLICATION + ) + + +@dataclass(frozen=True) +class LiteralCondition(Condition): + """This condition evaluates as a literal without considering evaluation parameters.""" + + literal: ConditionState + """The literal value of this condition.""" + + def evaluate(self, *args, **kwargs) -> ConditionState: + return self.literal + + def negate(self) -> "LiteralCondition": + return LiteralCondition(~self.literal) + + +@dataclass(frozen=True) +class TimeWindowCondition(Condition): + """This condition should be true in the given simulation time window.""" + + start: float + """The starting simulation time before which this condition becomes false.""" + end: float + """The ending simulation time as of which this condition becomes expired.""" + + def evaluate(self, *args, simulation_time, **kwargs): + if self.start <= simulation_time < self.end: + return ConditionState.TRUE + elif self.end >= simulation_time: + return ConditionState.EXPIRED + return ConditionState.UNTRIGGERED + + +@dataclass(frozen=True) +class DependeeActorCondition(Condition): + """This condition should be true if the given actor exists.""" + + actor_id: str + """The id of an actor in the simulation that needs to exist for this condition to be true.""" + + def evaluate(self, *args, actor_ids, **kwargs): + if self.actor_id in actor_ids: + return ConditionState.TRUE + return ConditionState.FALSE + + +class ConditionLogicalOperator(IntEnum): + """Represents logical operators between conditions.""" + + CONJUNCTION = enum.auto() + """Evaluate true if both operands are true, otherwise false.""" + + DISJUNCTION = enum.auto() + """Evaluate true if either operand is true, otherwise false.""" + + IMPLICATION = enum.auto() + """Evaluate true if either the first operand is false, or both operands are true, otherwise false.""" + + ## This would be desirable but makes the implementation more difficult in comparison to a negated condition. + # NEGATION=enum.auto() + # """True if its operand is false, otherwise false.""" + + +@dataclass(frozen=True) +class NegatedCondition(Condition): + """This condition negates the inner condition.""" + + inner_condition: Condition + """The inner condition to negate.""" + + def evaluate(self, *args, **kwargs) -> ConditionState: + return ~self.inner_condition.evaluate(*args, **kwargs) + + +@dataclass(frozen=True) +class CompoundCondition: + """This condition should be true if the given actor exists.""" + + first_condition: Condition + """The first condition.""" + + second_condition: Condition + """The second condition.""" + + operator: ConditionLogicalOperator + """The operator used to combine these conditions.""" + + def evaluate(self, *args, actor_ids, **kwargs): + eval_0 = self.first_condition.evaluate(*args, **kwargs) + if self.operator == ConditionLogicalOperator.IMPLICATION and not eval_0: + return ConditionState.TRUE + + eval_1 = self.second_condition.evaluate(*args, **kwargs) + if self.operator == ConditionLogicalOperator.IMPLICATION and eval_0 and eval_1: + return ConditionState.TRUE + + if self.operator == ConditionLogicalOperator.CONJUNCTION: + return eval_0 & eval_1 + elif self.operator == ConditionLogicalOperator.DISJUNCTION: + return eval_0 | eval_1 + + return ConditionState.FALSE + + @dataclass(frozen=True) class EntryTactic: """The tactic that the simulation should use to acquire a vehicle for an agent.""" - pass + conditions: Tuple[Condition, ...] @dataclass(frozen=True) From ba6169dd5919eeda2f3824a718ab66f870de72a3 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Mon, 17 Apr 2023 19:48:03 +0000 Subject: [PATCH 02/59] Update entry tactic to use a single condition. --- smarts/sstudio/types.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 8db9cb671f..81e1039fed 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -624,6 +624,9 @@ def evaluate(self, *args, **kwargs) -> ConditionState: """ raise NotImplementedError() + +@dataclass(frozen=True) +class LogicalCondition(Condition): def negate(self) -> "Condition": """Negates this condition.""" return NegatedCondition(self) @@ -648,7 +651,7 @@ def implicate(self, other: "Condition"): @dataclass(frozen=True) -class LiteralCondition(Condition): +class LiteralCondition(LogicalCondition): """This condition evaluates as a literal without considering evaluation parameters.""" literal: ConditionState @@ -662,7 +665,7 @@ def negate(self) -> "LiteralCondition": @dataclass(frozen=True) -class TimeWindowCondition(Condition): +class TimeWindowCondition(LogicalCondition): """This condition should be true in the given simulation time window.""" start: float @@ -679,7 +682,7 @@ def evaluate(self, *args, simulation_time, **kwargs): @dataclass(frozen=True) -class DependeeActorCondition(Condition): +class DependeeActorCondition(LogicalCondition): """This condition should be true if the given actor exists.""" actor_id: str @@ -709,7 +712,7 @@ class ConditionLogicalOperator(IntEnum): @dataclass(frozen=True) -class NegatedCondition(Condition): +class NegatedCondition(LogicalCondition): """This condition negates the inner condition.""" inner_condition: Condition @@ -720,7 +723,7 @@ def evaluate(self, *args, **kwargs) -> ConditionState: @dataclass(frozen=True) -class CompoundCondition: +class CompoundCondition(LogicalCondition): """This condition should be true if the given actor exists.""" first_condition: Condition @@ -732,7 +735,7 @@ class CompoundCondition: operator: ConditionLogicalOperator """The operator used to combine these conditions.""" - def evaluate(self, *args, actor_ids, **kwargs): + def evaluate(self, *args, **kwargs): eval_0 = self.first_condition.evaluate(*args, **kwargs) if self.operator == ConditionLogicalOperator.IMPLICATION and not eval_0: return ConditionState.TRUE @@ -753,7 +756,8 @@ def evaluate(self, *args, actor_ids, **kwargs): class EntryTactic: """The tactic that the simulation should use to acquire a vehicle for an agent.""" - conditions: Tuple[Condition, ...] + condition: LogicalCondition + """The condition to determine if this entry tactic should be used.""" @dataclass(frozen=True) From d9e57e8ee42299b2a489a49be0f9b2042169fb82 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Mon, 17 Apr 2023 20:02:15 +0000 Subject: [PATCH 03/59] Clean up naming. --- smarts/sstudio/types.py | 81 ++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 45 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 81e1039fed..4d92677a42 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -612,9 +612,26 @@ def __bool__(self) -> bool: return self.TRUE in self +class ConditionOperator(IntEnum): + """Represents logical operators between conditions.""" + + CONJUNCTION = enum.auto() + """Evaluate true if both operands are true, otherwise false.""" + + DISJUNCTION = enum.auto() + """Evaluate true if either operand is true, otherwise false.""" + + IMPLICATION = enum.auto() + """Evaluate true if either the first operand is false, or both operands are true, otherwise false.""" + + ## This would be desirable but makes the implementation more difficult in comparison to a negated condition. + # NEGATION=enum.auto() + # """True if its operand is false, otherwise false.""" + + @dataclass(frozen=True) class Condition: - """This describes a case that must be true.""" + """This encompasses an expression to evaluate to a logical result.""" def evaluate(self, *args, **kwargs) -> ConditionState: """Used to evaluate if a condition is met. @@ -624,34 +641,25 @@ def evaluate(self, *args, **kwargs) -> ConditionState: """ raise NotImplementedError() - -@dataclass(frozen=True) -class LogicalCondition(Condition): - def negate(self) -> "Condition": + def negate(self) -> "NegatedCondition": """Negates this condition.""" return NegatedCondition(self) - def conjoin(self, other: "Condition"): + def conjoin(self, other: "Condition") -> "CompoundCondition": """AND's this condition with the other condition.""" - return CompoundCondition( - self, other, operator=ConditionLogicalOperator.CONJUNCTION - ) + return CompoundCondition(self, other, operator=ConditionOperator.CONJUNCTION) - def disjoin(self, other: "Condition"): + def disjoin(self, other: "Condition") -> "CompoundCondition": """OR's this condition with the other condition.""" - return CompoundCondition( - self, other, operator=ConditionLogicalOperator.DISJUNCTION - ) + return CompoundCondition(self, other, operator=ConditionOperator.DISJUNCTION) - def implicate(self, other: "Condition"): + def implicate(self, other: "Condition") -> "CompoundCondition": """Current condition must be false or both conditions true to be true.""" - return CompoundCondition( - self, other, operator=ConditionLogicalOperator.IMPLICATION - ) + return CompoundCondition(self, other, operator=ConditionOperator.IMPLICATION) @dataclass(frozen=True) -class LiteralCondition(LogicalCondition): +class LiteralCondition(Condition): """This condition evaluates as a literal without considering evaluation parameters.""" literal: ConditionState @@ -665,7 +673,7 @@ def negate(self) -> "LiteralCondition": @dataclass(frozen=True) -class TimeWindowCondition(LogicalCondition): +class TimeWindowCondition(Condition): """This condition should be true in the given simulation time window.""" start: float @@ -682,7 +690,7 @@ def evaluate(self, *args, simulation_time, **kwargs): @dataclass(frozen=True) -class DependeeActorCondition(LogicalCondition): +class DependeeActorCondition(Condition): """This condition should be true if the given actor exists.""" actor_id: str @@ -694,25 +702,8 @@ def evaluate(self, *args, actor_ids, **kwargs): return ConditionState.FALSE -class ConditionLogicalOperator(IntEnum): - """Represents logical operators between conditions.""" - - CONJUNCTION = enum.auto() - """Evaluate true if both operands are true, otherwise false.""" - - DISJUNCTION = enum.auto() - """Evaluate true if either operand is true, otherwise false.""" - - IMPLICATION = enum.auto() - """Evaluate true if either the first operand is false, or both operands are true, otherwise false.""" - - ## This would be desirable but makes the implementation more difficult in comparison to a negated condition. - # NEGATION=enum.auto() - # """True if its operand is false, otherwise false.""" - - @dataclass(frozen=True) -class NegatedCondition(LogicalCondition): +class NegatedCondition(Condition): """This condition negates the inner condition.""" inner_condition: Condition @@ -723,7 +714,7 @@ def evaluate(self, *args, **kwargs) -> ConditionState: @dataclass(frozen=True) -class CompoundCondition(LogicalCondition): +class CompoundCondition(Condition): """This condition should be true if the given actor exists.""" first_condition: Condition @@ -732,21 +723,21 @@ class CompoundCondition(LogicalCondition): second_condition: Condition """The second condition.""" - operator: ConditionLogicalOperator + operator: ConditionOperator """The operator used to combine these conditions.""" def evaluate(self, *args, **kwargs): eval_0 = self.first_condition.evaluate(*args, **kwargs) - if self.operator == ConditionLogicalOperator.IMPLICATION and not eval_0: + if self.operator == ConditionOperator.IMPLICATION and not eval_0: return ConditionState.TRUE eval_1 = self.second_condition.evaluate(*args, **kwargs) - if self.operator == ConditionLogicalOperator.IMPLICATION and eval_0 and eval_1: + if self.operator == ConditionOperator.IMPLICATION and eval_0 and eval_1: return ConditionState.TRUE - if self.operator == ConditionLogicalOperator.CONJUNCTION: + if self.operator == ConditionOperator.CONJUNCTION: return eval_0 & eval_1 - elif self.operator == ConditionLogicalOperator.DISJUNCTION: + elif self.operator == ConditionOperator.DISJUNCTION: return eval_0 | eval_1 return ConditionState.FALSE @@ -756,7 +747,7 @@ def evaluate(self, *args, **kwargs): class EntryTactic: """The tactic that the simulation should use to acquire a vehicle for an agent.""" - condition: LogicalCondition + condition: Condition """The condition to determine if this entry tactic should be used.""" From 6d7e449bc13520ce16fd676e2c59493b83a0928e Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 28 Apr 2023 20:56:22 +0000 Subject: [PATCH 04/59] Add delay condition and subject conditions. --- smarts/sstudio/types.py | 74 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 4d92677a42..10673c72f7 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -58,6 +58,7 @@ from smarts.core.utils.file import pickle_hash_int from smarts.core.utils.id import SocialAgentId from smarts.core.utils.math import rotate_cw_around_point +from smarts.sstudio.types import ConditionState class _SUMO_PARAMS_MODE(IntEnum): @@ -601,7 +602,7 @@ class ConditionState(IntFlag): FALSE = 0 """This condition is false.""" - UNTRIGGERED = 1 + BEFORE = 1 """This condition is false and never evaluated true before.""" EXPIRED = 2 """This condition is false and will never evaluate true.""" @@ -658,6 +659,21 @@ def implicate(self, other: "Condition") -> "CompoundCondition": return CompoundCondition(self, other, operator=ConditionOperator.IMPLICATION) +@dataclass(frozen=True) +class SubjectCondition(Condition): + """This condition assumes that there is a subject involved.""" + + def evaluate(self, *args, actor_info, **kwargs) -> ConditionState: + """Used to evaluate if a condition is met. + + Args: + actor_info: Information about the currently relevant actor. + Returns: + ConditionState: The evaluation result of the condition. + """ + raise NotImplementedError() + + @dataclass(frozen=True) class LiteralCondition(Condition): """This condition evaluates as a literal without considering evaluation parameters.""" @@ -686,7 +702,7 @@ def evaluate(self, *args, simulation_time, **kwargs): return ConditionState.TRUE elif self.end >= simulation_time: return ConditionState.EXPIRED - return ConditionState.UNTRIGGERED + return ConditionState.BEFORE @dataclass(frozen=True) @@ -713,6 +729,60 @@ def evaluate(self, *args, **kwargs) -> ConditionState: return ~self.inner_condition.evaluate(*args, **kwargs) +@dataclass(frozen=True) +class DelayCondition(Condition): + """This condition delays the inner condition by a number of seconds. + + This can be used to wait for some time after the inner condition has become true to be true. + Note that the original condition may no longer be true by the time delay has expired. + """ + + inner_condition: Condition + """The inner condition to delay.""" + + delay_seconds: float + """The number of seconds to delay for.""" + + inner_affects_final_result: bool = False + """If the inner condition must still be true at the end of the delay to be true.""" + + def evaluate(self, *args, simulation_time, **kwargs) -> ConditionState: + key = "met_time" + if (met_time := getattr(self, key, None)) is not None: + if simulation_time > met_time + self.delay_seconds: + result = ConditionState.TRUE + if self.inner_affects_final_result: + result &= self.inner_condition.evaluate( + *args, simulation_time, **kwargs + ) + return result + elif self.inner_condition.evaluate(*args, simulation_time, **kwargs): + setattr(self, key, simulation_time) + return ConditionState.FALSE + + +@dataclass(frozen=True) +class OnRoadCondition(SubjectCondition): + """This condition is true if the subject is on road.""" + + def evaluate(self, *args, actor_info, **kwargs) -> ConditionState: + return ConditionState.TRUE if actor_info.on_road else ConditionState.FALSE + + +@dataclass(frozen=True) +class IsVehicleType(SubjectCondition): + """This condition is true if the subject is of the given types.""" + + vehicle_type: str + + def evaluate(self, *args, actor_info, **kwargs) -> ConditionState: + return ( + ConditionState.TRUE + if actor_info.vehicle_type == self.vehicle_type + else ConditionState.FALSE + ) + + @dataclass(frozen=True) class CompoundCondition(Condition): """This condition should be true if the given actor exists.""" From 1166e9854222289ccf7fec121232dfdb84ee5432 Mon Sep 17 00:00:00 2001 From: Tucker Date: Wed, 3 May 2023 10:14:08 -0400 Subject: [PATCH 05/59] Remove circular import. --- smarts/sstudio/types.py | 1 - 1 file changed, 1 deletion(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 10673c72f7..799858b71d 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -58,7 +58,6 @@ from smarts.core.utils.file import pickle_hash_int from smarts.core.utils.id import SocialAgentId from smarts.core.utils.math import rotate_cw_around_point -from smarts.sstudio.types import ConditionState class _SUMO_PARAMS_MODE(IntEnum): From 2b47c50564163dac53daae2bfecbedf39d3bb6f8 Mon Sep 17 00:00:00 2001 From: Tucker Date: Wed, 3 May 2023 17:26:56 -0400 Subject: [PATCH 06/59] Remove condition from entry_tactic. --- smarts/sstudio/types.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 799858b71d..26a6b824c6 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -812,12 +812,13 @@ def evaluate(self, *args, **kwargs): return ConditionState.FALSE -@dataclass(frozen=True) +@dataclass(frozen=True, ) class EntryTactic: """The tactic that the simulation should use to acquire a vehicle for an agent.""" - condition: Condition - """The condition to determine if this entry tactic should be used.""" + pass + # condition: Condition + # """The condition to determine if this entry tactic should be used.""" @dataclass(frozen=True) From 8780ea13a1e2f50f5a0a58b3fa73bbc61c0afb01 Mon Sep 17 00:00:00 2001 From: Tucker Date: Thu, 4 May 2023 10:50:31 -0400 Subject: [PATCH 07/59] Prepare for evaluating conditions --- smarts/sstudio/types.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 26a6b824c6..04b4eb5c69 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -769,7 +769,7 @@ def evaluate(self, *args, actor_info, **kwargs) -> ConditionState: @dataclass(frozen=True) -class IsVehicleType(SubjectCondition): +class VehicleTypeCondition(SubjectCondition): """This condition is true if the subject is of the given types.""" vehicle_type: str @@ -812,7 +812,7 @@ def evaluate(self, *args, **kwargs): return ConditionState.FALSE -@dataclass(frozen=True, ) +@dataclass(frozen=True) class EntryTactic: """The tactic that the simulation should use to acquire a vehicle for an agent.""" @@ -834,6 +834,8 @@ class TrapEntryTactic(EntryTactic): default_entry_speed: Optional[float] = None """The speed that the vehicle starts at when the hijack limit expiry emits a new vehicle""" + condition: Condition = LiteralCondition(ConditionState.FALSE) + @dataclass(frozen=True) class IdEntryTactic(EntryTactic): @@ -845,6 +847,8 @@ class IdEntryTactic(EntryTactic): patience: float = 0.1 """Defines the amount of time this tactic will wait for an actor.""" + condition: Condition = LiteralCondition(ConditionState.FALSE) + def __post_init__(self): assert isinstance(self.actor_id, str) assert isinstance(self.patience, (float, int)) From ec8720259330ecdec3cc78a77075bf6a55595fe3 Mon Sep 17 00:00:00 2001 From: Tucker Date: Thu, 4 May 2023 13:11:04 -0400 Subject: [PATCH 08/59] Add conditions to id_actor_capture_manager. --- smarts/core/id_actor_capture_manager.py | 12 ++++++++++-- smarts/sstudio/types.py | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/smarts/core/id_actor_capture_manager.py b/smarts/core/id_actor_capture_manager.py index 0b307d16ea..55f832e1e6 100644 --- a/smarts/core/id_actor_capture_manager.py +++ b/smarts/core/id_actor_capture_manager.py @@ -58,8 +58,9 @@ def step(self, sim): for a, (b, m) in self._actor_for_agent.items() if m.start_time <= sim.elapsed_sim_time and a in social_vehicle_ids ): - assert isinstance(mission.entry_tactic, IdEntryTactic) - patience_expiry = mission.start_time + mission.entry_tactic.patience + entry_tactic = mission.entry_tactic + assert isinstance(entry_tactic, IdEntryTactic) + patience_expiry = mission.start_time + entry_tactic.patience if sim.elapsed_sim_time > patience_expiry: self._log.error( f"Actor aquisition skipped for `{agent_id}` scheduled to start between " @@ -69,6 +70,13 @@ def step(self, sim): used_actors.append(actor_id) sim.agent_manager.teardown_ego_agents({agent_id}) continue + vehicle = sim.vehicle_index.vehicle_by_id(actor_id) + if not entry_tactic.condition.evaluate( + simulation_time = sim.elapsed_sim_time, + actor_ids = sim.vehicle_index.vehicle_ids, + vehicle_state = vehicle.state if vehicle else None, + ): + continue vehicle: Optional[Vehicle] = self._take_existing_vehicle( sim, actor_id, diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 04b4eb5c69..6ed9326e8f 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -774,10 +774,10 @@ class VehicleTypeCondition(SubjectCondition): vehicle_type: str - def evaluate(self, *args, actor_info, **kwargs) -> ConditionState: + def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: return ( ConditionState.TRUE - if actor_info.vehicle_type == self.vehicle_type + if vehicle_state.vehicle_config_type == self.vehicle_type else ConditionState.FALSE ) From f4381a1a44c9a3f8bcca10d369c9fd179939666f Mon Sep 17 00:00:00 2001 From: Tucker Date: Thu, 4 May 2023 17:03:28 -0400 Subject: [PATCH 09/59] Test conditions in id capture manager. --- scenarios/sumo/loop/scenario.py | 5 ++- smarts/core/actor_capture_manager.py | 6 +++- smarts/core/id_actor_capture_manager.py | 45 ++++++++++++++----------- smarts/sstudio/types.py | 17 +++++----- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/scenarios/sumo/loop/scenario.py b/scenarios/sumo/loop/scenario.py index 587bd4fdb9..1dedb4134f 100644 --- a/scenarios/sumo/loop/scenario.py +++ b/scenarios/sumo/loop/scenario.py @@ -1,4 +1,5 @@ import random +import sys from pathlib import Path from smarts.core import seed @@ -54,7 +55,9 @@ [ t.Mission( route=t.RandomRoute(), - entry_tactic=t.IdEntryTactic("other_interest", patience=10), + entry_tactic=t.IdEntryTactic( + "other_interest", t.TimeWindowCondition(0.1, 20.0) + ), ) ], ) diff --git a/smarts/core/actor_capture_manager.py b/smarts/core/actor_capture_manager.py index c7d88cf44b..be68fd429a 100644 --- a/smarts/core/actor_capture_manager.py +++ b/smarts/core/actor_capture_manager.py @@ -21,6 +21,7 @@ # THE SOFTWARE. +import warnings from dataclasses import replace from typing import Optional @@ -101,7 +102,10 @@ def _take_existing_vehicle( assert isinstance(sim, SMARTS) if social: - # Not supported + # MTA: TODO implement this section of actor capture. + warnings.warn( + f"Unable to capture for {agent_id} because social agent id capture not yet implemented." + ) return None vehicle = sim.switch_control_to_agent( vehicle_id, agent_id, mission, recreate=True, is_hijacked=False diff --git a/smarts/core/id_actor_capture_manager.py b/smarts/core/id_actor_capture_manager.py index 55f832e1e6..734e30650c 100644 --- a/smarts/core/id_actor_capture_manager.py +++ b/smarts/core/id_actor_capture_manager.py @@ -25,7 +25,7 @@ from smarts.core.actor_capture_manager import ActorCaptureManager from smarts.core.plan import Mission from smarts.core.vehicle import Vehicle -from smarts.sstudio.types import IdEntryTactic +from smarts.sstudio.types import ConditionState, IdEntryTactic class IdActorCaptureManager(ActorCaptureManager): @@ -60,22 +60,22 @@ def step(self, sim): ): entry_tactic = mission.entry_tactic assert isinstance(entry_tactic, IdEntryTactic) - patience_expiry = mission.start_time + entry_tactic.patience - if sim.elapsed_sim_time > patience_expiry: - self._log.error( + vehicle = sim.vehicle_index.vehicle_by_id(actor_id) + condition_result = entry_tactic.condition.evaluate( + simulation_time=sim.elapsed_sim_time, + actor_ids=sim.vehicle_index.vehicle_ids, + vehicle_state=vehicle.state if vehicle else None, + mission_start_time=mission.start_time, + ) + if condition_result == ConditionState.EXPIRED: + self._log.warning( f"Actor aquisition skipped for `{agent_id}` scheduled to start between " - + f"`{mission.start_time}` and `{patience_expiry}` has expired with no vehicle." - f"`simulation time: {sim.elapsed_sim_time}`" + + f"`Condition `{entry_tactic.condition}` has expired with no vehicle." ) used_actors.append(actor_id) sim.agent_manager.teardown_ego_agents({agent_id}) continue - vehicle = sim.vehicle_index.vehicle_by_id(actor_id) - if not entry_tactic.condition.evaluate( - simulation_time = sim.elapsed_sim_time, - actor_ids = sim.vehicle_index.vehicle_ids, - vehicle_state = vehicle.state if vehicle else None, - ): + if not condition_result: continue vehicle: Optional[Vehicle] = self._take_existing_vehicle( sim, @@ -104,18 +104,25 @@ def reset(self, scenario, sim): for agent_id, mission in missions.items(): if mission is None: continue - if not isinstance(mission.entry_tactic, IdEntryTactic): + entry_tactic = mission.entry_tactic + if not isinstance(entry_tactic, IdEntryTactic): continue - patience_expiry = mission.start_time + mission.entry_tactic.patience - if sim.elapsed_sim_time > patience_expiry: - self._log.error( - f"ID actor capture entry tactic failed for `{agent_id}` scheduled to start between " - + f"`{mission.start_time}` and `{patience_expiry}` because simulation skipped to " + vehicle = sim.vehicle_index.vehicle_by_id(entry_tactic.actor_id, None) + condition_result = entry_tactic.condition.evaluate( + simulation_time=sim.elapsed_sim_time, + actor_ids=sim.vehicle_index.vehicle_ids, + vehicle_state=vehicle.state if vehicle else None, + mission_start_time=mission.start_time, + ) + if condition_result == ConditionState.EXPIRED: + self._log.warning( + f"Actor aquisition skipped for `{agent_id}` scheduled to start with" + + f"`Condition:{entry_tactic.condition}` because simulation skipped to " f"`simulation time: {sim.elapsed_sim_time}`" ) cancelled_agents.add(agent_id) continue - self._actor_for_agent[mission.entry_tactic.actor_id] = (agent_id, mission) + self._actor_for_agent[entry_tactic.actor_id] = (agent_id, mission) if len(cancelled_agents) > 0: sim.agent_manager.teardown_ego_agents(cancelled_agents) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 6ed9326e8f..9c6fd41d2d 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -22,6 +22,7 @@ import logging import math import random +import sys from dataclasses import dataclass, field, replace from enum import IntEnum, IntFlag from sys import maxsize @@ -696,10 +697,10 @@ class TimeWindowCondition(Condition): end: float """The ending simulation time as of which this condition becomes expired.""" - def evaluate(self, *args, simulation_time, **kwargs): + def evaluate(self, *args, simulation_time, mission_start_time, **kwargs): if self.start <= simulation_time < self.end: return ConditionState.TRUE - elif self.end >= simulation_time: + elif simulation_time > self.end: return ConditionState.EXPIRED return ConditionState.BEFORE @@ -833,8 +834,8 @@ class TrapEntryTactic(EntryTactic): """The prefixes of vehicles to avoid hijacking""" default_entry_speed: Optional[float] = None """The speed that the vehicle starts at when the hijack limit expiry emits a new vehicle""" - - condition: Condition = LiteralCondition(ConditionState.FALSE) + condition: Condition = LiteralCondition(ConditionState.TRUE) + """A condition that is used to add additional exclusions.""" @dataclass(frozen=True) @@ -844,14 +845,12 @@ class IdEntryTactic(EntryTactic): actor_id: str """The id of the actor to take over.""" - patience: float = 0.1 - """Defines the amount of time this tactic will wait for an actor.""" - - condition: Condition = LiteralCondition(ConditionState.FALSE) + condition: Condition = TimeWindowCondition(0.1, sys.maxsize) + """A condition that is used to add additional exclusions.""" def __post_init__(self): assert isinstance(self.actor_id, str) - assert isinstance(self.patience, (float, int)) + assert isinstance(self.condition, (Condition)) @dataclass(frozen=True) From 9e9b8d32fbffd913ffcab46b5f9061520a153ec9 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 14:08:03 +0000 Subject: [PATCH 10/59] Deprecated sstudio mission start time. --- .../scenario.py | 2 +- .../scenario.py | 14 +++- .../scenario.py | 73 +++++++++++-------- scenarios/sumo/loop/scenario.py | 4 +- .../sumo/merge/3lane_agents_1/scenario.py | 22 ++++-- .../sumo/merge/3lane_agents_2/scenario.py | 20 ++++- .../merge_exit_sumo_t_agents_1/scenario.py | 5 +- .../3lane_cruise_agents_1/scenario.py | 14 +++- .../3lane_cruise_agents_3/scenario.py | 25 ++++++- .../3lane_cut_in_agents_1/scenario.py | 5 +- .../3lane_overtake_agents_1/scenario.py | 5 +- scenarios/sumo/zoo_intersection/scenario.py | 7 +- smarts/core/id_actor_capture_manager.py | 2 + smarts/core/plan.py | 1 + smarts/core/scenario.py | 17 +++-- smarts/sstudio/types.py | 11 +-- 16 files changed, 156 insertions(+), 71 deletions(-) diff --git a/scenarios/sumo/intersections/1_to_1lane_left_turn_c_agents_1/scenario.py b/scenarios/sumo/intersections/1_to_1lane_left_turn_c_agents_1/scenario.py index 628c16f539..17f3324ec0 100644 --- a/scenarios/sumo/intersections/1_to_1lane_left_turn_c_agents_1/scenario.py +++ b/scenarios/sumo/intersections/1_to_1lane_left_turn_c_agents_1/scenario.py @@ -93,8 +93,8 @@ ego_missions = [ Mission( route=route, - start_time=12, # Delayed start, to ensure road has prior traffic. entry_tactic=TrapEntryTactic( + start_time=12, # Delayed start, to ensure road has prior traffic. wait_to_hijack_limit_s=1, zone=MapZone( start=( diff --git a/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py b/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py index 54417c77b2..3dfcacb901 100644 --- a/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py +++ b/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py @@ -23,7 +23,15 @@ from pathlib import Path from smarts.sstudio import gen_scenario -from smarts.sstudio.types import Flow, Mission, Route, Scenario, Traffic, TrafficActor +from smarts.sstudio.types import ( + Flow, + Mission, + Route, + Scenario, + Traffic, + TrafficActor, + TrapEntryTactic, +) normal = TrafficActor( name="car", @@ -86,7 +94,9 @@ ego_missions = [ Mission( route=route, - start_time=4, # Delayed start, to ensure road has prior traffic. + entry_tactic=TrapEntryTactic( + start_time=4, wait_to_hijack_limit_s=0.1 + ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/intersections/1_to_2lane_left_turn_t_agents_1/scenario.py b/scenarios/sumo/intersections/1_to_2lane_left_turn_t_agents_1/scenario.py index fd14a6d308..6a12f07434 100644 --- a/scenarios/sumo/intersections/1_to_2lane_left_turn_t_agents_1/scenario.py +++ b/scenarios/sumo/intersections/1_to_2lane_left_turn_t_agents_1/scenario.py @@ -23,65 +23,78 @@ from pathlib import Path from smarts.sstudio import gen_scenario -from smarts.sstudio.types import Flow, Mission, Route, Scenario, Traffic, TrafficActor +from smarts.sstudio.types import ( + Flow, + Mission, + Route, + Scenario, + Traffic, + TrafficActor, + TrapEntryTactic, +) normal = TrafficActor( name="car", ) -horizontal_routes = [ - ("E4", 0, "E1", 0), - ("E4", 1, "E1", 1), - ("-E1", 0, "-E4", 0), - ("-E1", 1, "-E4", 1), -] - -turn_left_routes = [ - ("E0", 0, "E1", 1), - ("E4", 1, "-E0", 0), +# flow_name = (start_lane, end_lane) +route_opt = [ + (0, 0), + (1, 1), + (2, 2), ] -turn_right_routes = [ - ("E0", 0, "-E4", 0), - ("-E1", 0, "-E0", 0), -] - -# Total route combinations = 8C1 + 8C2 + 8C3 + 8C4 + 8C5 = 218 -# Repeated route combinations = 218 * 2 = 436 -all_routes = horizontal_routes + turn_left_routes + turn_right_routes +# Traffic combinations = 3C2 + 3C3 = 3 + 1 = 4 +# Repeated traffic combinations = 4 * 100 = 400 +min_flows = 2 +max_flows = 3 route_comb = [ - com for elems in range(1, 6) for com in combinations(all_routes, elems) -] * 2 + com + for elems in range(min_flows, max_flows + 1) + for com in combinations(route_opt, elems) +] * 100 + traffic = {} for name, routes in enumerate(route_comb): traffic[str(name)] = Traffic( flows=[ Flow( route=Route( - begin=(start_edge, start_lane, 0), - end=(end_edge, end_lane, "max"), + begin=("gneE3", start_lane, 0), + end=("gneE4", end_lane, "max"), ), # Random flow rate, between x and y vehicles per minute. - rate=60 * random.uniform(5, 10), + rate=60 * random.uniform(10, 20), # Random flow start time, between x and y seconds. - begin=random.uniform(0, 3), + begin=random.uniform(0, 5), # For an episode with maximum_episode_steps=3000 and step # time=0.1s, the maximum episode time=300s. Hence, traffic is # set to end at 900s, which is greater than maximum episode # time of 300s. end=60 * 15, actors={normal: 1}, + randomly_spaced=True, ) - for start_edge, start_lane, end_edge, end_lane in routes + for start_lane, end_lane in routes ] ) -route = Route(begin=("E0", 0, 5), end=("E1", 0, "max")) + ego_missions = [ Mission( - route=route, - start_time=4, # Delayed start, to ensure road has prior traffic. - ) + Route(begin=("gneE6", 0, 10), end=("gneE4", 2, "max")), + entry_tactic=TrapEntryTactic( + entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=0.1), + wait_to_hijack_limit_s=1, + ), + ), + Mission( + Route(begin=("gneE3", 0, 10), end=("gneE4", 0, "max")), + entry_tactic=TrapEntryTactic( + entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=0.1), + wait_to_hijack_limit_s=1, + ), + ), ] gen_scenario( diff --git a/scenarios/sumo/loop/scenario.py b/scenarios/sumo/loop/scenario.py index 1dedb4134f..90399b2ddc 100644 --- a/scenarios/sumo/loop/scenario.py +++ b/scenarios/sumo/loop/scenario.py @@ -56,7 +56,9 @@ t.Mission( route=t.RandomRoute(), entry_tactic=t.IdEntryTactic( - "other_interest", t.TimeWindowCondition(0.1, 20.0) + start_time=0.1, + actor_id="other_interest", + condition=t.TimeWindowCondition(0.1, 20.0), ), ) ], diff --git a/scenarios/sumo/merge/3lane_agents_1/scenario.py b/scenarios/sumo/merge/3lane_agents_1/scenario.py index 1de44a7fcc..264247876e 100644 --- a/scenarios/sumo/merge/3lane_agents_1/scenario.py +++ b/scenarios/sumo/merge/3lane_agents_1/scenario.py @@ -23,7 +23,15 @@ from pathlib import Path from smarts.sstudio import gen_scenario -from smarts.sstudio.types import Flow, Mission, Route, Scenario, Traffic, TrafficActor +from smarts.sstudio.types import ( + Flow, + Mission, + Route, + Scenario, + Traffic, + TrafficActor, + TrapEntryTactic, +) normal = TrafficActor( name="car", @@ -71,12 +79,16 @@ ] ) -route = Route(begin=("gneE6", 0, 10), end=("gneE4", 2, "max")) + ego_missions = [ Mission( - route=route, - start_time=15, # Delayed start, to ensure road has prior traffic. - ) + Route(begin=("gneE6", 0, 10), end=("gneE4", 2, "max")), + entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=1), + ), + Mission( + Route(begin=("gneE3", 0, 10), end=("gneE4", 0, "max")), + entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=1), + ), ] gen_scenario( diff --git a/scenarios/sumo/merge/3lane_agents_2/scenario.py b/scenarios/sumo/merge/3lane_agents_2/scenario.py index 0365a090cb..264247876e 100644 --- a/scenarios/sumo/merge/3lane_agents_2/scenario.py +++ b/scenarios/sumo/merge/3lane_agents_2/scenario.py @@ -23,7 +23,15 @@ from pathlib import Path from smarts.sstudio import gen_scenario -from smarts.sstudio.types import Flow, Mission, Route, Scenario, Traffic, TrafficActor +from smarts.sstudio.types import ( + Flow, + Mission, + Route, + Scenario, + Traffic, + TrafficActor, + TrapEntryTactic, +) normal = TrafficActor( name="car", @@ -73,8 +81,14 @@ ego_missions = [ - Mission(Route(begin=("gneE6", 0, 10), end=("gneE4", 2, "max")), start_time=15), - Mission(Route(begin=("gneE3", 0, 10), end=("gneE4", 0, "max")), start_time=15), + Mission( + Route(begin=("gneE6", 0, 10), end=("gneE4", 2, "max")), + entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=1), + ), + Mission( + Route(begin=("gneE3", 0, 10), end=("gneE4", 0, "max")), + entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=1), + ), ] gen_scenario( diff --git a/scenarios/sumo/platoon/merge_exit_sumo_t_agents_1/scenario.py b/scenarios/sumo/platoon/merge_exit_sumo_t_agents_1/scenario.py index c533946265..a6637814a0 100644 --- a/scenarios/sumo/platoon/merge_exit_sumo_t_agents_1/scenario.py +++ b/scenarios/sumo/platoon/merge_exit_sumo_t_agents_1/scenario.py @@ -99,8 +99,9 @@ ego_missions = [ EndlessMission( begin=("E0", 2, 5), - start_time=31, - entry_tactic=TrapEntryTactic(wait_to_hijack_limit_s=0, default_entry_speed=0), + entry_tactic=TrapEntryTactic( + start_time=31, wait_to_hijack_limit_s=0, default_entry_speed=0 + ), ) # Delayed start, to ensure road has prior traffic. ] diff --git a/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py b/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py index 9e62110d74..396cbd230e 100644 --- a/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py +++ b/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py @@ -23,7 +23,15 @@ from pathlib import Path from smarts.sstudio import gen_scenario -from smarts.sstudio.types import Flow, Mission, Route, Scenario, Traffic, TrafficActor +from smarts.sstudio.types import ( + Flow, + Mission, + Route, + Scenario, + Traffic, + TrafficActor, + TrapEntryTactic, +) normal = TrafficActor( name="car", @@ -76,7 +84,9 @@ ego_missions = [ Mission( route=route, - start_time=17, # Delayed start, to ensure road has prior traffic. + entry_tactic=TrapEntryTactic( + start_time=17, wait_to_hijack_limit_s=0.1 + ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py b/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py index 8b922983cd..8a8b5ad2e5 100644 --- a/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py +++ b/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py @@ -23,7 +23,15 @@ from pathlib import Path from smarts.sstudio import gen_scenario -from smarts.sstudio.types import Flow, Mission, Route, Scenario, Traffic, TrafficActor +from smarts.sstudio.types import ( + Flow, + Mission, + Route, + Scenario, + Traffic, + TrafficActor, + TrapEntryTactic, +) normal = TrafficActor( name="car", @@ -72,9 +80,18 @@ ) ego_missions = [ - Mission(Route(begin=("gneE3", 0, 10), end=("gneE3", 0, "max")), start_time=19), - Mission(Route(begin=("gneE3", 1, 10), end=("gneE3", 1, "max")), start_time=21), - Mission(Route(begin=("gneE3", 2, 10), end=("gneE3", 2, "max")), start_time=17), + Mission( + Route(begin=("gneE3", 0, 10), end=("gneE3", 0, "max")), + entry_tactic=TrapEntryTactic(start_time=19, wait_to_hijack_limit_s=0.1), + ), + Mission( + Route(begin=("gneE3", 1, 10), end=("gneE3", 1, "max")), + entry_tactic=TrapEntryTactic(start_time=21, wait_to_hijack_limit_s=0.1), + ), + Mission( + Route(begin=("gneE3", 2, 10), end=("gneE3", 2, "max")), + entry_tactic=TrapEntryTactic(start_time=17, wait_to_hijack_limit_s=0.1), + ), ] gen_scenario( diff --git a/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py b/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py index dcc339bd8b..d9d9d8cbec 100644 --- a/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py +++ b/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py @@ -34,6 +34,7 @@ SmartsLaneChangingModel, Traffic, TrafficActor, + TrapEntryTactic, ) normal = TrafficActor( @@ -94,7 +95,9 @@ ego_missions = [ Mission( route=route, - start_time=20, # Delayed start, to ensure road has prior traffic. + entry_tactic=TrapEntryTactic( + start_time=20, wait_to_hijack_limit_s=0.1 + ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py b/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py index 39f14d9863..b20b9a4123 100644 --- a/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py +++ b/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py @@ -31,6 +31,7 @@ Scenario, Traffic, TrafficActor, + TrapEntryTactic, ) normal = TrafficActor( @@ -83,7 +84,9 @@ ego_missions = [ Mission( route=route, - start_time=17, # Delayed start, to ensure road has prior traffic. + entry_tactic=TrapEntryTactic( + start_time=17, wait_to_hijack_limit_s=0.1 + ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/zoo_intersection/scenario.py b/scenarios/sumo/zoo_intersection/scenario.py index 678970a929..7db35a6d65 100644 --- a/scenarios/sumo/zoo_intersection/scenario.py +++ b/scenarios/sumo/zoo_intersection/scenario.py @@ -96,7 +96,12 @@ f"s-agent-{social_agent1.name}": ( [social_agent1], [ - EndlessMission(begin=("edge-south-SN", 0, 10), start_time=0.7), + EndlessMission( + begin=("edge-south-SN", 0, 10), + entry_tactic=TrapEntryTactic( + start_time=0.7, wait_to_hijack_limit_s=0.1 + ), + ), ], ), }, diff --git a/smarts/core/id_actor_capture_manager.py b/smarts/core/id_actor_capture_manager.py index 734e30650c..08f02b079d 100644 --- a/smarts/core/id_actor_capture_manager.py +++ b/smarts/core/id_actor_capture_manager.py @@ -68,6 +68,7 @@ def step(self, sim): mission_start_time=mission.start_time, ) if condition_result == ConditionState.EXPIRED: + print(condition_result) self._log.warning( f"Actor aquisition skipped for `{agent_id}` scheduled to start between " + f"`Condition `{entry_tactic.condition}` has expired with no vehicle." @@ -77,6 +78,7 @@ def step(self, sim): continue if not condition_result: continue + print(condition_result) vehicle: Optional[Vehicle] = self._take_existing_vehicle( sim, actor_id, diff --git a/smarts/core/plan.py b/smarts/core/plan.py index a0e4aaa165..1f70de8b90 100644 --- a/smarts/core/plan.py +++ b/smarts/core/plan.py @@ -166,6 +166,7 @@ def _drove_off_map(self, veh_pos: Point, veh_heading: float) -> bool: def default_entry_tactic(default_entry_speed: Optional[float] = None) -> EntryTactic: """The default tactic the simulation will use to acquire an actor for an agent.""" return TrapEntryTactic( + start_time=0, wait_to_hijack_limit_s=0, exclusion_prefixes=tuple(), zone=None, diff --git a/smarts/core/scenario.py b/smarts/core/scenario.py index f989237582..ef54a0cc77 100644 --- a/smarts/core/scenario.py +++ b/smarts/core/scenario.py @@ -642,13 +642,11 @@ def create_dynamic_traffic_history_mission( positional_mission = Mission( start=start, entry_tactic=entry_tactic, - start_time=0, goal=PositionalGoal(veh_goal, radius=positional_radius), ) traverse_mission = Mission( start=start, entry_tactic=entry_tactic, - start_time=0, goal=TraverseGoal(self._road_map), ) return positional_mission, traverse_mission @@ -795,12 +793,13 @@ def to_scenario_via( ) goal = PositionalGoal(position, radius=2) + entry_tactic = mission.entry_tactic return Mission( start=start, route_vias=mission.route.via, goal=goal, - start_time=mission.start_time, - entry_tactic=mission.entry_tactic, + start_time=entry_tactic.start_time if entry_tactic else 0, + entry_tactic=entry_tactic, via=to_scenario_via(mission.via, road_map), ) elif isinstance(mission, sstudio_types.EndlessMission): @@ -810,11 +809,12 @@ def to_scenario_via( ) start = Start(position, heading) + entry_tactic = mission.entry_tactic return Mission( start=start, goal=EndlessGoal(), - start_time=mission.start_time, - entry_tactic=mission.entry_tactic, + start_time=entry_tactic.start_time if entry_tactic else 0, + entry_tactic=entry_tactic, via=to_scenario_via(mission.via, road_map), ) elif isinstance(mission, sstudio_types.LapMission): @@ -839,12 +839,13 @@ def to_scenario_via( road_map, ) + entry_tactic = mission.entry_tactic return LapMission( start=Start(start_position, start_heading), goal=PositionalGoal(end_position, radius=2), route_vias=mission.route.via, - start_time=mission.start_time, - entry_tactic=mission.entry_tactic, + start_time=entry_tactic.start_time if entry_tactic else 0, + entry_tactic=entry_tactic, via=to_scenario_via(mission.via, road_map), num_laps=mission.num_laps, route_length=route.road_length, diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 9c6fd41d2d..38bde74553 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -817,7 +817,7 @@ def evaluate(self, *args, **kwargs): class EntryTactic: """The tactic that the simulation should use to acquire a vehicle for an agent.""" - pass + start_time: float # condition: Condition # """The condition to determine if this entry tactic should be used.""" @@ -863,11 +863,6 @@ class Mission: via: Tuple[Via, ...] = () """Points on an road that an actor must pass through""" - start_time: float = 0.1 - """The earliest simulation time that this mission starts but may start later in couple with - `entry_tactic`. - """ - entry_tactic: Optional[EntryTactic] = None """A specific tactic the mission should employ to start the mission.""" @@ -888,8 +883,6 @@ class EndlessMission: """ via: Tuple[Via, ...] = () """Points on a road that an actor must pass through""" - start_time: float = 0.1 - """The earliest simulation time that this mission starts""" entry_tactic: Optional[EntryTactic] = None """A specific tactic the mission should employ to start the mission""" @@ -906,8 +899,6 @@ class LapMission: """The amount of times to repeat the mission""" via: Tuple[Via, ...] = () """Points on a road that an actor must pass through""" - start_time: float = 0.1 - """The earliest simulation time that this mission starts""" entry_tactic: Optional[EntryTactic] = None """A specific tactic the mission should employ to start the mission""" From 6c2bddde720db844a4cdc6cc2fc6ed1bddf2ca0e Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 16:18:04 +0000 Subject: [PATCH 11/59] Add back start_time but deprecate. --- smarts/core/scenario.py | 21 +++++++++++++++++---- smarts/sstudio/types.py | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/smarts/core/scenario.py b/smarts/core/scenario.py index ef54a0cc77..d37fdecb4a 100644 --- a/smarts/core/scenario.py +++ b/smarts/core/scenario.py @@ -66,7 +66,7 @@ vec_to_radians, ) from smarts.sstudio import types as sstudio_types -from smarts.sstudio.types import MapSpec +from smarts.sstudio.types import EntryTactic, MapSpec from smarts.sstudio.types import Via as SSVia VehicleWindow = TrafficHistory.TrafficHistoryVehicleWindow @@ -794,11 +794,12 @@ def to_scenario_via( goal = PositionalGoal(position, radius=2) entry_tactic = mission.entry_tactic + start_time = Scenario._extract_mission_start_time(mission, entry_tactic) return Mission( start=start, route_vias=mission.route.via, goal=goal, - start_time=entry_tactic.start_time if entry_tactic else 0, + start_time=start_time, entry_tactic=entry_tactic, via=to_scenario_via(mission.via, road_map), ) @@ -810,10 +811,11 @@ def to_scenario_via( start = Start(position, heading) entry_tactic = mission.entry_tactic + start_time = Scenario._extract_mission_start_time(mission, entry_tactic) return Mission( start=start, goal=EndlessGoal(), - start_time=entry_tactic.start_time if entry_tactic else 0, + start_time=start_time, entry_tactic=entry_tactic, via=to_scenario_via(mission.via, road_map), ) @@ -840,11 +842,12 @@ def to_scenario_via( ) entry_tactic = mission.entry_tactic + start_time = Scenario._extract_mission_start_time(mission, entry_tactic) return LapMission( start=Start(start_position, start_heading), goal=PositionalGoal(end_position, radius=2), route_vias=mission.route.via, - start_time=entry_tactic.start_time if entry_tactic else 0, + start_time=start_time, entry_tactic=entry_tactic, via=to_scenario_via(mission.via, road_map), num_laps=mission.num_laps, @@ -855,6 +858,16 @@ def to_scenario_via( f"sstudio mission={mission} is an invalid type={type(mission)}" ) + @staticmethod + def _extract_mission_start_time(mission, entry_tactic: Optional[EntryTactic]): + return ( + entry_tactic.start_time + if entry_tactic + else mission.start_time + if mission.start_time < sstudio_types.MISSING + else 0 + ) + @staticmethod def is_valid_scenario(scenario_root) -> bool: """Checks if the scenario_root directory matches our expected scenario structure diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 38bde74553..d2b56849e4 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -863,9 +863,21 @@ class Mission: via: Tuple[Via, ...] = () """Points on an road that an actor must pass through""" + start_time: float = MISSING + """The earliest simulation time that this mission starts but may start later in couple with + `entry_tactic`. + """ + entry_tactic: Optional[EntryTactic] = None """A specific tactic the mission should employ to start the mission.""" + def __post_init__(self): + if self.start_time != sys.maxsize: + warnings.warn( + "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", + category=DeprecationWarning, + ) + @dataclass(frozen=True) class EndlessMission: @@ -883,9 +895,18 @@ class EndlessMission: """ via: Tuple[Via, ...] = () """Points on a road that an actor must pass through""" + start_time: float = MISSING + """The earliest simulation time that this mission starts""" entry_tactic: Optional[EntryTactic] = None """A specific tactic the mission should employ to start the mission""" + def __post_init__(self): + if self.start_time != sys.maxsize: + warnings.warn( + "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", + category=DeprecationWarning, + ) + @dataclass(frozen=True) class LapMission: @@ -899,9 +920,18 @@ class LapMission: """The amount of times to repeat the mission""" via: Tuple[Via, ...] = () """Points on a road that an actor must pass through""" + start_time: float = MISSING + """The earliest simulation time that this mission starts""" entry_tactic: Optional[EntryTactic] = None """A specific tactic the mission should employ to start the mission""" + def __post_init__(self): + if self.start_time != sys.maxsize: + warnings.warn( + "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", + category=DeprecationWarning, + ) + @dataclass(frozen=True) class GroupedLapMission: @@ -1139,9 +1169,9 @@ def to_geometry(self, road_map: Optional[RoadMap] = None) -> Polygon: class BubbleLimits: """Defines the capture limits of a bubble.""" - hijack_limit: int = maxsize + hijack_limit: int = sys.maxsize """The maximum number of vehicles the bubble can hijack""" - shadow_limit: int = maxsize + shadow_limit: int = sys.maxsize """The maximum number of vehicles the bubble can shadow""" def __post_init__(self): From d4d16775b92078222f5205b40b3eed8f7a0fe9c4 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 16:34:22 +0000 Subject: [PATCH 12/59] Fix missing constant. --- smarts/sstudio/types.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index d2b56849e4..e464881541 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -23,9 +23,9 @@ import math import random import sys +import warnings from dataclasses import dataclass, field, replace from enum import IntEnum, IntFlag -from sys import maxsize from typing import ( Any, Callable, @@ -60,6 +60,8 @@ from smarts.core.utils.id import SocialAgentId from smarts.core.utils.math import rotate_cw_around_point +MISSING = sys.maxsize + class _SUMO_PARAMS_MODE(IntEnum): TITLE_CASE = 0 From 2ba1928573830ee6845b151f7e0611277d74dcbe Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 16:35:15 +0000 Subject: [PATCH 13/59] Change NameError to DeprecationWarning for deprecation. --- smarts/core/agent_interface.py | 3 ++- smarts/core/events.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/smarts/core/agent_interface.py b/smarts/core/agent_interface.py index 23fa799fd1..1ae2e95435 100644 --- a/smarts/core/agent_interface.py +++ b/smarts/core/agent_interface.py @@ -254,7 +254,8 @@ class DoneCriteria: @property def actors_alive(self): """Deprecated. Use interest.""" - raise NameError("Deprecated. Use interest.") + warnings.warn("Use interest.", category=DeprecationWarning) + return self.interest @dataclass diff --git a/smarts/core/events.py b/smarts/core/events.py index 61fb939af5..405f6d636a 100644 --- a/smarts/core/events.py +++ b/smarts/core/events.py @@ -17,6 +17,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import warnings from typing import NamedTuple, Sequence @@ -48,4 +49,5 @@ class Events(NamedTuple): @property def actors_alive_done(self): """Deprecated. Use interest_done.""" - raise NameError("Deprecated. Use interest_done.") + warnings.warn("Use interest_done.", category=DeprecationWarning) + return self.interest_done From 006a4cc3096f52a83f7be2e12ebedf6330c5f402 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 16:36:20 +0000 Subject: [PATCH 14/59] Give GroupedLapMission entry tactic. --- smarts/sstudio/genscenario.py | 3 +++ smarts/sstudio/types.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/smarts/sstudio/genscenario.py b/smarts/sstudio/genscenario.py index 3eb22fa470..b1840c6bad 100644 --- a/smarts/sstudio/genscenario.py +++ b/smarts/sstudio/genscenario.py @@ -256,6 +256,7 @@ def gen_scenario( grid_offset=mission.offset, used_lanes=mission.lanes, vehicle_count=mission.actor_count, + entry_tactic=mission.entry_tactic, num_laps=mission.num_laps, map_spec=map_spec, ) @@ -475,6 +476,7 @@ def gen_group_laps( grid_offset: int, used_lanes: int, vehicle_count: int, + entry_tactic: Optional[types.EntryTactic], num_laps: int = 3, map_spec: Optional[types.MapSpec] = None, ): @@ -516,6 +518,7 @@ def gen_group_laps( end=(end_road_id, (end_lane + i) % used_lanes, end_offset), ), num_laps=num_laps, + entry_tactic=entry_tactic, # route_length=route_length, ) ) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index e464881541..6ac0fba0bf 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -951,6 +951,8 @@ class GroupedLapMission: """The amount of times to repeat the mission""" via: Tuple[Via, ...] = () """Points on a road that an actor must pass through""" + entry_tactic: Optional[EntryTactic] = None + """A specific tactic the mission should employ to start the mission""" @dataclass(frozen=True) From afd5d492d422e24d2fc9e90b062d411a848638c3 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 19:07:00 +0000 Subject: [PATCH 15/59] Fix delay condition. --- smarts/sstudio/types.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 6ac0fba0bf..48a3a5ddb7 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -737,29 +737,33 @@ class DelayCondition(Condition): This can be used to wait for some time after the inner condition has become true to be true. Note that the original condition may no longer be true by the time delay has expired. + + This will never resolve TRUE on the first evaluate. """ inner_condition: Condition """The inner condition to delay.""" - delay_seconds: float + seconds: float """The number of seconds to delay for.""" - inner_affects_final_result: bool = False + persistant: bool = False """If the inner condition must still be true at the end of the delay to be true.""" def evaluate(self, *args, simulation_time, **kwargs) -> ConditionState: key = "met_time" if (met_time := getattr(self, key, None)) is not None: - if simulation_time > met_time + self.delay_seconds: + if simulation_time >= met_time + self.seconds: result = ConditionState.TRUE - if self.inner_affects_final_result: + if self.persistant: result &= self.inner_condition.evaluate( - *args, simulation_time, **kwargs + *args, simulation_time=simulation_time, **kwargs ) return result - elif self.inner_condition.evaluate(*args, simulation_time, **kwargs): - setattr(self, key, simulation_time) + elif self.inner_condition.evaluate( + *args, simulation_time=simulation_time, **kwargs + ): + object.__setattr__(self, key, simulation_time) return ConditionState.FALSE From 58d23e67130958d711ad0279b6494bdf11aa2e4e Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 19:07:59 +0000 Subject: [PATCH 16/59] Disallow abstract condition operations. --- smarts/sstudio/types.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 48a3a5ddb7..b4188a3f6a 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -646,20 +646,48 @@ def evaluate(self, *args, **kwargs) -> ConditionState: def negate(self) -> "NegatedCondition": """Negates this condition.""" + abstract_conditions = (Condition, SubjectCondition) + if self.__class__ in abstract_conditions: + raise TypeError("Base condition cannot be negated.") return NegatedCondition(self) def conjoin(self, other: "Condition") -> "CompoundCondition": """AND's this condition with the other condition.""" + abstract_conditions = (Condition, SubjectCondition) + if ( + self.__class__ in abstract_conditions + or other.__class__ in abstract_conditions + ): + raise TypeError("Base condition cannot be conjoined.") return CompoundCondition(self, other, operator=ConditionOperator.CONJUNCTION) def disjoin(self, other: "Condition") -> "CompoundCondition": """OR's this condition with the other condition.""" + abstract_conditions = (Condition, SubjectCondition) + if ( + self.__class__ in abstract_conditions + or other.__class__ in abstract_conditions + ): + raise TypeError("Base condition cannot be disjoined.") return CompoundCondition(self, other, operator=ConditionOperator.DISJUNCTION) def implicate(self, other: "Condition") -> "CompoundCondition": """Current condition must be false or both conditions true to be true.""" + abstract_conditions = (Condition, SubjectCondition) + if ( + self.__class__ in abstract_conditions + or other.__class__ in abstract_conditions + ): + raise TypeError("Base condition cannot be implicated.") return CompoundCondition(self, other, operator=ConditionOperator.IMPLICATION) + def delay(self, seconds, persistant=False) -> "DelayCondition": + """Delays the current condition until the given number of simulation seconds have occured.""" + abstract_conditions = (Condition, SubjectCondition) + if self.__class__ in abstract_conditions: + raise TypeError("Base condition cannot be delayed.") + return DelayCondition(self, seconds=seconds, persistant=persistant) + @dataclass(frozen=True) class SubjectCondition(Condition): From bac5cf938841ac21c465c9eec30b0cefcd0c8f66 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 19:15:37 +0000 Subject: [PATCH 17/59] Move abstract condition constraints down. --- smarts/sstudio/types.py | 50 ++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index b4188a3f6a..97b88e6e38 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -646,46 +646,22 @@ def evaluate(self, *args, **kwargs) -> ConditionState: def negate(self) -> "NegatedCondition": """Negates this condition.""" - abstract_conditions = (Condition, SubjectCondition) - if self.__class__ in abstract_conditions: - raise TypeError("Base condition cannot be negated.") return NegatedCondition(self) def conjoin(self, other: "Condition") -> "CompoundCondition": """AND's this condition with the other condition.""" - abstract_conditions = (Condition, SubjectCondition) - if ( - self.__class__ in abstract_conditions - or other.__class__ in abstract_conditions - ): - raise TypeError("Base condition cannot be conjoined.") return CompoundCondition(self, other, operator=ConditionOperator.CONJUNCTION) def disjoin(self, other: "Condition") -> "CompoundCondition": """OR's this condition with the other condition.""" - abstract_conditions = (Condition, SubjectCondition) - if ( - self.__class__ in abstract_conditions - or other.__class__ in abstract_conditions - ): - raise TypeError("Base condition cannot be disjoined.") return CompoundCondition(self, other, operator=ConditionOperator.DISJUNCTION) def implicate(self, other: "Condition") -> "CompoundCondition": """Current condition must be false or both conditions true to be true.""" - abstract_conditions = (Condition, SubjectCondition) - if ( - self.__class__ in abstract_conditions - or other.__class__ in abstract_conditions - ): - raise TypeError("Base condition cannot be implicated.") return CompoundCondition(self, other, operator=ConditionOperator.IMPLICATION) def delay(self, seconds, persistant=False) -> "DelayCondition": """Delays the current condition until the given number of simulation seconds have occured.""" - abstract_conditions = (Condition, SubjectCondition) - if self.__class__ in abstract_conditions: - raise TypeError("Base condition cannot be delayed.") return DelayCondition(self, seconds=seconds, persistant=persistant) @@ -758,6 +734,13 @@ class NegatedCondition(Condition): def evaluate(self, *args, **kwargs) -> ConditionState: return ~self.inner_condition.evaluate(*args, **kwargs) + def __post_init__(self): + abstract_conditions = (Condition, SubjectCondition) + if self.inner_condition.__class__ in abstract_conditions: + raise TypeError( + f"Abstract `{self.inner_condition.__name__}` cannot use the negation operation." + ) + @dataclass(frozen=True) class DelayCondition(Condition): @@ -794,6 +777,13 @@ def evaluate(self, *args, simulation_time, **kwargs) -> ConditionState: object.__setattr__(self, key, simulation_time) return ConditionState.FALSE + def __post_init__(self): + abstract_conditions = (Condition, SubjectCondition) + if self.inner_condition.__class__ in abstract_conditions: + raise TypeError( + f"Abstract `{self.inner_condition.__name__}` cannot use delay operations." + ) + @dataclass(frozen=True) class OnRoadCondition(SubjectCondition): @@ -842,10 +832,20 @@ def evaluate(self, *args, **kwargs): if self.operator == ConditionOperator.CONJUNCTION: return eval_0 & eval_1 elif self.operator == ConditionOperator.DISJUNCTION: - return eval_0 | eval_1 + return (eval_0 | eval_1) & ConditionState.TRUE return ConditionState.FALSE + def __post_init__(self): + abstract_conditions = (Condition, SubjectCondition) + if ( + self.first_condition.__class__ in abstract_conditions + or self.second_condition.__class__ in abstract_conditions + ): + raise TypeError( + f"Abstract `{self.inner_condition.__name__}` cannot use compound operations." + ) + @dataclass(frozen=True) class EntryTactic: From 233dd444e96569a6c4d9511d88a533e83035c1a4 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 19:17:41 +0000 Subject: [PATCH 18/59] Centralize listing of abstract conditions. --- smarts/sstudio/types.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 97b88e6e38..acad7c3cca 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -679,6 +679,7 @@ def evaluate(self, *args, actor_info, **kwargs) -> ConditionState: """ raise NotImplementedError() +_abstract_conditions = (Condition, SubjectCondition) @dataclass(frozen=True) class LiteralCondition(Condition): @@ -735,8 +736,7 @@ def evaluate(self, *args, **kwargs) -> ConditionState: return ~self.inner_condition.evaluate(*args, **kwargs) def __post_init__(self): - abstract_conditions = (Condition, SubjectCondition) - if self.inner_condition.__class__ in abstract_conditions: + if self.inner_condition.__class__ in _abstract_conditions: raise TypeError( f"Abstract `{self.inner_condition.__name__}` cannot use the negation operation." ) @@ -778,8 +778,7 @@ def evaluate(self, *args, simulation_time, **kwargs) -> ConditionState: return ConditionState.FALSE def __post_init__(self): - abstract_conditions = (Condition, SubjectCondition) - if self.inner_condition.__class__ in abstract_conditions: + if self.inner_condition.__class__ in _abstract_conditions: raise TypeError( f"Abstract `{self.inner_condition.__name__}` cannot use delay operations." ) @@ -837,10 +836,9 @@ def evaluate(self, *args, **kwargs): return ConditionState.FALSE def __post_init__(self): - abstract_conditions = (Condition, SubjectCondition) if ( - self.first_condition.__class__ in abstract_conditions - or self.second_condition.__class__ in abstract_conditions + self.first_condition.__class__ in _abstract_conditions + or self.second_condition.__class__ in _abstract_conditions ): raise TypeError( f"Abstract `{self.inner_condition.__name__}` cannot use compound operations." From 5ddf0e967738aeb02fce8c1cb9e3edc3b4eb178e Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 19:28:24 +0000 Subject: [PATCH 19/59] Fix naming in error. --- smarts/sstudio/types.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index acad7c3cca..fd14dc93c4 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -738,7 +738,7 @@ def evaluate(self, *args, **kwargs) -> ConditionState: def __post_init__(self): if self.inner_condition.__class__ in _abstract_conditions: raise TypeError( - f"Abstract `{self.inner_condition.__name__}` cannot use the negation operation." + f"Abstract `{self.inner_condition.__class__.__name__}` cannot use the negation operation." ) @@ -780,7 +780,7 @@ def evaluate(self, *args, simulation_time, **kwargs) -> ConditionState: def __post_init__(self): if self.inner_condition.__class__ in _abstract_conditions: raise TypeError( - f"Abstract `{self.inner_condition.__name__}` cannot use delay operations." + f"Abstract `{self.inner_condition.__class__.__name__}` cannot use delay operations." ) @@ -836,13 +836,11 @@ def evaluate(self, *args, **kwargs): return ConditionState.FALSE def __post_init__(self): - if ( - self.first_condition.__class__ in _abstract_conditions - or self.second_condition.__class__ in _abstract_conditions - ): - raise TypeError( - f"Abstract `{self.inner_condition.__name__}` cannot use compound operations." - ) + for condition in (self.first_condition, self.second_condition): + if condition.__class__ in _abstract_conditions: + raise TypeError( + f"Abstract `{condition.__class__.__name__}` cannot use compound operations." + ) @dataclass(frozen=True) From 4b92b798e11ffd40c553edf61c779146408a19cf Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 19:28:56 +0000 Subject: [PATCH 20/59] Remove simplified negation. --- smarts/sstudio/types.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index fd14dc93c4..8691aade19 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -691,9 +691,6 @@ class LiteralCondition(Condition): def evaluate(self, *args, **kwargs) -> ConditionState: return self.literal - def negate(self) -> "LiteralCondition": - return LiteralCondition(~self.literal) - @dataclass(frozen=True) class TimeWindowCondition(Condition): From 3a9004e7a33eeefff3bd512c7298b5f8a8d5a92c Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 19:38:55 +0000 Subject: [PATCH 21/59] Make conditions solely true. --- smarts/sstudio/types.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 8691aade19..5fc31dcfe0 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -826,9 +826,15 @@ def evaluate(self, *args, **kwargs): return ConditionState.TRUE if self.operator == ConditionOperator.CONJUNCTION: - return eval_0 & eval_1 + result = eval_0 & eval_1 + if result: + return ConditionState.TRUE + return result elif self.operator == ConditionOperator.DISJUNCTION: - return (eval_0 | eval_1) & ConditionState.TRUE + result = eval_0 | eval_1 + if result: + return ConditionState.TRUE + return result return ConditionState.FALSE From 0ee661cd3945d6209a1b71d558ae6d72d3d1c064 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 19:39:14 +0000 Subject: [PATCH 22/59] Remove unused arguement. --- smarts/sstudio/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 5fc31dcfe0..d60399b62c 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -701,7 +701,7 @@ class TimeWindowCondition(Condition): end: float """The ending simulation time as of which this condition becomes expired.""" - def evaluate(self, *args, simulation_time, mission_start_time, **kwargs): + def evaluate(self, *args, simulation_time, **kwargs): if self.start <= simulation_time < self.end: return ConditionState.TRUE elif simulation_time > self.end: From 01efa0ef913ceb79394f70607a9bfe6a3cd70d5d Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 19:39:25 +0000 Subject: [PATCH 23/59] Add tests. --- smarts/sstudio/tests/test_conditions.py | 253 ++++++++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 smarts/sstudio/tests/test_conditions.py diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py new file mode 100644 index 0000000000..ef581c01ff --- /dev/null +++ b/smarts/sstudio/tests/test_conditions.py @@ -0,0 +1,253 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import pytest + +from smarts.sstudio.types import ( + CompoundCondition, + Condition, + ConditionOperator, + ConditionState, + DelayCondition, + DependeeActorCondition, + LiteralCondition, + NegatedCondition, + OnRoadCondition, + SubjectCondition, + TimeWindowCondition, + VehicleTypeCondition, +) + + +def test_condition_state(): + assert not bool(ConditionState.BEFORE) + assert not bool(ConditionState.EXPIRED) + assert not bool(ConditionState.FALSE) + assert bool(ConditionState.TRUE) + + assert ConditionState.TRUE + assert bool(not ConditionState.TRUE) == False + assert not ~ConditionState.TRUE + + assert bool(ConditionState.FALSE) == False + assert not ConditionState.FALSE + assert ~ConditionState.FALSE + + assert bool(ConditionState.EXPIRED) == False + assert not ConditionState.EXPIRED + assert ~ConditionState.EXPIRED + + assert bool(ConditionState.BEFORE) == False + assert not ConditionState.BEFORE + assert ~ConditionState.BEFORE + + assert ConditionState.TRUE | ConditionState.FALSE + assert not ConditionState.TRUE & ConditionState.FALSE + assert ( + ConditionState.TRUE + | ConditionState.EXPIRED + | ConditionState.FALSE + | ConditionState.BEFORE + ) + assert not (ConditionState.EXPIRED | ConditionState.FALSE | ConditionState.BEFORE) + + +def test_condition(): + literal_true = LiteralCondition(ConditionState.TRUE) + + condition = Condition() + + with pytest.raises(NotImplementedError): + condition.evaluate(actor_info=None) + + with pytest.raises(TypeError): + condition.negate() + + with pytest.raises(TypeError): + condition.conjoin(literal_true) + + with pytest.raises(TypeError): + condition.disjoin(literal_true) + + with pytest.raises(TypeError): + condition.implicate(literal_true) + + with pytest.raises(TypeError): + condition.delay(10) + + +def test_compound_condition(): + literal_true = LiteralCondition(ConditionState.TRUE) + literal_false = LiteralCondition(ConditionState.FALSE) + + assert CompoundCondition( + first_condition=literal_true, + second_condition=literal_false, + operator=ConditionOperator.CONJUNCTION, + ) == literal_true.conjoin(literal_false) + assert literal_true.conjoin(literal_true).evaluate() + assert not literal_true.conjoin(literal_false).evaluate() + assert not literal_false.conjoin(literal_true).evaluate() + assert not literal_false.conjoin(literal_false).evaluate() + + assert CompoundCondition( + first_condition=literal_true, + second_condition=literal_false, + operator=ConditionOperator.DISJUNCTION, + ) == literal_true.disjoin(literal_false) + assert literal_true.disjoin(literal_true) + assert literal_true.disjoin(literal_false).evaluate() + assert literal_false.disjoin(literal_true).evaluate() + assert not literal_false.disjoin(literal_false).evaluate() + + assert CompoundCondition( + first_condition=literal_true, + second_condition=literal_false, + operator=ConditionOperator.IMPLICATION, + ) == literal_true.implicate(literal_false) + assert literal_true.implicate(literal_true).evaluate() + assert not literal_true.implicate(literal_false).evaluate() + assert literal_false.implicate(literal_true).evaluate() + assert literal_false.implicate(literal_false).evaluate() + + +def test_delay_condition(): + short_delay = 4 + long_delay = 10 + first_time_window_true = 5 + window_condition = TimeWindowCondition(4, 10) + delayed_condition = window_condition.delay(long_delay, persistant=False) + + assert delayed_condition == DelayCondition( + inner_condition=window_condition, + seconds=long_delay, + persistant=False, + ) + + # before + time = 2 + assert ( + not delayed_condition.evaluate(simulation_time=time) + ) and not window_condition.evaluate(simulation_time=time) + # first true + time = first_time_window_true + assert ( + not delayed_condition.evaluate(simulation_time=time) + ) and window_condition.evaluate(simulation_time=time) + # delay not expired + time = first_time_window_true + long_delay - 1 + assert ( + not delayed_condition.evaluate(simulation_time=time) + ) and not window_condition.evaluate(simulation_time=time) + # delay expired + time = first_time_window_true + long_delay + assert delayed_condition.evaluate( + simulation_time=time + ) and not window_condition.evaluate(simulation_time=time) + # delay expired + time = first_time_window_true + long_delay + 1 + assert delayed_condition.evaluate( + simulation_time=time + ) and not window_condition.evaluate(simulation_time=time) + # delay not expired + time = first_time_window_true + long_delay - 1 + assert not delayed_condition.evaluate(simulation_time=time) + + # Test persistant true + delayed_condition = window_condition.delay(short_delay, persistant=True) + time = first_time_window_true + assert not delayed_condition.evaluate(simulation_time=time) + time = first_time_window_true + short_delay + assert delayed_condition.evaluate(simulation_time=time) + time = first_time_window_true + long_delay + assert not delayed_condition.evaluate(simulation_time=time) + + +def test_dependee_condition(): + dependee_condition = DependeeActorCondition("leader") + pass + + +def test_literal_condition(): + literal_true = LiteralCondition(ConditionState.TRUE) + literal_false = LiteralCondition(ConditionState.FALSE) + + assert literal_false.evaluate() == ConditionState.FALSE + assert literal_true.evaluate() == ConditionState.TRUE + assert literal_true.evaluate() + assert not literal_false.evaluate() + + +def test_negated_condition(): + literal_true = LiteralCondition(ConditionState.TRUE) + literal_false = LiteralCondition(ConditionState.FALSE) + + assert literal_false.negate() == NegatedCondition(literal_false) + assert literal_true.negate() == NegatedCondition(literal_true) + + assert literal_false.negate().evaluate() + assert not literal_true.negate().evaluate() + + +def test_on_road_condition(): + on_road_condition = OnRoadCondition() + pass + + +def test_time_window_condition(): + start = 4 + between = 8 + end = 10 + + window_condition = TimeWindowCondition(start=start, end=end) + + assert not window_condition.evaluate(simulation_time=start - 1) + assert window_condition.evaluate(simulation_time=start) + assert window_condition.evaluate(simulation_time=between) + assert not window_condition.evaluate(simulation_time=end) + + +def test_subject_condition(): + literal_true = LiteralCondition(ConditionState.TRUE) + + subject_condition = SubjectCondition() + + with pytest.raises(NotImplementedError): + subject_condition.evaluate(actor_info=None) + + with pytest.raises(TypeError): + subject_condition.negate() + + with pytest.raises(TypeError): + subject_condition.conjoin(literal_true) + + with pytest.raises(TypeError): + subject_condition.disjoin(literal_true) + + with pytest.raises(TypeError): + subject_condition.implicate(literal_true) + + with pytest.raises(TypeError): + subject_condition.delay(10) + + +def test_vehicle_type_condition(): + pass From 4ad34113d5bfd3c6c097ae31b92d1287d0ae6b5d Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 19:48:03 +0000 Subject: [PATCH 24/59] Add vehicle type condition tests. --- smarts/sstudio/tests/test_conditions.py | 13 ++++++++++++- smarts/sstudio/types.py | 2 ++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index ef581c01ff..01e61bec04 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -19,6 +19,8 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +from unittest.mock import MagicMock, Mock + import pytest from smarts.sstudio.types import ( @@ -250,4 +252,13 @@ def test_subject_condition(): def test_vehicle_type_condition(): - pass + vehicle_type_condition = VehicleTypeCondition("passenger") + + passenger_vehicle_state = Mock() + passenger_vehicle_state.vehicle_config_type = "passenger" + + truck_vehicle_state = Mock() + truck_vehicle_state.vehicle_config_type = "truck" + + assert vehicle_type_condition.evaluate(vehicle_state=passenger_vehicle_state) + assert not vehicle_type_condition.evaluate(vehicle_state=truck_vehicle_state) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index d60399b62c..d49642bdd8 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -679,8 +679,10 @@ def evaluate(self, *args, actor_info, **kwargs) -> ConditionState: """ raise NotImplementedError() + _abstract_conditions = (Condition, SubjectCondition) + @dataclass(frozen=True) class LiteralCondition(Condition): """This condition evaluates as a literal without considering evaluation parameters.""" From 675db3df8ffb80798d860c5d7b4bc8b8df3998e5 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 20:10:33 +0000 Subject: [PATCH 25/59] Add vehicle speed condition. --- smarts/sstudio/tests/test_conditions.py | 23 ++++++++++++++++++++++- smarts/sstudio/types.py | 25 ++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index 01e61bec04..a147f48675 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -19,7 +19,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock import pytest @@ -35,6 +35,7 @@ OnRoadCondition, SubjectCondition, TimeWindowCondition, + VehicleSpeedCondition, VehicleTypeCondition, ) @@ -251,6 +252,26 @@ def test_subject_condition(): subject_condition.delay(10) +def test_vehicle_speed_condition(): + low = 30 + between = 50 + high = 100 + vehicle_speed_condition = VehicleSpeedCondition(low, high) + + slow_vehicle_state = Mock() + slow_vehicle_state.speed = low - 10 + + between_vehicle_state = Mock() + between_vehicle_state.speed = between + + fast_vehicle_state = Mock() + fast_vehicle_state.speed = high + 50 + + assert not vehicle_speed_condition.evaluate(vehicle_state=slow_vehicle_state) + assert vehicle_speed_condition.evaluate(vehicle_state=between_vehicle_state) + assert not vehicle_speed_condition.evaluate(vehicle_state=fast_vehicle_state) + + def test_vehicle_type_condition(): vehicle_type_condition = VehicleTypeCondition("passenger") diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index d49642bdd8..94bffe712e 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -35,6 +35,7 @@ Optional, Sequence, Tuple, + Type, Union, ) @@ -793,7 +794,7 @@ def evaluate(self, *args, actor_info, **kwargs) -> ConditionState: @dataclass(frozen=True) class VehicleTypeCondition(SubjectCondition): - """This condition is true if the subject is of the given types.""" + """This condition is true if the subject is of the given vehicle types.""" vehicle_type: str @@ -805,6 +806,28 @@ def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: ) +@dataclass(frozen=True) +class VehicleSpeedCondition(SubjectCondition): + """This condition is true if the subject has a speed between low and high.""" + + low: float + """The lowest speed allowed.""" + + high: float + """The highest speed allowed.""" + + def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: + return ( + ConditionState.TRUE + if self.low <= vehicle_state.speed <= self.high + else ConditionState.FALSE + ) + + @classmethod + def loitering(cls: Type["VehicleSpeedCondition"], abs_error=0.01): + return cls(low=abs_error, high=abs_error) + + @dataclass(frozen=True) class CompoundCondition(Condition): """This condition should be true if the given actor exists.""" From 4256a6758674f698a25b49465c363e4291c1b61b Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 20:14:38 +0000 Subject: [PATCH 26/59] Add time window check for after. --- smarts/sstudio/tests/test_conditions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index a147f48675..269559eea9 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -226,6 +226,7 @@ def test_time_window_condition(): assert window_condition.evaluate(simulation_time=start) assert window_condition.evaluate(simulation_time=between) assert not window_condition.evaluate(simulation_time=end) + assert not window_condition.evaluate(simulation_time=end + 1) def test_subject_condition(): From 22e196f38418d9f289af4cab0bcc575d4705d64a Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 20:16:30 +0000 Subject: [PATCH 27/59] Update subject condition. --- smarts/sstudio/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 94bffe712e..b7abd3e02a 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -670,7 +670,7 @@ def delay(self, seconds, persistant=False) -> "DelayCondition": class SubjectCondition(Condition): """This condition assumes that there is a subject involved.""" - def evaluate(self, *args, actor_info, **kwargs) -> ConditionState: + def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: """Used to evaluate if a condition is met. Args: @@ -788,8 +788,8 @@ def __post_init__(self): class OnRoadCondition(SubjectCondition): """This condition is true if the subject is on road.""" - def evaluate(self, *args, actor_info, **kwargs) -> ConditionState: - return ConditionState.TRUE if actor_info.on_road else ConditionState.FALSE + def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: + return ConditionState.TRUE if vehicle_state.on_road else ConditionState.FALSE @dataclass(frozen=True) From 9222e711fe9d4fa3b8b1e332a9388c2626de56b6 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 22:24:12 +0000 Subject: [PATCH 28/59] Fix missing import --- scenarios/sumo/zoo_intersection/scenario.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scenarios/sumo/zoo_intersection/scenario.py b/scenarios/sumo/zoo_intersection/scenario.py index 7db35a6d65..0dcbcaff2c 100644 --- a/scenarios/sumo/zoo_intersection/scenario.py +++ b/scenarios/sumo/zoo_intersection/scenario.py @@ -13,6 +13,7 @@ SocialAgentActor, Traffic, TrafficActor, + TrapEntryTactic, ) # Traffic Vehicles From cc05370354465a33a4a5b3dc605fbe4744fb51da Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 22:25:56 +0000 Subject: [PATCH 29/59] Finish ConditionState cases. --- smarts/sstudio/tests/test_conditions.py | 70 +++++++++++++++++++------ smarts/sstudio/types.py | 49 ++++++++++++++--- 2 files changed, 94 insertions(+), 25 deletions(-) diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index 269559eea9..5d7e2ccbdc 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -41,16 +41,16 @@ def test_condition_state(): - assert not bool(ConditionState.BEFORE) + assert bool(ConditionState.TRUE) assert not bool(ConditionState.EXPIRED) + assert not bool(ConditionState.BEFORE) assert not bool(ConditionState.FALSE) - assert bool(ConditionState.TRUE) assert ConditionState.TRUE - assert bool(not ConditionState.TRUE) == False + assert (not ConditionState.TRUE) == False assert not ~ConditionState.TRUE - assert bool(ConditionState.FALSE) == False + assert ConditionState.FALSE == False assert not ConditionState.FALSE assert ~ConditionState.FALSE @@ -100,36 +100,72 @@ def test_condition(): def test_compound_condition(): literal_true = LiteralCondition(ConditionState.TRUE) literal_false = LiteralCondition(ConditionState.FALSE) + literal_before = LiteralCondition(ConditionState.BEFORE) + literal_expired = LiteralCondition(ConditionState.EXPIRED) assert CompoundCondition( first_condition=literal_true, second_condition=literal_false, operator=ConditionOperator.CONJUNCTION, ) == literal_true.conjoin(literal_false) - assert literal_true.conjoin(literal_true).evaluate() - assert not literal_true.conjoin(literal_false).evaluate() - assert not literal_false.conjoin(literal_true).evaluate() - assert not literal_false.conjoin(literal_false).evaluate() + assert literal_true.conjoin(literal_true).evaluate() == ConditionState.TRUE + + assert literal_expired.conjoin(literal_expired).evaluate() == ConditionState.EXPIRED + assert literal_expired.conjoin(literal_true).evaluate() == ConditionState.EXPIRED + assert literal_expired.conjoin(literal_before).evaluate() == ConditionState.EXPIRED + assert literal_expired.conjoin(literal_false).evaluate() == ConditionState.EXPIRED + + assert literal_before.conjoin(literal_true).evaluate() == ConditionState.BEFORE + assert literal_before.conjoin(literal_before).evaluate() == ConditionState.BEFORE + assert literal_before.conjoin(literal_false).evaluate() == ConditionState.BEFORE + + assert literal_false.conjoin(literal_true).evaluate() == ConditionState.FALSE + assert literal_false.conjoin(literal_false).evaluate() == ConditionState.FALSE assert CompoundCondition( first_condition=literal_true, second_condition=literal_false, operator=ConditionOperator.DISJUNCTION, ) == literal_true.disjoin(literal_false) - assert literal_true.disjoin(literal_true) - assert literal_true.disjoin(literal_false).evaluate() - assert literal_false.disjoin(literal_true).evaluate() - assert not literal_false.disjoin(literal_false).evaluate() + assert literal_true.disjoin(literal_true).evaluate() == ConditionState.TRUE + assert literal_true.disjoin(literal_false).evaluate() == ConditionState.TRUE + assert literal_true.disjoin(literal_before).evaluate() == ConditionState.TRUE + assert literal_true.disjoin(literal_expired).evaluate() == ConditionState.TRUE + + assert literal_before.disjoin(literal_before).evaluate() == ConditionState.BEFORE + assert literal_before.disjoin(literal_expired).evaluate() == ConditionState.BEFORE + assert literal_before.disjoin(literal_false).evaluate() == ConditionState.BEFORE + + assert literal_expired.disjoin(literal_expired).evaluate() == ConditionState.EXPIRED + assert literal_expired.disjoin(literal_false).evaluate() == ConditionState.EXPIRED + + assert literal_false.disjoin(literal_false).evaluate() == ConditionState.FALSE assert CompoundCondition( first_condition=literal_true, second_condition=literal_false, operator=ConditionOperator.IMPLICATION, ) == literal_true.implicate(literal_false) - assert literal_true.implicate(literal_true).evaluate() - assert not literal_true.implicate(literal_false).evaluate() - assert literal_false.implicate(literal_true).evaluate() - assert literal_false.implicate(literal_false).evaluate() + assert literal_true.implicate(literal_true).evaluate() == ConditionState.TRUE + + assert literal_true.implicate(literal_expired).evaluate() == ConditionState.FALSE + assert literal_true.implicate(literal_before).evaluate() == ConditionState.FALSE + assert literal_true.implicate(literal_false).evaluate() == ConditionState.FALSE + + assert literal_expired.implicate(literal_true).evaluate() == ConditionState.TRUE + assert literal_expired.implicate(literal_expired).evaluate() == ConditionState.TRUE + assert literal_expired.implicate(literal_before).evaluate() == ConditionState.TRUE + assert literal_expired.implicate(literal_false).evaluate() == ConditionState.TRUE + + assert literal_before.implicate(literal_true).evaluate() == ConditionState.TRUE + assert literal_before.implicate(literal_expired).evaluate() == ConditionState.TRUE + assert literal_before.implicate(literal_before).evaluate() == ConditionState.TRUE + assert literal_before.implicate(literal_false).evaluate() == ConditionState.TRUE + + assert literal_false.implicate(literal_true).evaluate() == ConditionState.TRUE + assert literal_false.implicate(literal_expired).evaluate() == ConditionState.TRUE + assert literal_false.implicate(literal_false).evaluate() == ConditionState.TRUE + assert literal_false.implicate(literal_false).evaluate() == ConditionState.TRUE def test_delay_condition(): @@ -235,7 +271,7 @@ def test_subject_condition(): subject_condition = SubjectCondition() with pytest.raises(NotImplementedError): - subject_condition.evaluate(actor_info=None) + subject_condition.evaluate(vehicle_state=None) with pytest.raises(TypeError): subject_condition.negate() diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index b7abd3e02a..5fb107cf8f 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -830,7 +830,24 @@ def loitering(cls: Type["VehicleSpeedCondition"], abs_error=0.01): @dataclass(frozen=True) class CompoundCondition(Condition): - """This condition should be true if the given actor exists.""" + """This compounds multiple conditions. + + The following cases are notable + CONJUNCTION + If both conditions evaluate TRUE the result is exclusively TRUE. + Else if either condition evaluates EXPIRED the result will be EXPIRED. + Else if either condition evaluates BEFORE the result will be BEFORE. + Else FALSE + DISJUNCTION + If either condition evaluates TRUE the result is exclusively TRUE. + Else if either condition evaluates BEFORE then the result will be BEFORE. + Else if either condition evaluates EXPIRED then the result will be EXPIRED. + Else FALSE + IMPLICATION + If the first condition evaluates *not* TRUE the result is exclusively TRUE. + Else if the first condition evaluates TRUE and the second condition evaluates TRUE the result is exclusively TRUE. + Else FALSE + """ first_condition: Condition """The first condition.""" @@ -842,23 +859,39 @@ class CompoundCondition(Condition): """The operator used to combine these conditions.""" def evaluate(self, *args, **kwargs): - eval_0 = self.first_condition.evaluate(*args, **kwargs) - if self.operator == ConditionOperator.IMPLICATION and not eval_0: + first_eval = self.first_condition.evaluate(*args, **kwargs) + if self.operator == ConditionOperator.IMPLICATION and not first_eval: return ConditionState.TRUE - eval_1 = self.second_condition.evaluate(*args, **kwargs) - if self.operator == ConditionOperator.IMPLICATION and eval_0 and eval_1: + second_eval = self.second_condition.evaluate(*args, **kwargs) + if ( + self.operator == ConditionOperator.IMPLICATION + and first_eval + and second_eval + ): return ConditionState.TRUE if self.operator == ConditionOperator.CONJUNCTION: - result = eval_0 & eval_1 + result = first_eval & second_eval if result: return ConditionState.TRUE - return result + + temporals = (first_eval | second_eval) & ( + ConditionState.BEFORE | ConditionState.EXPIRED + ) + if ConditionState.EXPIRED in temporals: + return ConditionState.EXPIRED + + return temporals elif self.operator == ConditionOperator.DISJUNCTION: - result = eval_0 | eval_1 + result = first_eval | second_eval + if result: return ConditionState.TRUE + + if ConditionState.BEFORE in result: + return ConditionState.BEFORE + return result return ConditionState.FALSE From 31307dbb4659978a1168bb1276bc31d3e52492d7 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 22:34:02 +0000 Subject: [PATCH 30/59] Clarify operations. --- smarts/sstudio/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 5fb107cf8f..444a8aceaf 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -833,17 +833,17 @@ class CompoundCondition(Condition): """This compounds multiple conditions. The following cases are notable - CONJUNCTION + CONJUNCTION (A AND B) If both conditions evaluate TRUE the result is exclusively TRUE. Else if either condition evaluates EXPIRED the result will be EXPIRED. Else if either condition evaluates BEFORE the result will be BEFORE. Else FALSE - DISJUNCTION + DISJUNCTION (A OR B) If either condition evaluates TRUE the result is exclusively TRUE. Else if either condition evaluates BEFORE then the result will be BEFORE. Else if either condition evaluates EXPIRED then the result will be EXPIRED. Else FALSE - IMPLICATION + IMPLICATION (A AND B or not A) If the first condition evaluates *not* TRUE the result is exclusively TRUE. Else if the first condition evaluates TRUE and the second condition evaluates TRUE the result is exclusively TRUE. Else FALSE From 79f171a85dbb79d82d47a5dbe376504dedef573e Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 22:37:18 +0000 Subject: [PATCH 31/59] Update condition descriptions. --- smarts/sstudio/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 444a8aceaf..c42ba2fd38 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -650,15 +650,15 @@ def negate(self) -> "NegatedCondition": return NegatedCondition(self) def conjoin(self, other: "Condition") -> "CompoundCondition": - """AND's this condition with the other condition.""" + """Resolve conditions as A AND B.""" return CompoundCondition(self, other, operator=ConditionOperator.CONJUNCTION) def disjoin(self, other: "Condition") -> "CompoundCondition": - """OR's this condition with the other condition.""" + """Resolve conditions as A OR B.""" return CompoundCondition(self, other, operator=ConditionOperator.DISJUNCTION) def implicate(self, other: "Condition") -> "CompoundCondition": - """Current condition must be false or both conditions true to be true.""" + """Resolve conditions as A AND B OR NOT A.""" return CompoundCondition(self, other, operator=ConditionOperator.IMPLICATION) def delay(self, seconds, persistant=False) -> "DelayCondition": From 4360917d6fcf7b5290d986ebfa2a9e1b9c821d49 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 5 May 2023 22:48:32 +0000 Subject: [PATCH 32/59] Fix logic for disjunction of expired condition state. --- smarts/sstudio/tests/test_conditions.py | 2 +- smarts/sstudio/types.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index 5d7e2ccbdc..eb8f11d854 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -137,8 +137,8 @@ def test_compound_condition(): assert literal_before.disjoin(literal_false).evaluate() == ConditionState.BEFORE assert literal_expired.disjoin(literal_expired).evaluate() == ConditionState.EXPIRED - assert literal_expired.disjoin(literal_false).evaluate() == ConditionState.EXPIRED + assert literal_expired.disjoin(literal_false).evaluate() == ConditionState.FALSE assert literal_false.disjoin(literal_false).evaluate() == ConditionState.FALSE assert CompoundCondition( diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index c42ba2fd38..57fb77134f 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -841,7 +841,7 @@ class CompoundCondition(Condition): DISJUNCTION (A OR B) If either condition evaluates TRUE the result is exclusively TRUE. Else if either condition evaluates BEFORE then the result will be BEFORE. - Else if either condition evaluates EXPIRED then the result will be EXPIRED. + Else if both conditions evaluate EXPIRED then the result will be EXPIRED. Else FALSE IMPLICATION (A AND B or not A) If the first condition evaluates *not* TRUE the result is exclusively TRUE. @@ -892,7 +892,10 @@ def evaluate(self, *args, **kwargs): if ConditionState.BEFORE in result: return ConditionState.BEFORE - return result + if ConditionState.EXPIRED in first_eval & second_eval: + return ConditionState.EXPIRED + + return ConditionState.FALSE return ConditionState.FALSE From ee8cf6ddd3f124cfdecaa29d10abe9742e05283b Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 8 May 2023 10:13:20 -0400 Subject: [PATCH 33/59] Add test for dependee condition. --- smarts/sstudio/tests/test_conditions.py | 4 +++- smarts/sstudio/types.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index eb8f11d854..9b450f9dcb 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -222,7 +222,9 @@ def test_delay_condition(): def test_dependee_condition(): dependee_condition = DependeeActorCondition("leader") - pass + + assert dependee_condition.evaluate(active_actor_ids={"mr", "leader"}) + assert not dependee_condition.evaluate(active_actor_ids={"other", "vehicle"}) def test_literal_condition(): diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 57fb77134f..593038937a 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -719,11 +719,14 @@ class DependeeActorCondition(Condition): actor_id: str """The id of an actor in the simulation that needs to exist for this condition to be true.""" - def evaluate(self, *args, actor_ids, **kwargs): - if self.actor_id in actor_ids: + def evaluate(self, *args, active_actor_ids, **kwargs): + if self.actor_id in active_actor_ids: return ConditionState.TRUE return ConditionState.FALSE + def __post_init__(self): + assert isinstance(self.actor_id, str) + @dataclass(frozen=True) class NegatedCondition(Condition): From a0aaa4a4e1d21d295f178cd38f709feb782d25bf Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Tue, 9 May 2023 20:29:17 +0000 Subject: [PATCH 34/59] Update condition with information requirements. --- smarts/core/condition_state.py | 38 ++++++ smarts/core/id_actor_capture_manager.py | 4 +- smarts/core/trap_manager.py | 25 +++- smarts/sstudio/tests/test_conditions.py | 140 ++++++++++++--------- smarts/sstudio/types.py | 159 +++++++++++++++++++----- 5 files changed, 272 insertions(+), 94 deletions(-) create mode 100644 smarts/core/condition_state.py diff --git a/smarts/core/condition_state.py b/smarts/core/condition_state.py new file mode 100644 index 0000000000..e95cc2a093 --- /dev/null +++ b/smarts/core/condition_state.py @@ -0,0 +1,38 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +from enum import IntFlag + + +class ConditionState(IntFlag): + """Represents the state of a condition.""" + + FALSE = 0 + """This condition is false.""" + BEFORE = 1 + """This condition is false and never evaluated true before.""" + EXPIRED = 2 + """This condition is false and will never evaluate true.""" + TRUE = 4 + """This condition is true.""" + + def __bool__(self) -> bool: + return self.TRUE in self diff --git a/smarts/core/id_actor_capture_manager.py b/smarts/core/id_actor_capture_manager.py index 08f02b079d..5820ceca7d 100644 --- a/smarts/core/id_actor_capture_manager.py +++ b/smarts/core/id_actor_capture_manager.py @@ -23,9 +23,10 @@ from typing import Dict, Optional, Set, Tuple from smarts.core.actor_capture_manager import ActorCaptureManager +from smarts.core.condition_state import ConditionState from smarts.core.plan import Mission from smarts.core.vehicle import Vehicle -from smarts.sstudio.types import ConditionState, IdEntryTactic +from smarts.sstudio.types import IdEntryTactic class IdActorCaptureManager(ActorCaptureManager): @@ -66,6 +67,7 @@ def step(self, sim): actor_ids=sim.vehicle_index.vehicle_ids, vehicle_state=vehicle.state if vehicle else None, mission_start_time=mission.start_time, + mission=mission, ) if condition_result == ConditionState.EXPIRED: print(condition_result) diff --git a/smarts/core/trap_manager.py b/smarts/core/trap_manager.py index 2f11b626be..105b55ccca 100644 --- a/smarts/core/trap_manager.py +++ b/smarts/core/trap_manager.py @@ -22,11 +22,12 @@ import random as rand from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Set, Tuple +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple from shapely.geometry import Polygon from smarts.core.actor_capture_manager import ActorCaptureManager +from smarts.core.condition_state import ConditionState from smarts.core.coordinates import Point as MapPoint from smarts.core.plan import Mission, Plan, Start, default_entry_tactic from smarts.core.utils.file import replace @@ -70,6 +71,28 @@ def includes(self, vehicle_id: str): return False return True + def evaluate( + self, + simulation, + vehicle_state: Optional[Any], + ) -> ConditionState: + """Considers the given vehicle to see if it is applicable. + + Args: + simulation (SMARTS): The simulation reference + vehicle_state (VehicleState): The current vehicle state. + + Returns: + ConditionState: The current state of the condition. + """ + entry_tactic: TrapEntryTactic = self.mission.entry_tactic + return entry_tactic.condition.evaluate( + simulation_time=simulation.elapsed_sim_time, + actor_ids=simulation.vehicle_index.vehicle_ids, + vehicle_state=vehicle_state, + mission_start_time=self.mission.start_time, + ) + class TrapManager(ActorCaptureManager): """Facilitates agent hijacking of actors""" diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index 9b450f9dcb..86c294a726 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -28,13 +28,13 @@ Condition, ConditionOperator, ConditionState, - DelayCondition, DependeeActorCondition, LiteralCondition, NegatedCondition, OnRoadCondition, SubjectCondition, TimeWindowCondition, + TriggerCondition, VehicleSpeedCondition, VehicleTypeCondition, ) @@ -82,19 +82,19 @@ def test_condition(): condition.evaluate(actor_info=None) with pytest.raises(TypeError): - condition.negate() + condition.negation() with pytest.raises(TypeError): - condition.conjoin(literal_true) + condition.conjunction(literal_true) with pytest.raises(TypeError): - condition.disjoin(literal_true) + condition.disjunction(literal_true) with pytest.raises(TypeError): - condition.implicate(literal_true) + condition.implication(literal_true) with pytest.raises(TypeError): - condition.delay(10) + condition.trigger(10) def test_compound_condition(): @@ -107,65 +107,85 @@ def test_compound_condition(): first_condition=literal_true, second_condition=literal_false, operator=ConditionOperator.CONJUNCTION, - ) == literal_true.conjoin(literal_false) - assert literal_true.conjoin(literal_true).evaluate() == ConditionState.TRUE + ) == literal_true.conjunction(literal_false) + assert literal_true.conjunction(literal_true).evaluate() == ConditionState.TRUE - assert literal_expired.conjoin(literal_expired).evaluate() == ConditionState.EXPIRED - assert literal_expired.conjoin(literal_true).evaluate() == ConditionState.EXPIRED - assert literal_expired.conjoin(literal_before).evaluate() == ConditionState.EXPIRED - assert literal_expired.conjoin(literal_false).evaluate() == ConditionState.EXPIRED + assert ( + literal_expired.conjunction(literal_expired).evaluate() + == ConditionState.EXPIRED + ) + assert ( + literal_expired.conjunction(literal_true).evaluate() == ConditionState.EXPIRED + ) + assert ( + literal_expired.conjunction(literal_before).evaluate() == ConditionState.EXPIRED + ) + assert ( + literal_expired.conjunction(literal_false).evaluate() == ConditionState.EXPIRED + ) - assert literal_before.conjoin(literal_true).evaluate() == ConditionState.BEFORE - assert literal_before.conjoin(literal_before).evaluate() == ConditionState.BEFORE - assert literal_before.conjoin(literal_false).evaluate() == ConditionState.BEFORE + assert literal_before.conjunction(literal_true).evaluate() == ConditionState.BEFORE + assert ( + literal_before.conjunction(literal_before).evaluate() == ConditionState.BEFORE + ) + assert literal_before.conjunction(literal_false).evaluate() == ConditionState.BEFORE - assert literal_false.conjoin(literal_true).evaluate() == ConditionState.FALSE - assert literal_false.conjoin(literal_false).evaluate() == ConditionState.FALSE + assert literal_false.conjunction(literal_true).evaluate() == ConditionState.FALSE + assert literal_false.conjunction(literal_false).evaluate() == ConditionState.FALSE assert CompoundCondition( first_condition=literal_true, second_condition=literal_false, operator=ConditionOperator.DISJUNCTION, - ) == literal_true.disjoin(literal_false) - assert literal_true.disjoin(literal_true).evaluate() == ConditionState.TRUE - assert literal_true.disjoin(literal_false).evaluate() == ConditionState.TRUE - assert literal_true.disjoin(literal_before).evaluate() == ConditionState.TRUE - assert literal_true.disjoin(literal_expired).evaluate() == ConditionState.TRUE + ) == literal_true.disjunction(literal_false) + assert literal_true.disjunction(literal_true).evaluate() == ConditionState.TRUE + assert literal_true.disjunction(literal_false).evaluate() == ConditionState.TRUE + assert literal_true.disjunction(literal_before).evaluate() == ConditionState.TRUE + assert literal_true.disjunction(literal_expired).evaluate() == ConditionState.TRUE - assert literal_before.disjoin(literal_before).evaluate() == ConditionState.BEFORE - assert literal_before.disjoin(literal_expired).evaluate() == ConditionState.BEFORE - assert literal_before.disjoin(literal_false).evaluate() == ConditionState.BEFORE + assert ( + literal_before.disjunction(literal_before).evaluate() == ConditionState.BEFORE + ) + assert ( + literal_before.disjunction(literal_expired).evaluate() == ConditionState.BEFORE + ) + assert literal_before.disjunction(literal_false).evaluate() == ConditionState.BEFORE - assert literal_expired.disjoin(literal_expired).evaluate() == ConditionState.EXPIRED + assert ( + literal_expired.disjunction(literal_expired).evaluate() + == ConditionState.EXPIRED + ) - assert literal_expired.disjoin(literal_false).evaluate() == ConditionState.FALSE - assert literal_false.disjoin(literal_false).evaluate() == ConditionState.FALSE + assert literal_expired.disjunction(literal_false).evaluate() == ConditionState.FALSE + assert literal_false.disjunction(literal_false).evaluate() == ConditionState.FALSE assert CompoundCondition( first_condition=literal_true, second_condition=literal_false, operator=ConditionOperator.IMPLICATION, - ) == literal_true.implicate(literal_false) - assert literal_true.implicate(literal_true).evaluate() == ConditionState.TRUE + ) == literal_true.implication(literal_false) + assert literal_true.implication(literal_true).evaluate() == ConditionState.TRUE - assert literal_true.implicate(literal_expired).evaluate() == ConditionState.FALSE - assert literal_true.implicate(literal_before).evaluate() == ConditionState.FALSE - assert literal_true.implicate(literal_false).evaluate() == ConditionState.FALSE + assert literal_true.implication(literal_expired).evaluate() == ConditionState.FALSE + assert literal_true.implication(literal_before).evaluate() == ConditionState.FALSE + assert literal_true.implication(literal_false).evaluate() == ConditionState.FALSE - assert literal_expired.implicate(literal_true).evaluate() == ConditionState.TRUE - assert literal_expired.implicate(literal_expired).evaluate() == ConditionState.TRUE - assert literal_expired.implicate(literal_before).evaluate() == ConditionState.TRUE - assert literal_expired.implicate(literal_false).evaluate() == ConditionState.TRUE + assert literal_expired.implication(literal_true).evaluate() == ConditionState.TRUE + assert ( + literal_expired.implication(literal_expired).evaluate() == ConditionState.TRUE + ) + assert literal_expired.implication(literal_before).evaluate() == ConditionState.TRUE + assert literal_expired.implication(literal_false).evaluate() == ConditionState.TRUE - assert literal_before.implicate(literal_true).evaluate() == ConditionState.TRUE - assert literal_before.implicate(literal_expired).evaluate() == ConditionState.TRUE - assert literal_before.implicate(literal_before).evaluate() == ConditionState.TRUE - assert literal_before.implicate(literal_false).evaluate() == ConditionState.TRUE + assert literal_before.implication(literal_true).evaluate() == ConditionState.TRUE + assert literal_before.implication(literal_expired).evaluate() == ConditionState.TRUE + assert literal_before.implication(literal_before).evaluate() == ConditionState.TRUE + assert literal_before.implication(literal_false).evaluate() == ConditionState.TRUE - assert literal_false.implicate(literal_true).evaluate() == ConditionState.TRUE - assert literal_false.implicate(literal_expired).evaluate() == ConditionState.TRUE - assert literal_false.implicate(literal_false).evaluate() == ConditionState.TRUE - assert literal_false.implicate(literal_false).evaluate() == ConditionState.TRUE + assert literal_false.implication(literal_true).evaluate() == ConditionState.TRUE + assert literal_false.implication(literal_expired).evaluate() == ConditionState.TRUE + assert literal_false.implication(literal_false).evaluate() == ConditionState.TRUE + assert literal_false.implication(literal_false).evaluate() == ConditionState.TRUE def test_delay_condition(): @@ -173,11 +193,11 @@ def test_delay_condition(): long_delay = 10 first_time_window_true = 5 window_condition = TimeWindowCondition(4, 10) - delayed_condition = window_condition.delay(long_delay, persistant=False) + delayed_condition = window_condition.trigger(long_delay, persistant=False) - assert delayed_condition == DelayCondition( + assert delayed_condition == TriggerCondition( inner_condition=window_condition, - seconds=long_delay, + delay_seconds=long_delay, persistant=False, ) @@ -211,7 +231,7 @@ def test_delay_condition(): assert not delayed_condition.evaluate(simulation_time=time) # Test persistant true - delayed_condition = window_condition.delay(short_delay, persistant=True) + delayed_condition = window_condition.trigger(short_delay, persistant=True) time = first_time_window_true assert not delayed_condition.evaluate(simulation_time=time) time = first_time_window_true + short_delay @@ -223,8 +243,8 @@ def test_delay_condition(): def test_dependee_condition(): dependee_condition = DependeeActorCondition("leader") - assert dependee_condition.evaluate(active_actor_ids={"mr", "leader"}) - assert not dependee_condition.evaluate(active_actor_ids={"other", "vehicle"}) + assert dependee_condition.evaluate(actor_ids={"mr", "leader"}) + assert not dependee_condition.evaluate(actor_ids={"other", "vehicle"}) def test_literal_condition(): @@ -241,11 +261,11 @@ def test_negated_condition(): literal_true = LiteralCondition(ConditionState.TRUE) literal_false = LiteralCondition(ConditionState.FALSE) - assert literal_false.negate() == NegatedCondition(literal_false) - assert literal_true.negate() == NegatedCondition(literal_true) + assert literal_false.negation() == NegatedCondition(literal_false) + assert literal_true.negation() == NegatedCondition(literal_true) - assert literal_false.negate().evaluate() - assert not literal_true.negate().evaluate() + assert literal_false.negation().evaluate() + assert not literal_true.negation().evaluate() def test_on_road_condition(): @@ -276,19 +296,19 @@ def test_subject_condition(): subject_condition.evaluate(vehicle_state=None) with pytest.raises(TypeError): - subject_condition.negate() + subject_condition.negation() with pytest.raises(TypeError): - subject_condition.conjoin(literal_true) + subject_condition.conjunction(literal_true) with pytest.raises(TypeError): - subject_condition.disjoin(literal_true) + subject_condition.disjunction(literal_true) with pytest.raises(TypeError): - subject_condition.implicate(literal_true) + subject_condition.implication(literal_true) with pytest.raises(TypeError): - subject_condition.delay(10) + subject_condition.trigger(10) def test_vehicle_speed_condition(): diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 593038937a..b70c40922e 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -26,6 +26,7 @@ import warnings from dataclasses import dataclass, field, replace from enum import IntEnum, IntFlag +from functools import cached_property from typing import ( Any, Callable, @@ -34,6 +35,7 @@ List, Optional, Sequence, + Set, Tuple, Type, Union, @@ -54,6 +56,7 @@ from smarts.core import gen_id from smarts.core.colors import Colors +from smarts.core.condition_state import ConditionState from smarts.core.coordinates import RefLinePoint from smarts.core.default_map_builder import get_road_map from smarts.core.road_map import RoadMap @@ -600,22 +603,6 @@ class Traffic: """ -class ConditionState(IntFlag): - """Represents the state of a condition.""" - - FALSE = 0 - """This condition is false.""" - BEFORE = 1 - """This condition is false and never evaluated true before.""" - EXPIRED = 2 - """This condition is false and will never evaluate true.""" - TRUE = 4 - """This condition is true.""" - - def __bool__(self) -> bool: - return self.TRUE in self - - class ConditionOperator(IntEnum): """Represents logical operators between conditions.""" @@ -633,6 +620,35 @@ class ConditionOperator(IntEnum): # """True if its operand is false, otherwise false.""" +class ConditionRequires(IntFlag): + none = enum.auto() + + # MISSION CONSTANTS + agent_id = enum.auto() + mission = enum.auto() + + # SIMULATION STATE + simulation_time = enum.auto() + actor_ids = enum.auto() + actor_states = enum.auto() + simulation = enum.auto() + + # ACTOR STATE + current_actor_state = enum.auto() + current_actor_road_status = enum.auto() + + all_simulation_state = simulation_time | actor_ids | actor_states | simulation + all_current_actor_state = mission | current_actor_state | current_actor_road_status + + +@dataclass(frozen=True) +class ConditionEvaluationArgs: + actor_ids: Set[str] + mission_start_time: float + simulation_time: Union[float, int] + vehicle_state: Any + + @dataclass(frozen=True) class Condition: """This encompasses an expression to evaluate to a logical result.""" @@ -645,25 +661,48 @@ def evaluate(self, *args, **kwargs) -> ConditionState: """ raise NotImplementedError() - def negate(self) -> "NegatedCondition": + @property + def requires(self) -> ConditionRequires: + """Information that the condition requires to evaluate state. + + Returns: + ConditionRequires: The types of information this condition needs in order to evaluate. + """ + raise NotImplementedError() + + def negation(self) -> "NegatedCondition": """Negates this condition.""" return NegatedCondition(self) - def conjoin(self, other: "Condition") -> "CompoundCondition": + def conjunction(self, other: "Condition") -> "CompoundCondition": """Resolve conditions as A AND B.""" return CompoundCondition(self, other, operator=ConditionOperator.CONJUNCTION) - def disjoin(self, other: "Condition") -> "CompoundCondition": + def disjunction(self, other: "Condition") -> "CompoundCondition": """Resolve conditions as A OR B.""" return CompoundCondition(self, other, operator=ConditionOperator.DISJUNCTION) - def implicate(self, other: "Condition") -> "CompoundCondition": + def implication(self, other: "Condition") -> "CompoundCondition": """Resolve conditions as A AND B OR NOT A.""" return CompoundCondition(self, other, operator=ConditionOperator.IMPLICATION) - def delay(self, seconds, persistant=False) -> "DelayCondition": + def trigger(self, seconds, persistant=False) -> "TriggerCondition": """Delays the current condition until the given number of simulation seconds have occured.""" - return DelayCondition(self, seconds=seconds, persistant=persistant) + return TriggerCondition(self, delay_seconds=seconds, persistant=persistant) + + def __and__(self, other: "Condition") -> "CompoundCondition": + """Resolve conditions as A AND B""" + assert isinstance(other, Condition) + return self.conjunction(other) + + def __or__(self, other: "Condition") -> "CompoundCondition": + """Resolve conditions as A OR B.""" + assert isinstance(other, Condition) + return self.disjunction(other) + + def __neg__(self) -> "NegatedCondition": + """Negates this condition""" + return self.negation() @dataclass(frozen=True) @@ -680,6 +719,13 @@ def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: """ raise NotImplementedError() + @property + def requires(self) -> ConditionRequires: + return ( + ConditionRequires.all_current_actor_state + | ConditionRequires.all_simulation_state + ) + _abstract_conditions = (Condition, SubjectCondition) @@ -694,6 +740,10 @@ class LiteralCondition(Condition): def evaluate(self, *args, **kwargs) -> ConditionState: return self.literal + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.none + @dataclass(frozen=True) class TimeWindowCondition(Condition): @@ -711,6 +761,10 @@ def evaluate(self, *args, simulation_time, **kwargs): return ConditionState.EXPIRED return ConditionState.BEFORE + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.simulation_time + @dataclass(frozen=True) class DependeeActorCondition(Condition): @@ -719,11 +773,15 @@ class DependeeActorCondition(Condition): actor_id: str """The id of an actor in the simulation that needs to exist for this condition to be true.""" - def evaluate(self, *args, active_actor_ids, **kwargs): - if self.actor_id in active_actor_ids: + def evaluate(self, *args, actor_ids, **kwargs): + if self.actor_id in actor_ids: return ConditionState.TRUE return ConditionState.FALSE + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.actor_ids + def __post_init__(self): assert isinstance(self.actor_id, str) @@ -738,6 +796,10 @@ class NegatedCondition(Condition): def evaluate(self, *args, **kwargs) -> ConditionState: return ~self.inner_condition.evaluate(*args, **kwargs) + @property + def requires(self) -> ConditionRequires: + return self.inner_condition.requires + def __post_init__(self): if self.inner_condition.__class__ in _abstract_conditions: raise TypeError( @@ -746,8 +808,9 @@ def __post_init__(self): @dataclass(frozen=True) -class DelayCondition(Condition): - """This condition delays the inner condition by a number of seconds. +class TriggerCondition(Condition): + """This condition is a trigger that assumes FALSE and then turns true permanently on the inner section + becoming TRUE. The is an option to delay repsonse to the the inner condition by a number of seconds. This can be used to wait for some time after the inner condition has become true to be true. Note that the original condition may no longer be true by the time delay has expired. @@ -758,7 +821,7 @@ class DelayCondition(Condition): inner_condition: Condition """The inner condition to delay.""" - seconds: float + delay_seconds: float """The number of seconds to delay for.""" persistant: bool = False @@ -766,19 +829,28 @@ class DelayCondition(Condition): def evaluate(self, *args, simulation_time, **kwargs) -> ConditionState: key = "met_time" + result = ConditionState.FALSE if (met_time := getattr(self, key, None)) is not None: - if simulation_time >= met_time + self.seconds: + if simulation_time >= met_time + self.delay_seconds: result = ConditionState.TRUE if self.persistant: result &= self.inner_condition.evaluate( *args, simulation_time=simulation_time, **kwargs ) return result - elif self.inner_condition.evaluate( + elif result := self.inner_condition.evaluate( *args, simulation_time=simulation_time, **kwargs ): object.__setattr__(self, key, simulation_time) - return ConditionState.FALSE + + temporals = result & (ConditionState.BEFORE | ConditionState.EXPIRED) + if ConditionState.EXPIRED in temporals: + return ConditionState.EXPIRED + return temporals & ConditionState.BEFORE + + @property + def requires(self) -> ConditionRequires: + return self.inner_condition.requires def __post_init__(self): if self.inner_condition.__class__ in _abstract_conditions: @@ -791,8 +863,16 @@ def __post_init__(self): class OnRoadCondition(SubjectCondition): """This condition is true if the subject is on road.""" - def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: - return ConditionState.TRUE if vehicle_state.on_road else ConditionState.FALSE + def evaluate(self, *args, current_actor_road_status, **kwargs) -> ConditionState: + return ( + ConditionState.TRUE + if current_actor_road_status.on_road + else ConditionState.FALSE + ) + + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.current_actor_road_status @dataclass(frozen=True) @@ -808,6 +888,10 @@ def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: else ConditionState.FALSE ) + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.current_actor_state + @dataclass(frozen=True) class VehicleSpeedCondition(SubjectCondition): @@ -826,6 +910,10 @@ def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: else ConditionState.FALSE ) + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.current_actor_state + @classmethod def loitering(cls: Type["VehicleSpeedCondition"], abs_error=0.01): return cls(low=abs_error, high=abs_error) @@ -902,6 +990,10 @@ def evaluate(self, *args, **kwargs): return ConditionState.FALSE + @cached_property + def requires(self) -> ConditionRequires: + return self.first_condition.requires | self.second_condition.requires + def __post_init__(self): for condition in (self.first_condition, self.second_condition): if condition.__class__ in _abstract_conditions: @@ -948,6 +1040,9 @@ class IdEntryTactic(EntryTactic): def __post_init__(self): assert isinstance(self.actor_id, str) assert isinstance(self.condition, (Condition)) + assert not ( + self.condition.requires & ConditionRequires.all_current_actor_state + ), f"Id entry tactic cannot use conditions that require any_vehicle_state." @dataclass(frozen=True) From 47fe0e02c745d93d482dcece75388e6ea0559f80 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 16:16:32 +0000 Subject: [PATCH 35/59] Update condition types and terminology. --- smarts/core/id_actor_capture_manager.py | 4 +- smarts/sstudio/tests/test_conditions.py | 88 ++++++++------ smarts/sstudio/types.py | 147 ++++++++++++++++++------ 3 files changed, 161 insertions(+), 78 deletions(-) diff --git a/smarts/core/id_actor_capture_manager.py b/smarts/core/id_actor_capture_manager.py index 5820ceca7d..67e1bab137 100644 --- a/smarts/core/id_actor_capture_manager.py +++ b/smarts/core/id_actor_capture_manager.py @@ -63,7 +63,7 @@ def step(self, sim): assert isinstance(entry_tactic, IdEntryTactic) vehicle = sim.vehicle_index.vehicle_by_id(actor_id) condition_result = entry_tactic.condition.evaluate( - simulation_time=sim.elapsed_sim_time, + time=sim.elapsed_sim_time, actor_ids=sim.vehicle_index.vehicle_ids, vehicle_state=vehicle.state if vehicle else None, mission_start_time=mission.start_time, @@ -113,7 +113,7 @@ def reset(self, scenario, sim): continue vehicle = sim.vehicle_index.vehicle_by_id(entry_tactic.actor_id, None) condition_result = entry_tactic.condition.evaluate( - simulation_time=sim.elapsed_sim_time, + time=sim.elapsed_sim_time, actor_ids=sim.vehicle_index.vehicle_ids, vehicle_state=vehicle.state if vehicle else None, mission_start_time=mission.start_time, diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index 86c294a726..d430e9dcc6 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -28,17 +28,23 @@ Condition, ConditionOperator, ConditionState, + ConditionTrigger, DependeeActorCondition, + ExpireTrigger, LiteralCondition, NegatedCondition, OnRoadCondition, SubjectCondition, TimeWindowCondition, - TriggerCondition, VehicleSpeedCondition, VehicleTypeCondition, ) +literal_true = LiteralCondition(ConditionState.TRUE) +literal_false = LiteralCondition(ConditionState.FALSE) +literal_before = LiteralCondition(ConditionState.BEFORE) +literal_expired = LiteralCondition(ConditionState.EXPIRED) + def test_condition_state(): assert bool(ConditionState.TRUE) @@ -74,8 +80,6 @@ def test_condition_state(): def test_condition(): - literal_true = LiteralCondition(ConditionState.TRUE) - condition = Condition() with pytest.raises(NotImplementedError): @@ -96,12 +100,11 @@ def test_condition(): with pytest.raises(TypeError): condition.trigger(10) + with pytest.raises(TypeError): + condition.expire(10) + def test_compound_condition(): - literal_true = LiteralCondition(ConditionState.TRUE) - literal_false = LiteralCondition(ConditionState.FALSE) - literal_before = LiteralCondition(ConditionState.BEFORE) - literal_expired = LiteralCondition(ConditionState.EXPIRED) assert CompoundCondition( first_condition=literal_true, @@ -188,14 +191,14 @@ def test_compound_condition(): assert literal_false.implication(literal_false).evaluate() == ConditionState.TRUE -def test_delay_condition(): +def test_condition_trigger(): short_delay = 4 long_delay = 10 first_time_window_true = 5 window_condition = TimeWindowCondition(4, 10) delayed_condition = window_condition.trigger(long_delay, persistant=False) - assert delayed_condition == TriggerCondition( + assert delayed_condition == ConditionTrigger( inner_condition=window_condition, delay_seconds=long_delay, persistant=False, @@ -204,40 +207,53 @@ def test_delay_condition(): # before time = 2 assert ( - not delayed_condition.evaluate(simulation_time=time) - ) and not window_condition.evaluate(simulation_time=time) + not delayed_condition.evaluate(time=time) + ) and not window_condition.evaluate(time=time) # first true time = first_time_window_true - assert ( - not delayed_condition.evaluate(simulation_time=time) - ) and window_condition.evaluate(simulation_time=time) + assert (not delayed_condition.evaluate(time=time)) and window_condition.evaluate( + time=time + ) # delay not expired time = first_time_window_true + long_delay - 1 assert ( - not delayed_condition.evaluate(simulation_time=time) - ) and not window_condition.evaluate(simulation_time=time) + not delayed_condition.evaluate(time=time) + ) and not window_condition.evaluate(time=time) # delay expired time = first_time_window_true + long_delay - assert delayed_condition.evaluate( - simulation_time=time - ) and not window_condition.evaluate(simulation_time=time) + assert delayed_condition.evaluate(time=time) and not window_condition.evaluate( + time=time + ) # delay expired time = first_time_window_true + long_delay + 1 - assert delayed_condition.evaluate( - simulation_time=time - ) and not window_condition.evaluate(simulation_time=time) + assert delayed_condition.evaluate(time=time) and not window_condition.evaluate( + time=time + ) # delay not expired time = first_time_window_true + long_delay - 1 - assert not delayed_condition.evaluate(simulation_time=time) + assert not delayed_condition.evaluate(time=time) # Test persistant true delayed_condition = window_condition.trigger(short_delay, persistant=True) time = first_time_window_true - assert not delayed_condition.evaluate(simulation_time=time) + assert not delayed_condition.evaluate(time=time) time = first_time_window_true + short_delay - assert delayed_condition.evaluate(simulation_time=time) + assert delayed_condition.evaluate(time=time) time = first_time_window_true + long_delay - assert not delayed_condition.evaluate(simulation_time=time) + assert not delayed_condition.evaluate(time=time) + + +def test_expiring_trigger(): + end_time = 10 + before = end_time - 1 + after = end_time + 1 + expire_trigger = literal_true.expire(end_time) + + assert expire_trigger == ExpireTrigger(literal_true, end_time) + + assert expire_trigger.evaluate(time=before) + assert not expire_trigger.evaluate(time=end_time) + assert not expire_trigger.evaluate(time=after) def test_dependee_condition(): @@ -253,14 +269,9 @@ def test_literal_condition(): assert literal_false.evaluate() == ConditionState.FALSE assert literal_true.evaluate() == ConditionState.TRUE - assert literal_true.evaluate() - assert not literal_false.evaluate() def test_negated_condition(): - literal_true = LiteralCondition(ConditionState.TRUE) - literal_false = LiteralCondition(ConditionState.FALSE) - assert literal_false.negation() == NegatedCondition(literal_false) assert literal_true.negation() == NegatedCondition(literal_true) @@ -280,16 +291,14 @@ def test_time_window_condition(): window_condition = TimeWindowCondition(start=start, end=end) - assert not window_condition.evaluate(simulation_time=start - 1) - assert window_condition.evaluate(simulation_time=start) - assert window_condition.evaluate(simulation_time=between) - assert not window_condition.evaluate(simulation_time=end) - assert not window_condition.evaluate(simulation_time=end + 1) + assert not window_condition.evaluate(time=start - 1) + assert window_condition.evaluate(time=start) + assert window_condition.evaluate(time=between) + assert not window_condition.evaluate(time=end) + assert not window_condition.evaluate(time=end + 1) def test_subject_condition(): - literal_true = LiteralCondition(ConditionState.TRUE) - subject_condition = SubjectCondition() with pytest.raises(NotImplementedError): @@ -310,6 +319,9 @@ def test_subject_condition(): with pytest.raises(TypeError): subject_condition.trigger(10) + with pytest.raises(TypeError): + subject_condition.expire(10) + def test_vehicle_speed_condition(): low = 30 diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index b70c40922e..0e34533ed3 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -628,7 +628,7 @@ class ConditionRequires(IntFlag): mission = enum.auto() # SIMULATION STATE - simulation_time = enum.auto() + time = enum.auto() actor_ids = enum.auto() actor_states = enum.auto() simulation = enum.auto() @@ -637,16 +637,20 @@ class ConditionRequires(IntFlag): current_actor_state = enum.auto() current_actor_road_status = enum.auto() - all_simulation_state = simulation_time | actor_ids | actor_states | simulation - all_current_actor_state = mission | current_actor_state | current_actor_road_status + any_simulation_state = time | actor_ids | actor_states | simulation + any_current_actor_state = mission | current_actor_state | current_actor_road_status @dataclass(frozen=True) class ConditionEvaluationArgs: - actor_ids: Set[str] - mission_start_time: float - simulation_time: Union[float, int] - vehicle_state: Any + agent_id: Optional[str] + mission: Optional[Any] + time: Optional[float] + actor_ids: Optional[Set[str]] + actor_states: Optional[List[Any]] + simulation: Any + current_actor_state: Optional[Any] + current_actor_road_status: Optional[Any] @dataclass(frozen=True) @@ -686,9 +690,41 @@ def implication(self, other: "Condition") -> "CompoundCondition": """Resolve conditions as A AND B OR NOT A.""" return CompoundCondition(self, other, operator=ConditionOperator.IMPLICATION) - def trigger(self, seconds, persistant=False) -> "TriggerCondition": - """Delays the current condition until the given number of simulation seconds have occured.""" - return TriggerCondition(self, delay_seconds=seconds, persistant=persistant) + def trigger( + self, delay_seconds: float, persistant: bool = False + ) -> "ConditionTrigger": + """Converts the condition to a trigger which becomes permanently TRUE after the inner condition becomes TRUE. + + Args: + delay_seconds (float): Applies the trigger after the delay has passed since the inner condition first TRUE. Defaults to False. + persistant (bool, optional): Mixes the inner result with the trigger result using an AND operation. + + Returns: + ConditionTrigger: A resulting condition. + """ + return ConditionTrigger( + self, delay_seconds=delay_seconds, persistant=persistant + ) + + def expire(self, time, expired_state=ConditionState.EXPIRED) -> "ExpireTrigger": + """This trigger evaluates to the expired state value after the given simulation time. + + >>> trigger = LiteralCondition(ConditionState.TRUE).expire(20) + >>> trigger.evaluate(time=10) + ConditionState.TRUE + >>> trigger.evaluate(time=30) + ConditionState.FALSE + + Args: + time (float): The simulation time when this trigger changes. + expired_state (ConditionState, optional): The condition state to use when the simulation is after the given time. Defaults to ConditionState.EXPIRED. + + Returns: + ExpireTrigger: The resulting condition. + """ + return ExpireTrigger( + inner_condition=self, time=time, expired_state=expired_state + ) def __and__(self, other: "Condition") -> "CompoundCondition": """Resolve conditions as A AND B""" @@ -722,8 +758,8 @@ def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: @property def requires(self) -> ConditionRequires: return ( - ConditionRequires.all_current_actor_state - | ConditionRequires.all_simulation_state + ConditionRequires.any_current_actor_state + | ConditionRequires.any_simulation_state ) @@ -754,16 +790,16 @@ class TimeWindowCondition(Condition): end: float """The ending simulation time as of which this condition becomes expired.""" - def evaluate(self, *args, simulation_time, **kwargs): - if self.start <= simulation_time < self.end: + def evaluate(self, *args, time, **kwargs): + if self.start <= time < self.end or self.end == sys.maxsize: return ConditionState.TRUE - elif simulation_time > self.end: + elif time > self.end: return ConditionState.EXPIRED return ConditionState.BEFORE @property def requires(self) -> ConditionRequires: - return ConditionRequires.simulation_time + return ConditionRequires.time @dataclass(frozen=True) @@ -808,9 +844,39 @@ def __post_init__(self): @dataclass(frozen=True) -class TriggerCondition(Condition): - """This condition is a trigger that assumes FALSE and then turns true permanently on the inner section - becoming TRUE. The is an option to delay repsonse to the the inner condition by a number of seconds. +class ExpireTrigger(Condition): + """This condition allows for expiration after a given time.""" + + inner_condition: Condition + """The inner condition to delay.""" + + time: float + """The simulation time when this trigger becomes expired.""" + + expired_state: ConditionState = ConditionState.EXPIRED + """The state value this trigger should have when it expires.""" + + def evaluate(self, *args, time, **kwargs) -> ConditionState: + if time >= self.time: + return self.expired_state + return self.inner_condition.evaluate(*args, time=time, **kwargs) + + @property + def requires(self) -> ConditionRequires: + return self.inner_condition.requires | ConditionRequires.time + + def __post_init__(self): + if self.inner_condition.__class__ in _abstract_conditions: + raise TypeError( + f"Abstract `{self.inner_condition.__class__.__name__}` cannot be wrapped by a trigger." + ) + + +@dataclass(frozen=True) +class ConditionTrigger(Condition): + """This condition is a trigger that assumes an untriggered constant state and then turns to the other state permanently + on the inner condition becoming TRUE. There is also an option to delay repsonse to the the inner condition by a number + of seconds. This will convey an EXPIRED value immediately because that state means the inner value will never be true. This can be used to wait for some time after the inner condition has become true to be true. Note that the original condition may no longer be true by the time delay has expired. @@ -824,29 +890,31 @@ class TriggerCondition(Condition): delay_seconds: float """The number of seconds to delay for.""" + untriggered_state: ConditionState = ConditionState.BEFORE + """The state before the inner trigger condition and delay is resolved.""" + + triggered_state: ConditionState = ConditionState.TRUE + """The state after the inner trigger condition and delay is resolved.""" + persistant: bool = False - """If the inner condition must still be true at the end of the delay to be true.""" + """If the inner condition state is used in conjuction with the triggered state. (inner_condition_state & triggered_state)""" - def evaluate(self, *args, simulation_time, **kwargs) -> ConditionState: + def evaluate(self, *args, time, **kwargs) -> ConditionState: key = "met_time" - result = ConditionState.FALSE - if (met_time := getattr(self, key, None)) is not None: - if simulation_time >= met_time + self.delay_seconds: - result = ConditionState.TRUE + result = self.untriggered_state + if self.delay_seconds <= 0 or (met_time := getattr(self, key, -1)) > -1: + if time >= met_time + self.delay_seconds: + result = self.triggered_state if self.persistant: - result &= self.inner_condition.evaluate( - *args, simulation_time=simulation_time, **kwargs - ) + result &= self.inner_condition.evaluate(*args, time=time, **kwargs) return result - elif result := self.inner_condition.evaluate( - *args, simulation_time=simulation_time, **kwargs - ): - object.__setattr__(self, key, simulation_time) + elif result := self.inner_condition.evaluate(*args, time=time, **kwargs): + object.__setattr__(self, key, time) - temporals = result & (ConditionState.BEFORE | ConditionState.EXPIRED) + temporals = result & (ConditionState.EXPIRED) if ConditionState.EXPIRED in temporals: return ConditionState.EXPIRED - return temporals & ConditionState.BEFORE + return self.untriggered_state @property def requires(self) -> ConditionRequires: @@ -855,7 +923,7 @@ def requires(self) -> ConditionRequires: def __post_init__(self): if self.inner_condition.__class__ in _abstract_conditions: raise TypeError( - f"Abstract `{self.inner_condition.__class__.__name__}` cannot use delay operations." + f"Abstract `{self.inner_condition.__class__.__name__}` cannot be wrapped by a trigger." ) @@ -1026,6 +1094,12 @@ class TrapEntryTactic(EntryTactic): condition: Condition = LiteralCondition(ConditionState.TRUE) """A condition that is used to add additional exclusions.""" + def __post_init__(self): + assert isinstance(self.condition, (Condition)) + assert not ( + self.condition.requires & ConditionRequires.any_current_actor_state + ), f"Trap entry tactic cannot use conditions that require any_vehicle_state." + @dataclass(frozen=True) class IdEntryTactic(EntryTactic): @@ -1040,9 +1114,6 @@ class IdEntryTactic(EntryTactic): def __post_init__(self): assert isinstance(self.actor_id, str) assert isinstance(self.condition, (Condition)) - assert not ( - self.condition.requires & ConditionRequires.all_current_actor_state - ), f"Id entry tactic cannot use conditions that require any_vehicle_state." @dataclass(frozen=True) From de21f79c53fe806c4f2c80b6136e23ea387118ec Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 16:17:17 +0000 Subject: [PATCH 36/59] Implement conditions in trap manager. --- smarts/core/trap_manager.py | 126 +++++++++++++++++++++--------------- 1 file changed, 75 insertions(+), 51 deletions(-) diff --git a/smarts/core/trap_manager.py b/smarts/core/trap_manager.py index 105b55ccca..aee9de4b9d 100644 --- a/smarts/core/trap_manager.py +++ b/smarts/core/trap_manager.py @@ -74,7 +74,7 @@ def includes(self, vehicle_id: str): def evaluate( self, simulation, - vehicle_state: Optional[Any], + vehicle_state: Optional[Any] = None, ) -> ConditionState: """Considers the given vehicle to see if it is applicable. @@ -87,13 +87,22 @@ def evaluate( """ entry_tactic: TrapEntryTactic = self.mission.entry_tactic return entry_tactic.condition.evaluate( - simulation_time=simulation.elapsed_sim_time, + time=simulation.elapsed_sim_time, actor_ids=simulation.vehicle_index.vehicle_ids, vehicle_state=vehicle_state, mission_start_time=self.mission.start_time, ) +@dataclass(frozen=True) +class CaptureState: + ready_state: ConditionState + trap: Optional[Trap] + vehicle_id: Optional[str] = None + updated_mission: Optional[Mission] = None + default: bool = False + + class TrapManager(ActorCaptureManager): """Facilitates agent hijacking of actors""" @@ -182,15 +191,15 @@ def reset_traps(self, used_traps): self.remove_traps(used_traps) def step(self, sim): - """Run hijacking and update agent and actor states.""" + """Run vehicle hijacking and update agent and actor states.""" from smarts.core.smarts import SMARTS assert isinstance(sim, SMARTS) - captures_by_agent_id: Dict[str, List[Tuple[str, Trap, Mission]]] = defaultdict( - list + capture_by_agent_id: Dict[str, CaptureState] = defaultdict( + lambda: CaptureState(ConditionState.FALSE, None, default=True) ) - # Do an optimization to only check if there are pending agents. + # An optimization to short circuit if there are no pending agents. if not ( sim.agent_manager.pending_agent_ids | sim.agent_manager.pending_social_agent_ids @@ -206,24 +215,41 @@ def step(self, sim): v_id: sim.vehicle_index.vehicle_by_id(v_id) for v_id in social_vehicle_ids } - def largest_vehicle_plane_dimension(vehicle: Vehicle): - return max(*vehicle.chassis.dimensions.as_lwh[:2]) - vehicle_comp = [ - (v.position[:2], largest_vehicle_plane_dimension(v), v) + (v.position[:2], max(v.chassis.dimensions.as_lwh[:2]), v) for v in vehicles.values() ] - for agent_id in ( + pending_agent_ids = ( sim.agent_manager.pending_agent_ids | sim.agent_manager.pending_social_agent_ids - ): + ) + # Pending agents is currently used to avoid + for agent_id in pending_agent_ids: trap = self._traps.get(agent_id) if trap is None: continue + # Skip the capturing process if history traffic is used + if trap.mission.vehicle_spec is not None: + continue + if not trap.ready(sim.elapsed_sim_time): + capture_by_agent_id[agent_id] = CaptureState( + ConditionState.BEFORE, trap + ) + continue + + if trap.patience_expired(sim.elapsed_sim_time): + capture_by_agent_id[agent_id] = CaptureState( + ConditionState.EXPIRED, trap, updated_mission=trap.mission + ) + continue + + trap_condition = trap.evaluate(sim) + if not trap_condition: + capture_by_agent_id[agent_id] = CaptureState(trap_condition, trap) continue # Order vehicle ids by distance. @@ -233,61 +259,52 @@ def largest_vehicle_plane_dimension(vehicle: Vehicle): vehicles[v].position[:2], trap.mission.start.position[:2] ), ) - for v_id in sorted_vehicle_ids: - # Skip the capturing process if history traffic is used - if trap.mission.vehicle_spec is not None: - break - - if not trap.includes(v_id): + for vehicle_id in sorted_vehicle_ids: + if not trap.includes(vehicle_id): continue - vehicle: Vehicle = vehicles[v_id] + vehicle: Vehicle = vehicles[vehicle_id] point = vehicle.pose.point.as_shapely if not point.within(trap.geometry): continue - captures_by_agent_id[agent_id].append( - ( - v_id, - trap, - replace( - trap.mission, - start=Start(vehicle.position[:2], vehicle.pose.heading), - ), - ) + capture_by_agent_id[agent_id] = CaptureState( + ready_state=trap_condition, + trap=trap, + updated_mission=replace( + trap.mission, + start=Start(vehicle.position[:2], vehicle.pose.heading), + ), + vehicle_id=vehicle_id, ) - social_vehicle_ids.remove(v_id) + social_vehicle_ids.remove(vehicle_id) break + else: + capture_by_agent_id[agent_id] = CaptureState( + ready_state=trap_condition, + trap=trap, + ) used_traps = [] - for agent_id in ( - sim.agent_manager.pending_agent_ids - | sim.agent_manager.pending_social_agent_ids - ): - trap = self._traps.get(agent_id) + for agent_id in pending_agent_ids: + capture = capture_by_agent_id[agent_id] - if trap is None: + if capture.default: continue - if not trap.ready(sim.elapsed_sim_time): + if capture.trap is None: continue - captures = captures_by_agent_id[agent_id] + if not capture.trap.ready(sim.elapsed_sim_time): + continue vehicle: Optional[Vehicle] = None - if len(captures) > 0: - vehicle_id, trap, mission = rand.choice(captures) - vehicle = self._take_existing_vehicle( - sim, - vehicle_id, - agent_id, - mission, - social=agent_id in sim.agent_manager.pending_social_agent_ids, - ) - elif trap.patience_expired(sim.elapsed_sim_time): + if ConditionState.EXPIRED in capture.ready_state: # Make sure there is not a vehicle in the same location - mission = trap.mission + mission = capture.updated_mission + if mission is None: + continue if mission.vehicle_spec is None: nv_dims = Vehicle.agent_vehicle_dims(mission) new_veh_maxd = max(nv_dims.as_lwh[:2]) @@ -309,14 +326,21 @@ def largest_vehicle_plane_dimension(vehicle: Vehicle): trap.default_entry_speed, social=agent_id in sim.agent_manager.pending_social_agent_ids, ) + elif trap_condition and capture.vehicle_id is not None: + vehicle = self._take_existing_vehicle( + sim, + capture.vehicle_id, + agent_id, + capture.updated_mission, + social=agent_id in sim.agent_manager.pending_social_agent_ids, + ) else: continue if vehicle is None: continue - used_traps.append((agent_id, trap)) + used_traps.append((agent_id, capture.trap)) - if len(used_traps) > 0: - self.remove_traps(used_traps) + self.remove_traps(used_traps) @property def traps(self) -> Dict[str, Trap]: From ea43885c9b21b5a377546697b5da305d579e3d01 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 16:17:59 +0000 Subject: [PATCH 37/59] Bubble manager now inherits form ActorCaptureManager. --- smarts/core/bubble_manager.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/smarts/core/bubble_manager.py b/smarts/core/bubble_manager.py index 20ef36f641..fa085c9352 100644 --- a/smarts/core/bubble_manager.py +++ b/smarts/core/bubble_manager.py @@ -31,6 +31,7 @@ from shapely.affinity import rotate, translate from shapely.geometry import CAP_STYLE, JOIN_STYLE, Point, Polygon +from smarts.core.actor_capture_manager import ActorCaptureManager from smarts.core.data_model import SocialAgent from smarts.core.plan import ( EndlessGoal, @@ -206,7 +207,6 @@ def admissibility( ) ) - # pytype: disable=unsupported-operands all_hijacked_vehicle_ids = ( current_hijacked_vehicle_ids | vehicle_ids_by_bubble_state[BubbleState.InAirlock][self] @@ -216,7 +216,6 @@ def admissibility( current_shadowed_vehicle_ids | vehicle_ids_by_bubble_state[BubbleState.InBubble][self] ) - {vehicle_id} - # pytype: enable=unsupported-operands hijackable = len(all_hijacked_vehicle_ids) < ( self._limit.hijack_limit or maxsize @@ -421,7 +420,7 @@ def __hash__(self) -> int: return hash((self.vehicle_id, self.state, self.transition, self.bubble.id)) -class BubbleManager: +class BubbleManager(ActorCaptureManager): """Manages bubble interactions.""" def __init__(self, bubbles: Sequence[SSBubble], road_map: RoadMap): @@ -472,7 +471,7 @@ def is_active(bubble): @lru_cache(maxsize=2) def _vehicle_ids_divided_by_bubble_state( cursors: FrozenSet[Cursor], - ) -> Dict[Bubble, Set[Bubble]]: + ) -> Dict[BubbleState, Dict[Bubble, Set[Bubble]]]: vehicle_ids_grouped_by_cursor = defaultdict(lambda: defaultdict(set)) for cursor in cursors: vehicle_ids_grouped_by_cursor[cursor.state][cursor.bubble].add( From e0d1986785d1f059123924bf3af5319f65024f42 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 16:28:59 +0000 Subject: [PATCH 38/59] Fix docstring test. --- smarts/core/trap_manager.py | 16 ++++++++-------- smarts/sstudio/types.py | 5 +++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/smarts/core/trap_manager.py b/smarts/core/trap_manager.py index aee9de4b9d..b28972d17e 100644 --- a/smarts/core/trap_manager.py +++ b/smarts/core/trap_manager.py @@ -95,7 +95,7 @@ def evaluate( @dataclass(frozen=True) -class CaptureState: +class _CaptureState: ready_state: ConditionState trap: Optional[Trap] vehicle_id: Optional[str] = None @@ -195,8 +195,8 @@ def step(self, sim): from smarts.core.smarts import SMARTS assert isinstance(sim, SMARTS) - capture_by_agent_id: Dict[str, CaptureState] = defaultdict( - lambda: CaptureState(ConditionState.FALSE, None, default=True) + capture_by_agent_id: Dict[str, _CaptureState] = defaultdict( + lambda: _CaptureState(ConditionState.FALSE, None, default=True) ) # An optimization to short circuit if there are no pending agents. @@ -236,20 +236,20 @@ def step(self, sim): continue if not trap.ready(sim.elapsed_sim_time): - capture_by_agent_id[agent_id] = CaptureState( + capture_by_agent_id[agent_id] = _CaptureState( ConditionState.BEFORE, trap ) continue if trap.patience_expired(sim.elapsed_sim_time): - capture_by_agent_id[agent_id] = CaptureState( + capture_by_agent_id[agent_id] = _CaptureState( ConditionState.EXPIRED, trap, updated_mission=trap.mission ) continue trap_condition = trap.evaluate(sim) if not trap_condition: - capture_by_agent_id[agent_id] = CaptureState(trap_condition, trap) + capture_by_agent_id[agent_id] = _CaptureState(trap_condition, trap) continue # Order vehicle ids by distance. @@ -269,7 +269,7 @@ def step(self, sim): if not point.within(trap.geometry): continue - capture_by_agent_id[agent_id] = CaptureState( + capture_by_agent_id[agent_id] = _CaptureState( ready_state=trap_condition, trap=trap, updated_mission=replace( @@ -281,7 +281,7 @@ def step(self, sim): social_vehicle_ids.remove(vehicle_id) break else: - capture_by_agent_id[agent_id] = CaptureState( + capture_by_agent_id[agent_id] = _CaptureState( ready_state=trap_condition, trap=trap, ) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 0e34533ed3..362f4a6c7b 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -621,6 +621,8 @@ class ConditionOperator(IntEnum): class ConditionRequires(IntFlag): + """This bitfield lays out the required information that a condition needs in order to evaluate.""" + none = enum.auto() # MISSION CONSTANTS @@ -643,6 +645,8 @@ class ConditionRequires(IntFlag): @dataclass(frozen=True) class ConditionEvaluationArgs: + """Standard arguments given to condition evaluations.""" + agent_id: Optional[str] mission: Optional[Any] time: Optional[float] @@ -984,6 +988,7 @@ def requires(self) -> ConditionRequires: @classmethod def loitering(cls: Type["VehicleSpeedCondition"], abs_error=0.01): + """Generates a speed condition which assumes that the subject is stationary.""" return cls(low=abs_error, high=abs_error) From 8d7e5bc6e9f7ca4e81b97f804bef5833e1188936 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 17:11:14 +0000 Subject: [PATCH 39/59] Exclude waymo init files in header gen. --- bin/gen_header.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/gen_header.sh b/bin/gen_header.sh index c2ac6032ee..7e859ccd3e 100644 --- a/bin/gen_header.sh +++ b/bin/gen_header.sh @@ -4,7 +4,7 @@ files="" if [[ $# -eq 0 ]]; then # No specific file specified, collate all files excluding auto generated files of *_pb2.py and *_pb2_grpc.py - python_files="$(find ./baselines/marl_benchmark ./cli ./envision ./smarts -name '*.py' ! -name '*_pb2.py' ! -name '*_pb2_grpc.py')" + python_files="$(find ./baselines/marl_benchmark ./cli ./envision ./smarts -name '*.py' ! -name '*_pb2.py' ! -name '*_pb2_grpc.py' ! -wholename '**/waymo/**/__init__.py')" js_files="$(find ./envision/web/src -name '*.js')" files="$python_files $js_files" else From 25840d7ff162de806845732a0af6029606cf68d8 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 17:33:02 +0000 Subject: [PATCH 40/59] Fix type issues. --- smarts/core/bubble_manager.py | 2 +- smarts/core/trap_manager.py | 20 ++++++++++---------- smarts/sstudio/types.py | 16 +++++++++++----- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/smarts/core/bubble_manager.py b/smarts/core/bubble_manager.py index fa085c9352..299e10466d 100644 --- a/smarts/core/bubble_manager.py +++ b/smarts/core/bubble_manager.py @@ -471,7 +471,7 @@ def is_active(bubble): @lru_cache(maxsize=2) def _vehicle_ids_divided_by_bubble_state( cursors: FrozenSet[Cursor], - ) -> Dict[BubbleState, Dict[Bubble, Set[Bubble]]]: + ) -> Dict[BubbleState, Dict[Bubble, Set[str]]]: vehicle_ids_grouped_by_cursor = defaultdict(lambda: defaultdict(set)) for cursor in cursors: vehicle_ids_grouped_by_cursor[cursor.state][cursor.bubble].add( diff --git a/smarts/core/trap_manager.py b/smarts/core/trap_manager.py index b28972d17e..e276bc1ac9 100644 --- a/smarts/core/trap_manager.py +++ b/smarts/core/trap_manager.py @@ -44,14 +44,14 @@ class Trap: """The trap area within which actors are considered for capture.""" mission: Mission """The mission that this trap should assign the captured actor.""" - exclusion_prefixes: Sequence[str] - """Prefixes of actors that should be ignored by this trap.""" activation_time: float """The amount of time left until this trap activates.""" patience: float """Patience to wait for better capture circumstances after which the trap expires.""" default_entry_speed: float """The default entry speed of a new vehicle should this trap expire.""" + entry_tactic: TrapEntryTactic + """The entry tactic that this trap was generated with.""" def ready(self, sim_time: float): """If the trap is ready to capture a vehicle.""" @@ -66,7 +66,7 @@ def patience_expired(self, sim_time: float): def includes(self, vehicle_id: str): """Returns if the given actor should be considered for capture.""" - for prefix in self.exclusion_prefixes: + for prefix in self.entry_tactic.exclusion_prefixes: if vehicle_id.startswith(prefix): return False return True @@ -85,8 +85,7 @@ def evaluate( Returns: ConditionState: The current state of the condition. """ - entry_tactic: TrapEntryTactic = self.mission.entry_tactic - return entry_tactic.condition.evaluate( + return self.entry_tactic.condition.evaluate( time=simulation.elapsed_sim_time, actor_ids=simulation.vehicle_index.vehicle_ids, vehicle_state=vehicle_state, @@ -357,11 +356,12 @@ def _mission2trap(self, road_map, mission: Mission, default_zone_dist: float = 6 if not (hasattr(mission, "start") and hasattr(mission, "goal")): raise ValueError(f"Value {mission} is not a mission!") - assert isinstance(mission.entry_tactic, TrapEntryTactic) + entry_tactic = mission.entry_tactic + assert isinstance(entry_tactic, TrapEntryTactic) - patience = mission.entry_tactic.wait_to_hijack_limit_s - zone = mission.entry_tactic.zone - default_entry_speed = mission.entry_tactic.default_entry_speed + patience = entry_tactic.wait_to_hijack_limit_s + zone = entry_tactic.zone + default_entry_speed = entry_tactic.default_entry_speed n_lane = None if default_entry_speed is None: @@ -404,8 +404,8 @@ def _mission2trap(self, road_map, mission: Mission, default_zone_dist: float = 6 activation_time=mission.start_time, patience=patience, mission=mission, - exclusion_prefixes=mission.entry_tactic.exclusion_prefixes, default_entry_speed=default_entry_speed, + entry_tactic=mission.entry_tactic, ) return trap diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 362f4a6c7b..9c27dfb65d 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -794,7 +794,7 @@ class TimeWindowCondition(Condition): end: float """The ending simulation time as of which this condition becomes expired.""" - def evaluate(self, *args, time, **kwargs): + def evaluate(self, *args, time, **kwargs) -> ConditionState: if self.start <= time < self.end or self.end == sys.maxsize: return ConditionState.TRUE elif time > self.end: @@ -813,7 +813,7 @@ class DependeeActorCondition(Condition): actor_id: str """The id of an actor in the simulation that needs to exist for this condition to be true.""" - def evaluate(self, *args, actor_ids, **kwargs): + def evaluate(self, *args, actor_ids, **kwargs) -> ConditionState: if self.actor_id in actor_ids: return ConditionState.TRUE return ConditionState.FALSE @@ -834,7 +834,9 @@ class NegatedCondition(Condition): """The inner condition to negate.""" def evaluate(self, *args, **kwargs) -> ConditionState: - return ~self.inner_condition.evaluate(*args, **kwargs) + result = ~self.inner_condition.evaluate(*args, **kwargs) + assert isinstance(result, ConditionState) + return result @property def requires(self) -> ConditionRequires: @@ -1022,7 +1024,7 @@ class CompoundCondition(Condition): operator: ConditionOperator """The operator used to combine these conditions.""" - def evaluate(self, *args, **kwargs): + def evaluate(self, *args, **kwargs) -> ConditionState: first_eval = self.first_condition.evaluate(*args, **kwargs) if self.operator == ConditionOperator.IMPLICATION and not first_eval: return ConditionState.TRUE @@ -1249,9 +1251,13 @@ class MapZone(Zone): n_lanes: int = 2 """The number of lanes from right to left that this zone covers.""" - def to_geometry(self, road_map: RoadMap) -> Polygon: + def to_geometry(self, road_map: Optional[RoadMap]) -> Polygon: """Generates a map zone over a stretch of the given lanes.""" + assert ( + road_map is not None + ), f"{self.__class__.__name__} requires a road map to resolve geometry." + def resolve_offset(offset, geometry_length, lane_length): if offset == "base": return 0 From 8bc0a57358e87a23b6be2002c97f6320d527569b Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 17:34:29 +0000 Subject: [PATCH 41/59] Update type condition tests. --- smarts/sstudio/tests/test_conditions.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index d430e9dcc6..45f90a9ebd 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -54,19 +54,27 @@ def test_condition_state(): assert ConditionState.TRUE assert (not ConditionState.TRUE) == False - assert not ~ConditionState.TRUE + assert ( + ConditionState.FALSE | ConditionState.BEFORE | ConditionState.EXPIRED + ) in ~ConditionState.TRUE assert ConditionState.FALSE == False assert not ConditionState.FALSE - assert ~ConditionState.FALSE + assert ( + ConditionState.TRUE | ConditionState.BEFORE | ConditionState.EXPIRED + ) in ~ConditionState.FALSE assert bool(ConditionState.EXPIRED) == False assert not ConditionState.EXPIRED - assert ~ConditionState.EXPIRED + assert ( + ConditionState.TRUE | ConditionState.BEFORE | ConditionState.FALSE + ) in ~ConditionState.EXPIRED assert bool(ConditionState.BEFORE) == False assert not ConditionState.BEFORE - assert ~ConditionState.BEFORE + assert ( + ConditionState.TRUE | ConditionState.FALSE | ConditionState.EXPIRED + ) in ~ConditionState.BEFORE assert ConditionState.TRUE | ConditionState.FALSE assert not ConditionState.TRUE & ConditionState.FALSE From 8f2de139d5c3d3c45c141f65e332b06d5d3a9cfd Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 19:05:56 +0000 Subject: [PATCH 42/59] Use flags to control condition evaluation parameters. --- smarts/core/actor_capture_manager.py | 91 ++++++++++++++++++++++++- smarts/core/id_actor_capture_manager.py | 25 ++++--- smarts/core/trap_manager.py | 31 +++------ smarts/sstudio/tests/test_conditions.py | 11 +-- smarts/sstudio/types.py | 72 +++++++++---------- 5 files changed, 151 insertions(+), 79 deletions(-) diff --git a/smarts/core/actor_capture_manager.py b/smarts/core/actor_capture_manager.py index be68fd429a..b894832988 100644 --- a/smarts/core/actor_capture_manager.py +++ b/smarts/core/actor_capture_manager.py @@ -23,10 +23,12 @@ import warnings from dataclasses import replace -from typing import Optional +from typing import Any, Dict, Optional +from smarts.core.actor import ActorState from smarts.core.plan import Plan from smarts.core.vehicle import Vehicle +from smarts.sstudio.types import ConditionRequires class ActorCaptureManager: @@ -114,3 +116,90 @@ def _take_existing_vehicle( sim.agent_manager.remove_pending_agent_ids({agent_id}) sim.create_vehicle_in_providers(vehicle, agent_id, True) return vehicle + + @classmethod + def _gen_all_condition_kwargs( + cls, + agent_id: str, + mission, + sim, + actor_state: ActorState, + condition_requires: ConditionRequires, + ): + return { + **cls._gen_mission_condition_kwargs(agent_id, mission, condition_requires), + **cls._gen_simulation_condition_kwargs(sim, condition_requires), + **cls._gen_actor_state_condition_args( + sim.road_map, actor_state, condition_requires + ), + } + + @staticmethod + def _gen_mission_condition_kwargs( + agent_id: str, mission, condition_requires: ConditionRequires + ) -> Dict[str, Any]: + out_kwargs = dict() + + if ( + ConditionRequires.any_mission_state & condition_requires + ) == ConditionRequires.none: + return out_kwargs + + if condition_requires.agent_id in condition_requires: + out_kwargs[ConditionRequires.agent_id.name] = agent_id + if ConditionRequires.mission in condition_requires: + out_kwargs[ConditionRequires.mission.name] = mission + return out_kwargs + + @staticmethod + def _gen_simulation_condition_kwargs( + sim, condition_requires: ConditionRequires + ) -> Dict[str, Any]: + out_kwargs = dict() + + if ( + ConditionRequires.any_simulation_state & condition_requires + ) == ConditionRequires.none: + return out_kwargs + + from smarts.core.smarts import SMARTS + + sim: SMARTS = sim + if ConditionRequires.time in condition_requires: + out_kwargs[ConditionRequires.time.name] = sim.elapsed_sim_time + if ConditionRequires.actor_ids in condition_requires: + out_kwargs[ConditionRequires.actor_ids.name] = sim.vehicle_index.vehicle_ids + if ConditionRequires.road_map in condition_requires: + out_kwargs[ConditionRequires.road_map.name] = sim.road_map + if ConditionRequires.actor_states in condition_requires: + out_kwargs[ConditionRequires.actor_states.name] = [ + v.state for v in sim.vehicle_index.vehicles + ] + if ConditionRequires.simulation in condition_requires: + out_kwargs[ConditionRequires.simulation.name] = sim + + return out_kwargs + + @staticmethod + def _gen_actor_state_condition_args( + road_map, + actor_state: Optional[ActorState], + condition_requires: ConditionRequires, + ) -> Dict[str, Any]: + out_kwargs = dict() + + if ( + ConditionRequires.any_current_actor_state & condition_requires + ) == ConditionRequires.none: + return out_kwargs + + from smarts.core.road_map import RoadMap + + assert isinstance(road_map, RoadMap) + + if ConditionRequires.current_actor_state in condition_requires: + out_kwargs[ConditionRequires.current_actor_state.name] = actor_state + if ConditionRequires.current_actor_road_status in condition_requires: + out_kwargs[ConditionRequires.current_actor_road_status.name] = None + + return out_kwargs diff --git a/smarts/core/id_actor_capture_manager.py b/smarts/core/id_actor_capture_manager.py index 67e1bab137..bf20e82471 100644 --- a/smarts/core/id_actor_capture_manager.py +++ b/smarts/core/id_actor_capture_manager.py @@ -62,13 +62,14 @@ def step(self, sim): entry_tactic = mission.entry_tactic assert isinstance(entry_tactic, IdEntryTactic) vehicle = sim.vehicle_index.vehicle_by_id(actor_id) - condition_result = entry_tactic.condition.evaluate( - time=sim.elapsed_sim_time, - actor_ids=sim.vehicle_index.vehicle_ids, - vehicle_state=vehicle.state if vehicle else None, - mission_start_time=mission.start_time, - mission=mission, + condition_kwargs = ActorCaptureManager._gen_all_condition_kwargs( + agent_id, + mission, + sim, + vehicle.state if vehicle is not None else None, + entry_tactic.condition.requires, ) + condition_result = entry_tactic.condition.evaluate(**condition_kwargs) if condition_result == ConditionState.EXPIRED: print(condition_result) self._log.warning( @@ -112,12 +113,14 @@ def reset(self, scenario, sim): if not isinstance(entry_tactic, IdEntryTactic): continue vehicle = sim.vehicle_index.vehicle_by_id(entry_tactic.actor_id, None) - condition_result = entry_tactic.condition.evaluate( - time=sim.elapsed_sim_time, - actor_ids=sim.vehicle_index.vehicle_ids, - vehicle_state=vehicle.state if vehicle else None, - mission_start_time=mission.start_time, + condition_kwargs = ActorCaptureManager._gen_all_condition_kwargs( + agent_id, + mission, + sim, + vehicle.state if vehicle is not None else None, + entry_tactic.condition.requires, ) + condition_result = entry_tactic.condition.evaluate(**condition_kwargs) if condition_result == ConditionState.EXPIRED: self._log.warning( f"Actor aquisition skipped for `{agent_id}` scheduled to start with" diff --git a/smarts/core/trap_manager.py b/smarts/core/trap_manager.py index e276bc1ac9..823995df60 100644 --- a/smarts/core/trap_manager.py +++ b/smarts/core/trap_manager.py @@ -71,27 +71,6 @@ def includes(self, vehicle_id: str): return False return True - def evaluate( - self, - simulation, - vehicle_state: Optional[Any] = None, - ) -> ConditionState: - """Considers the given vehicle to see if it is applicable. - - Args: - simulation (SMARTS): The simulation reference - vehicle_state (VehicleState): The current vehicle state. - - Returns: - ConditionState: The current state of the condition. - """ - return self.entry_tactic.condition.evaluate( - time=simulation.elapsed_sim_time, - actor_ids=simulation.vehicle_index.vehicle_ids, - vehicle_state=vehicle_state, - mission_start_time=self.mission.start_time, - ) - @dataclass(frozen=True) class _CaptureState: @@ -246,7 +225,15 @@ def step(self, sim): ) continue - trap_condition = trap.evaluate(sim) + sim_eval_kwargs = ActorCaptureManager._gen_simulation_condition_kwargs( + sim, trap.entry_tactic.condition.requires + ) + mission_eval_kwargs = ActorCaptureManager._gen_mission_condition_kwargs( + agent_id, trap.mission, trap.entry_tactic.condition.requires + ) + trap_condition = trap.entry_tactic.condition.evaluate( + **sim_eval_kwargs, **mission_eval_kwargs + ) if not trap_condition: capture_by_agent_id[agent_id] = _CaptureState(trap_condition, trap) continue diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index 45f90a9ebd..63cca9582a 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -27,6 +27,7 @@ CompoundCondition, Condition, ConditionOperator, + ConditionRequires, ConditionState, ConditionTrigger, DependeeActorCondition, @@ -346,9 +347,9 @@ def test_vehicle_speed_condition(): fast_vehicle_state = Mock() fast_vehicle_state.speed = high + 50 - assert not vehicle_speed_condition.evaluate(vehicle_state=slow_vehicle_state) - assert vehicle_speed_condition.evaluate(vehicle_state=between_vehicle_state) - assert not vehicle_speed_condition.evaluate(vehicle_state=fast_vehicle_state) + assert not vehicle_speed_condition.evaluate(current_actor_state=slow_vehicle_state) + assert vehicle_speed_condition.evaluate(current_actor_state=between_vehicle_state) + assert not vehicle_speed_condition.evaluate(current_actor_state=fast_vehicle_state) def test_vehicle_type_condition(): @@ -360,5 +361,5 @@ def test_vehicle_type_condition(): truck_vehicle_state = Mock() truck_vehicle_state.vehicle_config_type = "truck" - assert vehicle_type_condition.evaluate(vehicle_state=passenger_vehicle_state) - assert not vehicle_type_condition.evaluate(vehicle_state=truck_vehicle_state) + assert vehicle_type_condition.evaluate(current_actor_state=passenger_vehicle_state) + assert not vehicle_type_condition.evaluate(current_actor_state=truck_vehicle_state) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 9c27dfb65d..e2a5bdaf98 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -623,7 +623,7 @@ class ConditionOperator(IntEnum): class ConditionRequires(IntFlag): """This bitfield lays out the required information that a condition needs in order to evaluate.""" - none = enum.auto() + none = 0 # MISSION CONSTANTS agent_id = enum.auto() @@ -633,6 +633,7 @@ class ConditionRequires(IntFlag): time = enum.auto() actor_ids = enum.auto() actor_states = enum.auto() + road_map = enum.auto() simulation = enum.auto() # ACTOR STATE @@ -641,27 +642,14 @@ class ConditionRequires(IntFlag): any_simulation_state = time | actor_ids | actor_states | simulation any_current_actor_state = mission | current_actor_state | current_actor_road_status - - -@dataclass(frozen=True) -class ConditionEvaluationArgs: - """Standard arguments given to condition evaluations.""" - - agent_id: Optional[str] - mission: Optional[Any] - time: Optional[float] - actor_ids: Optional[Set[str]] - actor_states: Optional[List[Any]] - simulation: Any - current_actor_state: Optional[Any] - current_actor_road_status: Optional[Any] + any_mission_state = agent_id | mission @dataclass(frozen=True) class Condition: """This encompasses an expression to evaluate to a logical result.""" - def evaluate(self, *args, **kwargs) -> ConditionState: + def evaluate(self, **kwargs) -> ConditionState: """Used to evaluate if a condition is met. Returns: @@ -749,7 +737,7 @@ def __neg__(self) -> "NegatedCondition": class SubjectCondition(Condition): """This condition assumes that there is a subject involved.""" - def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: + def evaluate(self, **kwargs) -> ConditionState: """Used to evaluate if a condition is met. Args: @@ -761,10 +749,7 @@ def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: @property def requires(self) -> ConditionRequires: - return ( - ConditionRequires.any_current_actor_state - | ConditionRequires.any_simulation_state - ) + return ConditionRequires.current_actor_state _abstract_conditions = (Condition, SubjectCondition) @@ -777,7 +762,7 @@ class LiteralCondition(Condition): literal: ConditionState """The literal value of this condition.""" - def evaluate(self, *args, **kwargs) -> ConditionState: + def evaluate(self, **kwargs) -> ConditionState: return self.literal @property @@ -794,7 +779,8 @@ class TimeWindowCondition(Condition): end: float """The ending simulation time as of which this condition becomes expired.""" - def evaluate(self, *args, time, **kwargs) -> ConditionState: + def evaluate(self, **kwargs) -> ConditionState: + time = kwargs[ConditionRequires.time.name] if self.start <= time < self.end or self.end == sys.maxsize: return ConditionState.TRUE elif time > self.end: @@ -813,7 +799,8 @@ class DependeeActorCondition(Condition): actor_id: str """The id of an actor in the simulation that needs to exist for this condition to be true.""" - def evaluate(self, *args, actor_ids, **kwargs) -> ConditionState: + def evaluate(self, **kwargs) -> ConditionState: + actor_ids = kwargs[self.requires.name] if self.actor_id in actor_ids: return ConditionState.TRUE return ConditionState.FALSE @@ -833,8 +820,8 @@ class NegatedCondition(Condition): inner_condition: Condition """The inner condition to negate.""" - def evaluate(self, *args, **kwargs) -> ConditionState: - result = ~self.inner_condition.evaluate(*args, **kwargs) + def evaluate(self, **kwargs) -> ConditionState: + result = ~self.inner_condition.evaluate(**kwargs) assert isinstance(result, ConditionState) return result @@ -862,12 +849,13 @@ class ExpireTrigger(Condition): expired_state: ConditionState = ConditionState.EXPIRED """The state value this trigger should have when it expires.""" - def evaluate(self, *args, time, **kwargs) -> ConditionState: + def evaluate(self, **kwargs) -> ConditionState: + time = kwargs[ConditionRequires.time.name] if time >= self.time: return self.expired_state - return self.inner_condition.evaluate(*args, time=time, **kwargs) + return self.inner_condition.evaluate(**kwargs) - @property + @cached_property def requires(self) -> ConditionRequires: return self.inner_condition.requires | ConditionRequires.time @@ -905,16 +893,17 @@ class ConditionTrigger(Condition): persistant: bool = False """If the inner condition state is used in conjuction with the triggered state. (inner_condition_state & triggered_state)""" - def evaluate(self, *args, time, **kwargs) -> ConditionState: + def evaluate(self, **kwargs) -> ConditionState: + time = kwargs[ConditionRequires.time.name] key = "met_time" result = self.untriggered_state if self.delay_seconds <= 0 or (met_time := getattr(self, key, -1)) > -1: if time >= met_time + self.delay_seconds: result = self.triggered_state if self.persistant: - result &= self.inner_condition.evaluate(*args, time=time, **kwargs) + result &= self.inner_condition.evaluate(**kwargs) return result - elif result := self.inner_condition.evaluate(*args, time=time, **kwargs): + elif result := self.inner_condition.evaluate(**kwargs): object.__setattr__(self, key, time) temporals = result & (ConditionState.EXPIRED) @@ -924,7 +913,7 @@ def evaluate(self, *args, time, **kwargs) -> ConditionState: @property def requires(self) -> ConditionRequires: - return self.inner_condition.requires + return self.inner_condition.requires | ConditionRequires.time def __post_init__(self): if self.inner_condition.__class__ in _abstract_conditions: @@ -937,7 +926,8 @@ def __post_init__(self): class OnRoadCondition(SubjectCondition): """This condition is true if the subject is on road.""" - def evaluate(self, *args, current_actor_road_status, **kwargs) -> ConditionState: + def evaluate(self, **kwargs) -> ConditionState: + current_actor_road_status = kwargs[self.requires.name] return ( ConditionState.TRUE if current_actor_road_status.on_road @@ -955,10 +945,11 @@ class VehicleTypeCondition(SubjectCondition): vehicle_type: str - def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: + def evaluate(self, **kwargs) -> ConditionState: + current_actor_state = kwargs[self.requires.name] return ( ConditionState.TRUE - if vehicle_state.vehicle_config_type == self.vehicle_type + if current_actor_state.vehicle_config_type == self.vehicle_type else ConditionState.FALSE ) @@ -977,7 +968,8 @@ class VehicleSpeedCondition(SubjectCondition): high: float """The highest speed allowed.""" - def evaluate(self, *args, vehicle_state, **kwargs) -> ConditionState: + def evaluate(self, **kwargs) -> ConditionState: + vehicle_state = kwargs[self.requires.name] return ( ConditionState.TRUE if self.low <= vehicle_state.speed <= self.high @@ -1024,12 +1016,12 @@ class CompoundCondition(Condition): operator: ConditionOperator """The operator used to combine these conditions.""" - def evaluate(self, *args, **kwargs) -> ConditionState: - first_eval = self.first_condition.evaluate(*args, **kwargs) + def evaluate(self, **kwargs) -> ConditionState: + first_eval = self.first_condition.evaluate(**kwargs) if self.operator == ConditionOperator.IMPLICATION and not first_eval: return ConditionState.TRUE - second_eval = self.second_condition.evaluate(*args, **kwargs) + second_eval = self.second_condition.evaluate(**kwargs) if ( self.operator == ConditionOperator.IMPLICATION and first_eval From e6e3c46dfb5409632b5559ac7374bf77d6f2d52c Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 20:06:22 +0000 Subject: [PATCH 43/59] Fix condition trigger. --- smarts/sstudio/types.py | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index e2a5bdaf98..e1cedbaa34 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -685,7 +685,28 @@ def implication(self, other: "Condition") -> "CompoundCondition": def trigger( self, delay_seconds: float, persistant: bool = False ) -> "ConditionTrigger": - """Converts the condition to a trigger which becomes permanently TRUE after the inner condition becomes TRUE. + """Converts the condition to a trigger which becomes permanently TRUE after the first time the inner condition becomes TRUE. + + >>> trigger = TimeWindowCondition(2, 5).trigger(delay_seconds=0) + >>> trigger.evaluate(time=1) + + >>> trigger.evaluate(time=4) + + >>> trigger.evaluate(time=90) + + + >>> start_time = 5 + >>> between_time = 10 + >>> delay_seconds = 20 + >>> trigger = LiteralCondition(ConditionState.TRUE).trigger(delay_seconds=delay_seconds) + >>> trigger.evaluate(time=start_time) + + >>> trigger.evaluate(time=between_time) + + >>> trigger.evaluate(time=start_time + delay_seconds) + + >>> trigger.evaluate(time=between_time) + Args: delay_seconds (float): Applies the trigger after the delay has passed since the inner condition first TRUE. Defaults to False. @@ -703,9 +724,9 @@ def expire(self, time, expired_state=ConditionState.EXPIRED) -> "ExpireTrigger": >>> trigger = LiteralCondition(ConditionState.TRUE).expire(20) >>> trigger.evaluate(time=10) - ConditionState.TRUE + >>> trigger.evaluate(time=30) - ConditionState.FALSE + Args: time (float): The simulation time when this trigger changes. @@ -870,9 +891,9 @@ def __post_init__(self): class ConditionTrigger(Condition): """This condition is a trigger that assumes an untriggered constant state and then turns to the other state permanently on the inner condition becoming TRUE. There is also an option to delay repsonse to the the inner condition by a number - of seconds. This will convey an EXPIRED value immediately because that state means the inner value will never be true. + of seconds. This will convey an EXPIRED value immediately because that state means the inner value will never be TRUE. - This can be used to wait for some time after the inner condition has become true to be true. + This can be used to wait for some time after the inner condition has become TRUE to trigger. Note that the original condition may no longer be true by the time delay has expired. This will never resolve TRUE on the first evaluate. @@ -897,14 +918,17 @@ def evaluate(self, **kwargs) -> ConditionState: time = kwargs[ConditionRequires.time.name] key = "met_time" result = self.untriggered_state - if self.delay_seconds <= 0 or (met_time := getattr(self, key, -1)) > -1: + met_time = getattr(self, key, -1) + if met_time == -1 and self.inner_condition.evaluate(**kwargs): + object.__setattr__(self, key, time) + if met_time != -1 or ( + self.delay_seconds == 0 and self.inner_condition.evaluate(**kwargs) + ): if time >= met_time + self.delay_seconds: result = self.triggered_state if self.persistant: result &= self.inner_condition.evaluate(**kwargs) return result - elif result := self.inner_condition.evaluate(**kwargs): - object.__setattr__(self, key, time) temporals = result & (ConditionState.EXPIRED) if ConditionState.EXPIRED in temporals: From f0e896b8ff623ef28a2ca1e3a055e6440e77e876 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 21:16:48 +0000 Subject: [PATCH 44/59] Address type issues. --- smarts/core/actor_capture_manager.py | 12 +++++++++- smarts/core/bubble_manager.py | 12 ++++------ smarts/core/road_map.py | 2 +- smarts/sstudio/tests/test_conditions.py | 31 +++++++++++++++++++++---- smarts/sstudio/types.py | 9 +++++-- 5 files changed, 51 insertions(+), 15 deletions(-) diff --git a/smarts/core/actor_capture_manager.py b/smarts/core/actor_capture_manager.py index b894832988..9982d7e3da 100644 --- a/smarts/core/actor_capture_manager.py +++ b/smarts/core/actor_capture_manager.py @@ -22,6 +22,7 @@ import warnings +from collections import namedtuple from dataclasses import replace from typing import Any, Dict, Optional @@ -200,6 +201,15 @@ def _gen_actor_state_condition_args( if ConditionRequires.current_actor_state in condition_requires: out_kwargs[ConditionRequires.current_actor_state.name] = actor_state if ConditionRequires.current_actor_road_status in condition_requires: - out_kwargs[ConditionRequires.current_actor_road_status.name] = None + current_actor_road_status = namedtuple( + "actor_road_status", ["road", "off_road"], defaults=[None, False] + ) + if hasattr(actor_state, "pose"): + road = road_map.road_with_point(actor_state.pose.point) + current_actor_road_status.road = road + current_actor_road_status.off_road = not road + out_kwargs[ + ConditionRequires.current_actor_road_status.name + ] = current_actor_road_status return out_kwargs diff --git a/smarts/core/bubble_manager.py b/smarts/core/bubble_manager.py index 299e10466d..cf21e30976 100644 --- a/smarts/core/bubble_manager.py +++ b/smarts/core/bubble_manager.py @@ -295,7 +295,7 @@ def __hash__(self) -> int: return hash(self.id) -@dataclass(frozen=True, eq=True) +@dataclass(frozen=True) class Cursor: """Tracks an actor through an airlock or a bubble.""" @@ -306,9 +306,8 @@ class Cursor: transition: Optional[BubbleTransition] = None bubble: Optional[Bubble] = None - @classmethod + @staticmethod def for_removed( - cls, vehicle_id: str, bubble: Bubble, index: VehicleIndex, @@ -332,16 +331,15 @@ def for_removed( transition = None if was_in_this_bubble and (is_shadowed or is_hijacked): transition = BubbleTransition.AirlockExited - return cls( + return Cursor( vehicle_id=vehicle_id, transition=transition, state=BubbleState.WasInBubble, bubble=bubble, ) - @classmethod + @staticmethod def from_pos( - cls, pos: Point, vehicle_id: str, bubble: Bubble, @@ -409,7 +407,7 @@ def from_pos( elif in_airlock_zone: state = BubbleState.InAirlock - return cls( + return Cursor( vehicle_id=vehicle_id, transition=transition, state=state, bubble=bubble ) diff --git a/smarts/core/road_map.py b/smarts/core/road_map.py index b2026829e2..86fe07ce74 100644 --- a/smarts/core/road_map.py +++ b/smarts/core/road_map.py @@ -135,7 +135,7 @@ def nearest_lane( nearest_lanes = self.nearest_lanes(point, radius, include_junctions) return nearest_lanes[0][0] if nearest_lanes else None - def road_with_point(self, point: Point) -> RoadMap.Road: + def road_with_point(self, point: Point) -> Optional[RoadMap.Road]: """Find the road that contains the given point.""" raise NotImplementedError() diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index 63cca9582a..08b276af08 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -34,7 +34,7 @@ ExpireTrigger, LiteralCondition, NegatedCondition, - OnRoadCondition, + OffRoadCondition, SubjectCondition, TimeWindowCondition, VehicleSpeedCondition, @@ -288,9 +288,32 @@ def test_negated_condition(): assert not literal_true.negation().evaluate() -def test_on_road_condition(): - on_road_condition = OnRoadCondition() - pass +def test_off_road_condition(): + off_road_condition = OffRoadCondition() + + current_actor_road_status = Mock() + current_actor_road_status.off_road = False + current_actor_road_status.road = "c-ew" + assert ( + off_road_condition.evaluate(current_actor_road_status=current_actor_road_status) + == ConditionState.FALSE + ) + + current_actor_road_status = Mock() + current_actor_road_status.off_road = True + current_actor_road_status.road = None + assert ( + off_road_condition.evaluate(current_actor_road_status=current_actor_road_status) + == ConditionState.TRUE + ) + + current_actor_road_status = Mock() + current_actor_road_status.off_road = False + current_actor_road_status.road = None + assert ( + off_road_condition.evaluate(current_actor_road_status=current_actor_road_status) + == ConditionState.BEFORE + ) def test_time_window_condition(): diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index e1cedbaa34..0ebe0e7fcb 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -947,14 +947,19 @@ def __post_init__(self): @dataclass(frozen=True) -class OnRoadCondition(SubjectCondition): +class OffRoadCondition(SubjectCondition): """This condition is true if the subject is on road.""" def evaluate(self, **kwargs) -> ConditionState: current_actor_road_status = kwargs[self.requires.name] + if ( + current_actor_road_status.road is None + and not current_actor_road_status.off_road + ): + return ConditionState.BEFORE return ( ConditionState.TRUE - if current_actor_road_status.on_road + if current_actor_road_status.off_road else ConditionState.FALSE ) From 2991d669c0961178959f8130fa8afc905a0bc665 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 21:42:08 +0000 Subject: [PATCH 45/59] Update changelog. --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 296bd5f072..53c60ef247 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Copy and pasting the git commit messages is __NOT__ enough. ### Added - `visdom` can now be configured through the engine.ini configuration file `visdom:enabled`, `visdom:hostname`, and `visdom:port` (environment variables `SMARTS_VISDOM_ENABLED`, `SMARTS_VISDOM_HOSTNAME`, `SMARTS_VISDOM_PORT`.) - Added an install extra that installs the requirements for all optional modules. Use `pip install .[all]`. +- Added `Condition`, `ConditionRequires`, `ConditionState` and various condition implementations to enable logical operations in scenarios. ### Changed - Changed waypoints in sumo maps to use more incoming lanes into junctions. - Increased the cutoff radius for filtering out waypoints that are too far away in junctions in sumo maps. @@ -20,6 +21,7 @@ Copy and pasting the git commit messages is __NOT__ enough. - `SumoTrafficSimulator` now uses the last vehicle subscription update to back `route_for_vehicle()`. This means that the routes of vehicles can still be determined even if `SumoTrafficSimulation` disconnects. - Reward function in platoon RL example retrieves actor-of-interest from marked neighborhood vehicles. - `dist_to_destination` metric cost function computes the route distance and end point for vehicle of interest contained in social agents, SMARTS traffic provider, SUMO traffic provider, and traffic history provider. +- `sstudio` generated scenario vehicle traffic IDs are now shortened. ### Deprecated - `visdom` is set to be removed from the SMARTS object parameters. ### Fixed @@ -54,7 +56,6 @@ Copy and pasting the git commit messages is __NOT__ enough. - Documented the challenge objective, desired inference code structure, and use of baseline example, for Driving SMARTS 2023.1 (i.e., basic motion planning) and 2023.2 (i.e, turns) benchmarks. - Added an env wrapper for constraining the relative target pose action range. - Added a specialised metric formula module for Driving SMARTS 2023.1 and 2023.2 benchmark. -- Added representation interface `Condition` and `ConditionState` for conditions to scenario studio. ### Changed - The trap manager, `TrapManager`, is now a subclass of `ActorCaptureManager`. - Considering lane-change time ranges between 3s and 6s, assuming a speed of 13.89m/s, the via sensor lane acquisition range was increased from 40m to 80m, for better driving ability. @@ -84,7 +85,6 @@ Copy and pasting the git commit messages is __NOT__ enough. - Driving SMARTS 2023.3 benchmark and the metrics module now uses `actor_of_interest_re_filter` from scenario metadata to identify the lead vehicle. - Included `RelativeTargetPose` action space to the set of allowed action spaces in `platoon-v0` env. - `Collision.collidee_id` now gives the vehicle id rather than the name of the owner of the vehicle (usually the agent id.) `Collision.collidee_owner_id` now provides the id of the controlling `agent` (or other controlling entity in the future.) This is because 1) `collidee_id` should refer to the body and 2) in most cases the owner name would be `None`. -- `sstudio` generated scenario vehicle traffic IDs are now shortened. - Entry tactics now use conditions to determine when they should capture an actor. ### Deprecated ### Fixed From 469c6204093a214aa7d6d9fecf23d972f3538ad3 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 11 May 2023 22:18:46 +0000 Subject: [PATCH 46/59] Fix scenarios. --- .../scenario.py | 2 +- .../scenario.py | 58 +++++++++---------- .../sumo/merge/3lane_agents_1/scenario.py | 6 +- .../sumo/merge/3lane_agents_2/scenario.py | 4 +- .../3lane_cruise_agents_1/scenario.py | 2 +- .../3lane_cruise_agents_3/scenario.py | 6 +- 6 files changed, 36 insertions(+), 42 deletions(-) diff --git a/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py b/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py index 3dfcacb901..688f19bc43 100644 --- a/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py +++ b/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py @@ -95,7 +95,7 @@ Mission( route=route, entry_tactic=TrapEntryTactic( - start_time=4, wait_to_hijack_limit_s=0.1 + start_time=4, wait_to_hijack_limit_s=0 ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/intersections/1_to_2lane_left_turn_t_agents_1/scenario.py b/scenarios/sumo/intersections/1_to_2lane_left_turn_t_agents_1/scenario.py index 6a12f07434..420537e82d 100644 --- a/scenarios/sumo/intersections/1_to_2lane_left_turn_t_agents_1/scenario.py +++ b/scenarios/sumo/intersections/1_to_2lane_left_turn_t_agents_1/scenario.py @@ -37,61 +37,59 @@ name="car", ) -# flow_name = (start_lane, end_lane) -route_opt = [ - (0, 0), - (1, 1), - (2, 2), +horizontal_routes = [ + ("E4", 0, "E1", 0), + ("E4", 1, "E1", 1), + ("-E1", 0, "-E4", 0), + ("-E1", 1, "-E4", 1), ] -# Traffic combinations = 3C2 + 3C3 = 3 + 1 = 4 -# Repeated traffic combinations = 4 * 100 = 400 -min_flows = 2 -max_flows = 3 -route_comb = [ - com - for elems in range(min_flows, max_flows + 1) - for com in combinations(route_opt, elems) -] * 100 +turn_left_routes = [ + ("E0", 0, "E1", 1), + ("E4", 1, "-E0", 0), +] + +turn_right_routes = [ + ("E0", 0, "-E4", 0), + ("-E1", 0, "-E0", 0), +] +# Total route combinations = 8C1 + 8C2 + 8C3 + 8C4 + 8C5 = 218 +# Repeated route combinations = 218 * 2 = 436 +all_routes = horizontal_routes + turn_left_routes + turn_right_routes +route_comb = [ + com for elems in range(1, 6) for com in combinations(all_routes, elems) +] * 2 traffic = {} for name, routes in enumerate(route_comb): traffic[str(name)] = Traffic( flows=[ Flow( route=Route( - begin=("gneE3", start_lane, 0), - end=("gneE4", end_lane, "max"), + begin=(start_edge, start_lane, 0), + end=(end_edge, end_lane, "max"), ), # Random flow rate, between x and y vehicles per minute. - rate=60 * random.uniform(10, 20), + rate=60 * random.uniform(5, 10), # Random flow start time, between x and y seconds. - begin=random.uniform(0, 5), + begin=random.uniform(0, 3), # For an episode with maximum_episode_steps=3000 and step # time=0.1s, the maximum episode time=300s. Hence, traffic is # set to end at 900s, which is greater than maximum episode # time of 300s. end=60 * 15, actors={normal: 1}, - randomly_spaced=True, ) - for start_lane, end_lane in routes + for start_edge, start_lane, end_edge, end_lane in routes ] ) - +route = Route(begin=("E0", 0, 5), end=("E1", 0, "max")) ego_missions = [ Mission( - Route(begin=("gneE6", 0, 10), end=("gneE4", 2, "max")), - entry_tactic=TrapEntryTactic( - entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=0.1), - wait_to_hijack_limit_s=1, - ), - ), - Mission( - Route(begin=("gneE3", 0, 10), end=("gneE4", 0, "max")), + route=route, entry_tactic=TrapEntryTactic( - entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=0.1), + start_time=4, wait_to_hijack_limit_s=1, ), ), diff --git a/scenarios/sumo/merge/3lane_agents_1/scenario.py b/scenarios/sumo/merge/3lane_agents_1/scenario.py index 264247876e..1ad7795389 100644 --- a/scenarios/sumo/merge/3lane_agents_1/scenario.py +++ b/scenarios/sumo/merge/3lane_agents_1/scenario.py @@ -83,11 +83,7 @@ ego_missions = [ Mission( Route(begin=("gneE6", 0, 10), end=("gneE4", 2, "max")), - entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=1), - ), - Mission( - Route(begin=("gneE3", 0, 10), end=("gneE4", 0, "max")), - entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=1), + entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=0), ), ] diff --git a/scenarios/sumo/merge/3lane_agents_2/scenario.py b/scenarios/sumo/merge/3lane_agents_2/scenario.py index 264247876e..2c92b1af96 100644 --- a/scenarios/sumo/merge/3lane_agents_2/scenario.py +++ b/scenarios/sumo/merge/3lane_agents_2/scenario.py @@ -83,11 +83,11 @@ ego_missions = [ Mission( Route(begin=("gneE6", 0, 10), end=("gneE4", 2, "max")), - entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=1), + entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=0), ), Mission( Route(begin=("gneE3", 0, 10), end=("gneE4", 0, "max")), - entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=1), + entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=0), ), ] diff --git a/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py b/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py index 396cbd230e..1728332e76 100644 --- a/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py +++ b/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py @@ -85,7 +85,7 @@ Mission( route=route, entry_tactic=TrapEntryTactic( - start_time=17, wait_to_hijack_limit_s=0.1 + start_time=17, wait_to_hijack_limit_s=0 ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py b/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py index 8a8b5ad2e5..35e2fdbe48 100644 --- a/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py +++ b/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py @@ -82,15 +82,15 @@ ego_missions = [ Mission( Route(begin=("gneE3", 0, 10), end=("gneE3", 0, "max")), - entry_tactic=TrapEntryTactic(start_time=19, wait_to_hijack_limit_s=0.1), + entry_tactic=TrapEntryTactic(start_time=19, wait_to_hijack_limit_s=0), ), Mission( Route(begin=("gneE3", 1, 10), end=("gneE3", 1, "max")), - entry_tactic=TrapEntryTactic(start_time=21, wait_to_hijack_limit_s=0.1), + entry_tactic=TrapEntryTactic(start_time=21, wait_to_hijack_limit_s=0), ), Mission( Route(begin=("gneE3", 2, 10), end=("gneE3", 2, "max")), - entry_tactic=TrapEntryTactic(start_time=17, wait_to_hijack_limit_s=0.1), + entry_tactic=TrapEntryTactic(start_time=17, wait_to_hijack_limit_s=0), ), ] From f2436b284475e636c784983771df5fdb623008d7 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 12 May 2023 12:59:47 +0000 Subject: [PATCH 47/59] Remove erroneous imports. --- smarts/core/bubble_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/smarts/core/bubble_manager.py b/smarts/core/bubble_manager.py index cf21e30976..455fcaba20 100644 --- a/smarts/core/bubble_manager.py +++ b/smarts/core/bubble_manager.py @@ -19,14 +19,13 @@ # THE SOFTWARE. import logging import math -from builtins import classmethod from collections import defaultdict from copy import deepcopy from dataclasses import dataclass from enum import Enum from functools import lru_cache from sys import maxsize -from typing import Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, Union +from typing import Dict, FrozenSet, List, Optional, Sequence, Set, Tuple from shapely.affinity import rotate, translate from shapely.geometry import CAP_STYLE, JOIN_STYLE, Point, Polygon From 4e47c670a6e7fc14b46c3bb111ce4ce57b347782 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 12 May 2023 16:27:19 +0000 Subject: [PATCH 48/59] Short circuit condition operators. --- smarts/sstudio/types.py | 111 +++++++++++++++++++++++++++++++--------- 1 file changed, 88 insertions(+), 23 deletions(-) diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 0ebe0e7fcb..222e08b84b 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -667,19 +667,67 @@ def requires(self) -> ConditionRequires: raise NotImplementedError() def negation(self) -> "NegatedCondition": - """Negates this condition.""" + """Negates this condition giving the opposite result on evaluation. + + >>> condition_true = LiteralCondition(ConditionState.TRUE) + >>> condition_true.evaluate() + + >>> condition_false = condition_true.negation() + >>> condition_false.evaluate() + + + Note\\: This erases temporal values EXPIRED and BEFORE. + >>> condition_before = LiteralCondition(ConditionState.BEFORE) + >>> condition_before.negation().negation().evaluate() + + + Returns: + NegatedCondition: The wrapped condition. + """ return NegatedCondition(self) def conjunction(self, other: "Condition") -> "CompoundCondition": - """Resolve conditions as A AND B.""" + """Resolve conditions as A AND B. + + The bit AND operator has been overloaded to call this method. + >>> dependee_condition = DependeeActorCondition("leader") + >>> dependee_condition.evaluate(actor_ids={"leader"}) + + >>> conjunction = dependee_condition & LiteralCondition(ConditionState.FALSE) + >>> conjunction.evaluate(actor_ids={"leader"}) + + + Note that the resolution has the priority EXPIRED > BEFORE > FALSE > TRUE. + >>> conjunction = LiteralCondition(ConditionState.TRUE) & LiteralCondition(ConditionState.BEFORE) + >>> conjunction.evaluate() + + >>> (conjunction & LiteralCondition(ConditionState.EXPIRED)).evaluate() + + + Returns: + CompoundCondition: A condition combining two conditions using an AND operation. + """ return CompoundCondition(self, other, operator=ConditionOperator.CONJUNCTION) def disjunction(self, other: "Condition") -> "CompoundCondition": - """Resolve conditions as A OR B.""" + """Resolve conditions as A OR B. + + The bit OR operator has been overloaded to call this method. + >>> disjunction = LiteralCondition(ConditionState.TRUE) | LiteralCondition(ConditionState.BEFORE) + >>> disjunction.evaluate() + + + Note that the resolution has the priority TRUE > BEFORE > FALSE > EXPIRED. + >>> disjunction = LiteralCondition(ConditionState.FALSE) | LiteralCondition(ConditionState.EXPIRED) + >>> disjunction.evaluate() + + >>> (disjunction | LiteralCondition(ConditionState.BEFORE)).evaluate() + + """ return CompoundCondition(self, other, operator=ConditionOperator.DISJUNCTION) def implication(self, other: "Condition") -> "CompoundCondition": - """Resolve conditions as A AND B OR NOT A.""" + """Resolve conditions as A IMPLIES B. This is the same as A AND B OR NOT A.""" return CompoundCondition(self, other, operator=ConditionOperator.IMPLICATION) def trigger( @@ -740,7 +788,7 @@ def expire(self, time, expired_state=ConditionState.EXPIRED) -> "ExpireTrigger": ) def __and__(self, other: "Condition") -> "CompoundCondition": - """Resolve conditions as A AND B""" + """Resolve conditions as A AND B.""" assert isinstance(other, Condition) return self.conjunction(other) @@ -836,15 +884,19 @@ def __post_init__(self): @dataclass(frozen=True) class NegatedCondition(Condition): - """This condition negates the inner condition.""" + """This condition negates the inner condition to flip between TRUE and FALSE. + + Note\\: This erases temporal values EXPIRED and BEFORE. + """ inner_condition: Condition """The inner condition to negate.""" def evaluate(self, **kwargs) -> ConditionState: - result = ~self.inner_condition.evaluate(**kwargs) - assert isinstance(result, ConditionState) - return result + result = self.inner_condition.evaluate(**kwargs) + if ConditionState.TRUE in result: + return ConditionState.FALSE + return ConditionState.TRUE @property def requires(self) -> ConditionRequires: @@ -1046,34 +1098,49 @@ class CompoundCondition(Condition): """The operator used to combine these conditions.""" def evaluate(self, **kwargs) -> ConditionState: + # Short circuits first_eval = self.first_condition.evaluate(**kwargs) - if self.operator == ConditionOperator.IMPLICATION and not first_eval: + if ( + self.operator == ConditionOperator.CONJUNCTION + and ConditionState.EXPIRED in first_eval + ): + return ConditionState.EXPIRED + elif ( + self.operator == ConditionOperator.DISJUNCTION + and ConditionState.TRUE in first_eval + ): + return ConditionState.TRUE + elif ( + self.operator == ConditionOperator.IMPLICATION + and ConditionState.TRUE not in first_eval + ): return ConditionState.TRUE second_eval = self.second_condition.evaluate(**kwargs) if ( self.operator == ConditionOperator.IMPLICATION - and first_eval - and second_eval + and ConditionState.TRUE in first_eval + and ConditionState.TRUE in second_eval ): return ConditionState.TRUE - if self.operator == ConditionOperator.CONJUNCTION: - result = first_eval & second_eval - if result: + elif self.operator == ConditionOperator.CONJUNCTION: + conjuction = first_eval & second_eval + if ConditionState.TRUE in conjuction: return ConditionState.TRUE - temporals = (first_eval | second_eval) & ( - ConditionState.BEFORE | ConditionState.EXPIRED - ) - if ConditionState.EXPIRED in temporals: + # To priority of temporal versions of FALSE + disjunction = first_eval | second_eval + if ConditionState.EXPIRED in disjunction: return ConditionState.EXPIRED - return temporals + if ConditionState.BEFORE in disjunction: + return ConditionState.BEFORE + elif self.operator == ConditionOperator.DISJUNCTION: result = first_eval | second_eval - if result: + if ConditionState.TRUE in result: return ConditionState.TRUE if ConditionState.BEFORE in result: @@ -1082,8 +1149,6 @@ def evaluate(self, **kwargs) -> ConditionState: if ConditionState.EXPIRED in first_eval & second_eval: return ConditionState.EXPIRED - return ConditionState.FALSE - return ConditionState.FALSE @cached_property From 32f3999806eef9b91fcc37dde04b9e9b4e3dbc15 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 12 May 2023 19:38:30 +0000 Subject: [PATCH 49/59] Fix unbound variable access in trap manager. --- smarts/core/trap_manager.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/smarts/core/trap_manager.py b/smarts/core/trap_manager.py index 823995df60..10278662ab 100644 --- a/smarts/core/trap_manager.py +++ b/smarts/core/trap_manager.py @@ -22,7 +22,7 @@ import random as rand from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from shapely.geometry import Polygon @@ -286,7 +286,15 @@ def step(self, sim): continue vehicle: Optional[Vehicle] = None - if ConditionState.EXPIRED in capture.ready_state: + if capture.ready_state and capture.vehicle_id is not None: + vehicle = self._take_existing_vehicle( + sim, + capture.vehicle_id, + agent_id, + capture.updated_mission, + social=agent_id in sim.agent_manager.pending_social_agent_ids, + ) + elif ConditionState.EXPIRED in capture.ready_state: # Make sure there is not a vehicle in the same location mission = capture.updated_mission if mission is None: @@ -312,14 +320,6 @@ def step(self, sim): trap.default_entry_speed, social=agent_id in sim.agent_manager.pending_social_agent_ids, ) - elif trap_condition and capture.vehicle_id is not None: - vehicle = self._take_existing_vehicle( - sim, - capture.vehicle_id, - agent_id, - capture.updated_mission, - social=agent_id in sim.agent_manager.pending_social_agent_ids, - ) else: continue if vehicle is None: From 732eeb16f6bfda1c295656627b02307cb6cd6e5c Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 12 May 2023 19:39:49 +0000 Subject: [PATCH 50/59] Address reviews. --- smarts/core/id_actor_capture_manager.py | 2 -- smarts/sstudio/types.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/smarts/core/id_actor_capture_manager.py b/smarts/core/id_actor_capture_manager.py index bf20e82471..b393f91eb8 100644 --- a/smarts/core/id_actor_capture_manager.py +++ b/smarts/core/id_actor_capture_manager.py @@ -71,7 +71,6 @@ def step(self, sim): ) condition_result = entry_tactic.condition.evaluate(**condition_kwargs) if condition_result == ConditionState.EXPIRED: - print(condition_result) self._log.warning( f"Actor aquisition skipped for `{agent_id}` scheduled to start between " + f"`Condition `{entry_tactic.condition}` has expired with no vehicle." @@ -81,7 +80,6 @@ def step(self, sim): continue if not condition_result: continue - print(condition_result) vehicle: Optional[Vehicle] = self._take_existing_vehicle( sim, actor_id, diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 222e08b84b..e403e7ade4 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -1168,8 +1168,6 @@ class EntryTactic: """The tactic that the simulation should use to acquire a vehicle for an agent.""" start_time: float - # condition: Condition - # """The condition to determine if this entry tactic should be used.""" @dataclass(frozen=True) From 6ebfe76cad52d1d31b04f0e35dee0591d41f1687 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 12 May 2023 19:50:38 +0000 Subject: [PATCH 51/59] Fix unintentional scenario changes. --- scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py | 2 +- scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py | 2 +- scenarios/sumo/zoo_intersection/scenario.py | 2 +- smarts/core/trap_manager.py | 1 - smarts/env/tests/test_social_agent.py | 2 -- 5 files changed, 3 insertions(+), 6 deletions(-) diff --git a/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py b/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py index d9d9d8cbec..1951ca6393 100644 --- a/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py +++ b/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py @@ -96,7 +96,7 @@ Mission( route=route, entry_tactic=TrapEntryTactic( - start_time=20, wait_to_hijack_limit_s=0.1 + start_time=20, wait_to_hijack_limit_s=0 ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py b/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py index b20b9a4123..3c96a19f23 100644 --- a/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py +++ b/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py @@ -85,7 +85,7 @@ Mission( route=route, entry_tactic=TrapEntryTactic( - start_time=17, wait_to_hijack_limit_s=0.1 + start_time=17, wait_to_hijack_limit_s=0 ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/zoo_intersection/scenario.py b/scenarios/sumo/zoo_intersection/scenario.py index 0dcbcaff2c..5aea7fae8f 100644 --- a/scenarios/sumo/zoo_intersection/scenario.py +++ b/scenarios/sumo/zoo_intersection/scenario.py @@ -100,7 +100,7 @@ EndlessMission( begin=("edge-south-SN", 0, 10), entry_tactic=TrapEntryTactic( - start_time=0.7, wait_to_hijack_limit_s=0.1 + start_time=0.7, wait_to_hijack_limit_s=0 ), ), ], diff --git a/smarts/core/trap_manager.py b/smarts/core/trap_manager.py index 10278662ab..dd83812c0f 100644 --- a/smarts/core/trap_manager.py +++ b/smarts/core/trap_manager.py @@ -19,7 +19,6 @@ # THE SOFTWARE. import logging import math -import random as rand from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple diff --git a/smarts/env/tests/test_social_agent.py b/smarts/env/tests/test_social_agent.py index 0f0f55ffd8..62edcbbeba 100644 --- a/smarts/env/tests/test_social_agent.py +++ b/smarts/env/tests/test_social_agent.py @@ -22,9 +22,7 @@ import gym import pytest -from smarts.core.agent import Agent from smarts.core.agent_interface import AgentInterface, AgentType -from smarts.core.utils.episodes import episodes from smarts.env.hiway_env import HiWayEnv AGENT_ID = "Agent-007" From 701920531072b938fce66aa787877e03a834801c Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 12 May 2023 20:03:36 +0000 Subject: [PATCH 52/59] Default wait_to_hijack_limit_s to 0. --- .../1_to_2lane_left_turn_c_agents_1/scenario.py | 2 +- scenarios/sumo/merge/3lane_agents_1/scenario.py | 2 +- scenarios/sumo/merge/3lane_agents_2/scenario.py | 4 ++-- .../sumo/platoon/merge_exit_sumo_t_agents_1/scenario.py | 4 +--- scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py | 2 +- scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py | 6 +++--- scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py | 2 +- scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py | 2 +- scenarios/sumo/zoo_intersection/scenario.py | 4 +--- smarts/core/tests/test_trap_manager.py | 1 - smarts/core/utils/tests/fixtures.py | 1 - smarts/sstudio/types.py | 2 +- 12 files changed, 13 insertions(+), 19 deletions(-) diff --git a/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py b/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py index 688f19bc43..46955f3b18 100644 --- a/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py +++ b/scenarios/sumo/intersections/1_to_2lane_left_turn_c_agents_1/scenario.py @@ -95,7 +95,7 @@ Mission( route=route, entry_tactic=TrapEntryTactic( - start_time=4, wait_to_hijack_limit_s=0 + start_time=4 ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/merge/3lane_agents_1/scenario.py b/scenarios/sumo/merge/3lane_agents_1/scenario.py index 1ad7795389..5c42c18e58 100644 --- a/scenarios/sumo/merge/3lane_agents_1/scenario.py +++ b/scenarios/sumo/merge/3lane_agents_1/scenario.py @@ -83,7 +83,7 @@ ego_missions = [ Mission( Route(begin=("gneE6", 0, 10), end=("gneE4", 2, "max")), - entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=0), + entry_tactic=TrapEntryTactic(start_time=15), ), ] diff --git a/scenarios/sumo/merge/3lane_agents_2/scenario.py b/scenarios/sumo/merge/3lane_agents_2/scenario.py index 2c92b1af96..7a62dc7d2d 100644 --- a/scenarios/sumo/merge/3lane_agents_2/scenario.py +++ b/scenarios/sumo/merge/3lane_agents_2/scenario.py @@ -83,11 +83,11 @@ ego_missions = [ Mission( Route(begin=("gneE6", 0, 10), end=("gneE4", 2, "max")), - entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=0), + entry_tactic=TrapEntryTactic(start_time=15), ), Mission( Route(begin=("gneE3", 0, 10), end=("gneE4", 0, "max")), - entry_tactic=TrapEntryTactic(start_time=15, wait_to_hijack_limit_s=0), + entry_tactic=TrapEntryTactic(start_time=15), ), ] diff --git a/scenarios/sumo/platoon/merge_exit_sumo_t_agents_1/scenario.py b/scenarios/sumo/platoon/merge_exit_sumo_t_agents_1/scenario.py index a6637814a0..1faae67c7c 100644 --- a/scenarios/sumo/platoon/merge_exit_sumo_t_agents_1/scenario.py +++ b/scenarios/sumo/platoon/merge_exit_sumo_t_agents_1/scenario.py @@ -99,9 +99,7 @@ ego_missions = [ EndlessMission( begin=("E0", 2, 5), - entry_tactic=TrapEntryTactic( - start_time=31, wait_to_hijack_limit_s=0, default_entry_speed=0 - ), + entry_tactic=TrapEntryTactic(start_time=31, default_entry_speed=0), ) # Delayed start, to ensure road has prior traffic. ] diff --git a/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py b/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py index 1728332e76..3976eb6db9 100644 --- a/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py +++ b/scenarios/sumo/straight/3lane_cruise_agents_1/scenario.py @@ -85,7 +85,7 @@ Mission( route=route, entry_tactic=TrapEntryTactic( - start_time=17, wait_to_hijack_limit_s=0 + start_time=17 ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py b/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py index 35e2fdbe48..d1503ec72d 100644 --- a/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py +++ b/scenarios/sumo/straight/3lane_cruise_agents_3/scenario.py @@ -82,15 +82,15 @@ ego_missions = [ Mission( Route(begin=("gneE3", 0, 10), end=("gneE3", 0, "max")), - entry_tactic=TrapEntryTactic(start_time=19, wait_to_hijack_limit_s=0), + entry_tactic=TrapEntryTactic(start_time=19), ), Mission( Route(begin=("gneE3", 1, 10), end=("gneE3", 1, "max")), - entry_tactic=TrapEntryTactic(start_time=21, wait_to_hijack_limit_s=0), + entry_tactic=TrapEntryTactic(start_time=21), ), Mission( Route(begin=("gneE3", 2, 10), end=("gneE3", 2, "max")), - entry_tactic=TrapEntryTactic(start_time=17, wait_to_hijack_limit_s=0), + entry_tactic=TrapEntryTactic(start_time=17), ), ] diff --git a/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py b/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py index 1951ca6393..b0d5884534 100644 --- a/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py +++ b/scenarios/sumo/straight/3lane_cut_in_agents_1/scenario.py @@ -96,7 +96,7 @@ Mission( route=route, entry_tactic=TrapEntryTactic( - start_time=20, wait_to_hijack_limit_s=0 + start_time=20 ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py b/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py index 3c96a19f23..2372e2e33a 100644 --- a/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py +++ b/scenarios/sumo/straight/3lane_overtake_agents_1/scenario.py @@ -85,7 +85,7 @@ Mission( route=route, entry_tactic=TrapEntryTactic( - start_time=17, wait_to_hijack_limit_s=0 + start_time=17 ), # Delayed start, to ensure road has prior traffic. ) ] diff --git a/scenarios/sumo/zoo_intersection/scenario.py b/scenarios/sumo/zoo_intersection/scenario.py index 5aea7fae8f..d21f018630 100644 --- a/scenarios/sumo/zoo_intersection/scenario.py +++ b/scenarios/sumo/zoo_intersection/scenario.py @@ -99,9 +99,7 @@ [ EndlessMission( begin=("edge-south-SN", 0, 10), - entry_tactic=TrapEntryTactic( - start_time=0.7, wait_to_hijack_limit_s=0 - ), + entry_tactic=TrapEntryTactic(start_time=0.7), ), ], ), diff --git a/smarts/core/tests/test_trap_manager.py b/smarts/core/tests/test_trap_manager.py index 0e692b432c..d835a61100 100644 --- a/smarts/core/tests/test_trap_manager.py +++ b/smarts/core/tests/test_trap_manager.py @@ -107,7 +107,6 @@ def two_agent_capture_offset_tenth_of_second(): t.Mission( t.Route(begin=("west", 1, 20), end=("east", 1, "max")), entry_tactic=t.TrapEntryTactic( - wait_to_hijack_limit_s=0, zone=t.MapZone(start=("west", 0, 1), length=100, n_lanes=3), ), ), diff --git a/smarts/core/utils/tests/fixtures.py b/smarts/core/utils/tests/fixtures.py index 84cd1ab17f..43881de4bb 100644 --- a/smarts/core/utils/tests/fixtures.py +++ b/smarts/core/utils/tests/fixtures.py @@ -83,7 +83,6 @@ def large_observation(): route_vias=(), start_time=0.1, entry_tactic=t.TrapEntryTactic( - wait_to_hijack_limit_s=0, zone=None, exclusion_prefixes=(), default_entry_speed=None, diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index e403e7ade4..858d12d6af 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -1174,7 +1174,7 @@ class EntryTactic: class TrapEntryTactic(EntryTactic): """An entry tactic that repurposes a pre-existing vehicle for an agent.""" - wait_to_hijack_limit_s: float + wait_to_hijack_limit_s: float = 0 """The amount of seconds a hijack will wait to get a vehicle before defaulting to a new vehicle""" zone: Optional["MapZone"] = None """The zone of the hijack area""" From b657797ddf27bf7e86d70ef4d4a0e2d13f9912b6 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 12 May 2023 20:46:21 +0000 Subject: [PATCH 53/59] Allow for relative expiry on ExpireTrigger. --- smarts/sstudio/tests/test_conditions.py | 10 +++++- smarts/sstudio/types.py | 47 ++++++++++++++++++------- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index 08b276af08..902728189f 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -252,7 +252,7 @@ def test_condition_trigger(): assert not delayed_condition.evaluate(time=time) -def test_expiring_trigger(): +def test_expire_trigger(): end_time = 10 before = end_time - 1 after = end_time + 1 @@ -264,6 +264,14 @@ def test_expiring_trigger(): assert not expire_trigger.evaluate(time=end_time) assert not expire_trigger.evaluate(time=after) + first_time = 3 + expire_trigger = literal_true.expire(end_time, relative=True) + assert expire_trigger.evaluate(time=first_time) + assert expire_trigger.evaluate(time=first_time + before) + assert expire_trigger.evaluate(time=end_time) + assert not expire_trigger.evaluate(time=first_time + end_time) + assert not expire_trigger.evaluate(time=first_time + after) + def test_dependee_condition(): dependee_condition = DependeeActorCondition("leader") diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 858d12d6af..55925c2b1f 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -767,7 +767,9 @@ def trigger( self, delay_seconds=delay_seconds, persistant=persistant ) - def expire(self, time, expired_state=ConditionState.EXPIRED) -> "ExpireTrigger": + def expire( + self, time, expired_state=ConditionState.EXPIRED, relative: bool = False + ) -> "ExpireTrigger": """This trigger evaluates to the expired state value after the given simulation time. >>> trigger = LiteralCondition(ConditionState.TRUE).expire(20) @@ -779,12 +781,15 @@ def expire(self, time, expired_state=ConditionState.EXPIRED) -> "ExpireTrigger": Args: time (float): The simulation time when this trigger changes. expired_state (ConditionState, optional): The condition state to use when the simulation is after the given time. Defaults to ConditionState.EXPIRED. - + relative (bool, optional): If this trigger should resolve relative to the first evaluated time. Returns: ExpireTrigger: The resulting condition. """ return ExpireTrigger( - inner_condition=self, time=time, expired_state=expired_state + inner_condition=self, + time=time, + expired_state=expired_state, + relative=relative, ) def __and__(self, other: "Condition") -> "CompoundCondition": @@ -922,8 +927,19 @@ class ExpireTrigger(Condition): expired_state: ConditionState = ConditionState.EXPIRED """The state value this trigger should have when it expires.""" + relative: bool = False + """If this should start relative to the first time evaluated.""" + def evaluate(self, **kwargs) -> ConditionState: time = kwargs[ConditionRequires.time.name] + if self.relative: + key = "met" + met_time = getattr(self, key, -1) + if met_time == -1: + object.__setattr__(self, key, time) + time = 0 + else: + time -= met_time if time >= self.time: return self.expired_state return self.inner_condition.evaluate(**kwargs) @@ -971,16 +987,19 @@ def evaluate(self, **kwargs) -> ConditionState: key = "met_time" result = self.untriggered_state met_time = getattr(self, key, -1) - if met_time == -1 and self.inner_condition.evaluate(**kwargs): - object.__setattr__(self, key, time) - if met_time != -1 or ( - self.delay_seconds == 0 and self.inner_condition.evaluate(**kwargs) - ): - if time >= met_time + self.delay_seconds: - result = self.triggered_state - if self.persistant: - result &= self.inner_condition.evaluate(**kwargs) - return result + if met_time == -1: + if self.inner_condition.evaluate(**kwargs): + object.__setattr__(self, key, time) + time = 0 + else: + time = -1 + else: + time -= met_time + if time >= self.delay_seconds: + result = self.triggered_state + if self.persistant: + result &= self.inner_condition.evaluate(**kwargs) + return result temporals = result & (ConditionState.EXPIRED) if ConditionState.EXPIRED in temporals: @@ -996,6 +1015,8 @@ def __post_init__(self): raise TypeError( f"Abstract `{self.inner_condition.__class__.__name__}` cannot be wrapped by a trigger." ) + if self.delay_seconds < 0: + raise ValueError("Delay cannot be negative.") @dataclass(frozen=True) From f415d7781a9cb63026ae95020e49eea63a3ba95e Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 12 May 2023 21:00:51 +0000 Subject: [PATCH 54/59] Wrap up PR. --- CHANGELOG.md | 4 ++++ smarts/core/tests/test_trap_manager.py | 4 ++++ smarts/core/utils/tests/fixtures.py | 1 + smarts/sstudio/types.py | 5 +++++ 4 files changed, 14 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 53c60ef247..6e5ee98ff9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,8 +22,12 @@ Copy and pasting the git commit messages is __NOT__ enough. - Reward function in platoon RL example retrieves actor-of-interest from marked neighborhood vehicles. - `dist_to_destination` metric cost function computes the route distance and end point for vehicle of interest contained in social agents, SMARTS traffic provider, SUMO traffic provider, and traffic history provider. - `sstudio` generated scenario vehicle traffic IDs are now shortened. +- `TrapEntryTactic.wait_to_hijack_limit_s` field now defaults to `0`. +- `EntryTactic` derived classes now contain `condition` to provide extra filtering of candidate actors. +- `EntryTactic` derived classes now contain `start_time`. ### Deprecated - `visdom` is set to be removed from the SMARTS object parameters. +- Deprecated `start_time` on missions. ### Fixed - Fixed implementations of `RoadMap.waypoint_paths()` to ensure that the result is never empty. - The routes of `SumoTrafficSimulation` traffic vehicles are now preserved to be passed over to other traffic simulators when the `SumoTrafficSimulation` disconnects. diff --git a/smarts/core/tests/test_trap_manager.py b/smarts/core/tests/test_trap_manager.py index d835a61100..db96b3a790 100644 --- a/smarts/core/tests/test_trap_manager.py +++ b/smarts/core/tests/test_trap_manager.py @@ -61,6 +61,7 @@ def scenarios(traffic_sim): t.Mission( t.Route(begin=("west", 1, 10), end=("east", 1, "max")), entry_tactic=t.TrapEntryTactic( + start_time=0, wait_to_hijack_limit_s=3, zone=t.MapZone(start=("west", 0, 1), length=100, n_lanes=3), ), @@ -84,6 +85,7 @@ def empty_scenarios(): t.Mission( t.Route(begin=("west", 1, 10), end=("east", 1, "max")), entry_tactic=t.TrapEntryTactic( + start_time=0, wait_to_hijack_limit_s=3, zone=t.MapZone(start=("west", 0, 1), length=100, n_lanes=3), ), @@ -107,12 +109,14 @@ def two_agent_capture_offset_tenth_of_second(): t.Mission( t.Route(begin=("west", 1, 20), end=("east", 1, "max")), entry_tactic=t.TrapEntryTactic( + start_time=0, zone=t.MapZone(start=("west", 0, 1), length=100, n_lanes=3), ), ), t.Mission( t.Route(begin=("west", 2, 10), end=("east", 1, "max")), entry_tactic=t.TrapEntryTactic( + start_time=0, wait_to_hijack_limit_s=0.1, zone=t.MapZone(start=("west", 0, 1), length=100, n_lanes=3), ), diff --git a/smarts/core/utils/tests/fixtures.py b/smarts/core/utils/tests/fixtures.py index 43881de4bb..a99abeae02 100644 --- a/smarts/core/utils/tests/fixtures.py +++ b/smarts/core/utils/tests/fixtures.py @@ -83,6 +83,7 @@ def large_observation(): route_vias=(), start_time=0.1, entry_tactic=t.TrapEntryTactic( + start_time=0, zone=None, exclusion_prefixes=(), default_entry_speed=None, diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py index 55925c2b1f..3d846b86dd 100644 --- a/smarts/sstudio/types.py +++ b/smarts/sstudio/types.py @@ -1190,6 +1190,11 @@ class EntryTactic: start_time: float + def __post_init__(self): + assert ( + getattr(self, "condition", None) is not None + ), "Abstract class, inheriting types must implement the `condition` field." + @dataclass(frozen=True) class TrapEntryTactic(EntryTactic): From 5ce40813aa18861db621ed1fe1b21df80c693b8a Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 12 May 2023 21:30:52 +0000 Subject: [PATCH 55/59] Fix trap manager test. --- smarts/core/tests/test_trap_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/smarts/core/tests/test_trap_manager.py b/smarts/core/tests/test_trap_manager.py index db96b3a790..855f7cfad7 100644 --- a/smarts/core/tests/test_trap_manager.py +++ b/smarts/core/tests/test_trap_manager.py @@ -61,7 +61,7 @@ def scenarios(traffic_sim): t.Mission( t.Route(begin=("west", 1, 10), end=("east", 1, "max")), entry_tactic=t.TrapEntryTactic( - start_time=0, + start_time=0.1, wait_to_hijack_limit_s=3, zone=t.MapZone(start=("west", 0, 1), length=100, n_lanes=3), ), @@ -85,7 +85,7 @@ def empty_scenarios(): t.Mission( t.Route(begin=("west", 1, 10), end=("east", 1, "max")), entry_tactic=t.TrapEntryTactic( - start_time=0, + start_time=0.1, wait_to_hijack_limit_s=3, zone=t.MapZone(start=("west", 0, 1), length=100, n_lanes=3), ), From ee5c0c140d22a86c096b32546464a08751ca2b5c Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Mon, 15 May 2023 16:50:53 +0000 Subject: [PATCH 56/59] Refactor sstudio types. --- smarts/sstudio/genscenario.py | 16 +- smarts/sstudio/tests/baseline.rou.xml | 21 +- smarts/sstudio/tests/test_conditions.py | 1 - smarts/sstudio/types.py | 1783 ----------------- smarts/sstudio/types/__init__.py | 35 + smarts/sstudio/types/actor/__init__.py | 31 + .../sstudio/types/actor/social_agent_actor.py | 67 + smarts/sstudio/types/actor/traffic_actor.py | 75 + smarts/sstudio/types/bubble.py | 119 ++ smarts/sstudio/types/bubble_limits.py | 42 + smarts/sstudio/types/condition.py | 610 ++++++ smarts/sstudio/types/constants.py | 26 + smarts/sstudio/types/dataset.py | 71 + smarts/sstudio/types/distribution.py | 78 + smarts/sstudio/types/entry_tactic.py | 82 + smarts/sstudio/types/map_spec.py | 59 + smarts/sstudio/types/mission.py | 145 ++ smarts/sstudio/types/route.py | 128 ++ smarts/sstudio/types/scenario.py | 92 + smarts/sstudio/types/traffic.py | 129 ++ smarts/sstudio/types/traffic_model.py | 169 ++ smarts/sstudio/types/zone.py | 274 +++ 22 files changed, 2257 insertions(+), 1796 deletions(-) delete mode 100644 smarts/sstudio/types.py create mode 100644 smarts/sstudio/types/__init__.py create mode 100644 smarts/sstudio/types/actor/__init__.py create mode 100644 smarts/sstudio/types/actor/social_agent_actor.py create mode 100644 smarts/sstudio/types/actor/traffic_actor.py create mode 100644 smarts/sstudio/types/bubble.py create mode 100644 smarts/sstudio/types/bubble_limits.py create mode 100644 smarts/sstudio/types/condition.py create mode 100644 smarts/sstudio/types/constants.py create mode 100644 smarts/sstudio/types/dataset.py create mode 100644 smarts/sstudio/types/distribution.py create mode 100644 smarts/sstudio/types/entry_tactic.py create mode 100644 smarts/sstudio/types/map_spec.py create mode 100644 smarts/sstudio/types/mission.py create mode 100644 smarts/sstudio/types/route.py create mode 100644 smarts/sstudio/types/scenario.py create mode 100644 smarts/sstudio/types/traffic.py create mode 100644 smarts/sstudio/types/traffic_model.py create mode 100644 smarts/sstudio/types/zone.py diff --git a/smarts/sstudio/genscenario.py b/smarts/sstudio/genscenario.py index b1840c6bad..9bc97d1789 100644 --- a/smarts/sstudio/genscenario.py +++ b/smarts/sstudio/genscenario.py @@ -28,7 +28,7 @@ import os import pickle import sqlite3 -from dataclasses import asdict, replace +from dataclasses import asdict, dataclass, replace from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -46,6 +46,18 @@ logger.setLevel(logging.WARNING) +@dataclass(frozen=True) +class ActorAndMission: + """Holds an Actor object and its associated Mission.""" + + actor: types.Actor + """Specification for traffic actor. + """ + mission: Union[types.Mission, types.EndlessMission, types.LapMission] + """Mission for traffic actor. + """ + + def _check_if_called_externally(): frame_info = inspect.stack()[2] module = inspect.getmodule(frame_info[0]) @@ -592,7 +604,7 @@ def resolve_mission(mission): _validate_missions(missions) missions = [ - types.ActorAndMission(actor=actor, mission=resolve_mission(mission)) + ActorAndMission(actor=actor, mission=resolve_mission(mission)) for actor, mission in itertools.product(actors, missions) ] with open(output_path, "wb") as f: diff --git a/smarts/sstudio/tests/baseline.rou.xml b/smarts/sstudio/tests/baseline.rou.xml index e571acafa3..5baa590cc5 100644 --- a/smarts/sstudio/tests/baseline.rou.xml +++ b/smarts/sstudio/tests/baseline.rou.xml @@ -1,15 +1,15 @@ - - - + + - - + + - + - + diff --git a/smarts/sstudio/tests/test_conditions.py b/smarts/sstudio/tests/test_conditions.py index 902728189f..f267122f82 100644 --- a/smarts/sstudio/tests/test_conditions.py +++ b/smarts/sstudio/tests/test_conditions.py @@ -27,7 +27,6 @@ CompoundCondition, Condition, ConditionOperator, - ConditionRequires, ConditionState, ConditionTrigger, DependeeActorCondition, diff --git a/smarts/sstudio/types.py b/smarts/sstudio/types.py deleted file mode 100644 index 3d846b86dd..0000000000 --- a/smarts/sstudio/types.py +++ /dev/null @@ -1,1783 +0,0 @@ -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -import collections.abc as collections_abc -import enum -import logging -import math -import random -import sys -import warnings -from dataclasses import dataclass, field, replace -from enum import IntEnum, IntFlag -from functools import cached_property -from typing import ( - Any, - Callable, - Dict, - FrozenSet, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, -) - -import numpy as np -from shapely.affinity import rotate as shapely_rotate -from shapely.affinity import translate as shapely_translate -from shapely.geometry import ( - GeometryCollection, - LineString, - MultiPolygon, - Point, - Polygon, - box, -) -from shapely.ops import split, unary_union - -from smarts.core import gen_id -from smarts.core.colors import Colors -from smarts.core.condition_state import ConditionState -from smarts.core.coordinates import RefLinePoint -from smarts.core.default_map_builder import get_road_map -from smarts.core.road_map import RoadMap -from smarts.core.utils.file import pickle_hash_int -from smarts.core.utils.id import SocialAgentId -from smarts.core.utils.math import rotate_cw_around_point - -MISSING = sys.maxsize - - -class _SUMO_PARAMS_MODE(IntEnum): - TITLE_CASE = 0 - KEEP_SNAKE_CASE = 1 - - -class _SumoParams(collections_abc.Mapping): - """For some Sumo params (e.g. LaneChangingModel) the arguments are in title case - with a given prefix. Subclassing this class allows for an automatic way to map - between PEP8-compatible naming and Sumo's. - """ - - def __init__( - self, prefix, whitelist=[], mode=_SUMO_PARAMS_MODE.TITLE_CASE, **kwargs - ): - def snake_to_title(word): - return "".join(x.capitalize() or "_" for x in word.split("_")) - - def keep_snake_case(word: str): - w = word[0].upper() + word[1:] - return "".join(x or "_" for x in w.split("_")) - - func: Callable[[str], str] = snake_to_title - if mode == _SUMO_PARAMS_MODE.TITLE_CASE: - pass - elif mode == _SUMO_PARAMS_MODE.KEEP_SNAKE_CASE: - func = keep_snake_case - - # XXX: On rare occasions sumo doesn't respect their own conventions - # (e.x. junction model's impatience). - self._params = {key: kwargs.pop(key) for key in whitelist if key in kwargs} - - for key, value in kwargs.items(): - self._params[f"{prefix}{func(key)}"] = value - - def __iter__(self): - return iter(self._params) - - def __getitem__(self, key): - return self._params[key] - - def __len__(self): - return len(self._params) - - def __hash__(self): - return hash(frozenset(self._params.items())) - - def __eq__(self, other): - return self.__class__ == other.__class__ and hash(self) == hash(other) - - -class LaneChangingModel(_SumoParams): - """Models how the actor acts with respect to lane changes.""" - - # For SUMO-specific attributes, see: - # https://sumo.dlr.de/docs/Definition_of_Vehicles%2C_Vehicle_Types%2C_and_Routes.html#lane-changing_models - - def __init__(self, **kwargs): - super().__init__("lc", whitelist=["minGapLat"], **kwargs) - - -class JunctionModel(_SumoParams): - """Models how the actor acts with respect to waiting at junctions.""" - - def __init__(self, **kwargs): - super().__init__("jm", whitelist=["impatience"], **kwargs) - - -class SmartsLaneChangingModel(LaneChangingModel): - """Implements the simple lane-changing model built-into SMARTS. - - Args: - cutin_prob (float, optional): Float value [0, 1] that - determines the probabilty this vehicle will "arbitrarily" cut in - front of an adjacent agent vehicle when it has a chance, even if - there would otherwise be no reason to change lanes at that point. - Higher values risk a situation where this vehicle ends up in a lane - where it cannot maintain its planned route. If that happens, this - vehicle will perform whatever its default behavior is when it - completes its route. Defaults to 0.0. - assertive (float, optional): Willingness to accept lower front and rear - gaps in the target lane. The required gap is divided by this value. - Attempts to match the semantics of the attribute in SUMO's default - lane-changing model, see: ``https://sumo.dlr.de/docs/Definition_of_Vehicles%2C_Vehicle_Types%2C_and_Routes.html#lane-changing_models``. - Range: positive reals. Defaults to 1.0. - dogmatic (bool, optional): If True, will cutin when a suitable - opportunity presents itself based on the above parameters, even if - it means the risk of not not completing the assigned route; - otherwise, will forego the chance. Defaults to True. - hold_period (float, optional): The minimum amount of time (seconds) to - remain in the agent's lane after cutting into it (including the - time it takes within the lane to complete the maneuver). Must be - non-negative. Defaults to 3.0. - slow_down_after (float, optional): Target speed during the hold_period - will be scaled by this value. Must be non-negative. Defaults to 1.0. - multi_lane_cutin (bool, optional): If True, this vehicle will consider - changing across multiple lanes at once in order to cutin upon an - agent vehicle when there's an opportunity. Defaults to False. - """ - - def __init__( - self, - cutin_prob: float = 0.0, - assertive: float = 1.0, - dogmatic: bool = True, - hold_period: float = 3.0, - slow_down_after: float = 1.0, - multi_lane_cutin: bool = False, - ): - super().__init__( - cutin_prob=cutin_prob, - assertive=assertive, - dogmatic=dogmatic, - hold_period=hold_period, - slow_down_after=slow_down_after, - multi_lane_cutin=multi_lane_cutin, - ) - - -class SmartsJunctionModel(JunctionModel): - """Implements the simple junction model built-into SMARTS. - - Args: - yield_to_agents (str, optional): Defaults to "normal". 3 options are - available, namely: (1) "always" - Traffic actors will yield to Ego - and Social agents within junctions. (2) "never" - Traffic actors - will never yield to Ego or Social agents within junctions. - (3) "normal" - Traffic actors will attempt to honor normal - right-of-way conventions, only yielding when an agent has the - right-of-way. Examples of such conventions include (a) vehicles - going straight have the right-of-way over turning vehicles; - (b) vehicles on roads with more lanes have the right-of-way - relative to vehicles on intersecting roads with less lanes; - (c) all other things being equal, the vehicle to the right - in a counter-clockwise sense has the right-of-way. - wait_to_restart (float, optional): The amount of time in seconds - after stopping at a signal or stop sign before this vehicle - will start to go again. Defaults to 0.0. - """ - - def __init__(self, yield_to_agents: str = "normal", wait_to_restart: float = 0.0): - super().__init__( - yield_to_agents=yield_to_agents, wait_to_restart=wait_to_restart - ) - - -@dataclass(frozen=True) -class Distribution: - """A gaussian distribution used for randomized parameters.""" - - mean: float - """The mean value of the gaussian distribution.""" - sigma: float - """The sigma value of the gaussian distribution.""" - - def sample(self): - """The next sample from the distribution.""" - return random.gauss(self.mean, self.sigma) - - -@dataclass -class UniformDistribution: - """A uniform distribution, return a random number N - such that a <= N <= b for a <= b and b <= N <= a for b < a. - """ - - a: float - b: float - - def __post_init__(self): - if self.b < self.a: - self.a, self.b = self.b, self.a - - def sample(self): - """Get the next sample.""" - return random.uniform(self.a, self.b) - - -@dataclass -class TruncatedDistribution: - """A truncated normal distribution, by default, location=0, scale=1""" - - a: float - b: float - loc: float = 0 - scale: float = 1 - - def __post_init__(self): - assert self.a != self.b - if self.b < self.a: - self.a, self.b = self.b, self.a - - def sample(self): - """Get the next sample""" - from scipy.stats import truncnorm - - return truncnorm.rvs(self.a, self.b, loc=self.loc, scale=self.scale) - - -@dataclass(frozen=True) -class Actor: - """This is the base description/spec type for traffic actors.""" - - pass - - -@dataclass(frozen=True) -class TrafficActor(Actor): - """Used as a description/spec for traffic actors (e.x. Vehicles, Pedestrians, - etc). The defaults provided are for a car, but the name is not set to make it - explicit that you actually want a car. - """ - - name: str - """The name of the traffic actor. It must be unique.""" - accel: float = 2.6 - """The maximum acceleration value of the actor (in m/s^2).""" - decel: float = 4.5 - """The maximum deceleration value of the actor (in m/s^2).""" - tau: float = 1.0 - """The minimum time headway""" - sigma: float = 0.5 - """The driver imperfection""" # TODO: appears to not be used in generators.py - depart_speed: Union[float, str] = "max" - """The starting speed of the actor""" - emergency_decel: float = 4.5 - """maximum deceleration ability of vehicle in case of emergency""" - speed: Distribution = Distribution(mean=1.0, sigma=0.1) - """The speed distribution of this actor in m/s.""" - imperfection: Distribution = Distribution(mean=0.5, sigma=0) - """Driver imperfection within range [0..1]""" - min_gap: Distribution = Distribution(mean=2.5, sigma=0) - """Minimum gap (when standing) in meters.""" - max_speed: float = 55.5 - """The vehicle's maximum velocity (in m/s), defaults to 200 km/h for vehicles""" - vehicle_type: str = "passenger" - """The configured vehicle type this actor will perform as. ("passenger", "bus", "coach", "truck", "trailer")""" - lane_changing_model: LaneChangingModel = field( - default_factory=LaneChangingModel, hash=False - ) - junction_model: JunctionModel = field(default_factory=JunctionModel, hash=False) - - def __hash__(self) -> int: - return pickle_hash_int(self) - - @property - def id(self) -> str: - """The identifier tag of the traffic actor.""" - return "{}-{}".format(self.name, str(hash(self))[:6]) - - -@dataclass(frozen=True) -class SocialAgentActor(Actor): - """Used as a description/spec for zoo traffic actors. These actors use a - pre-trained model to understand how to act in the environment. - """ - - name: str - """The name of the social actor. Must be unique.""" - - # A pre-registered zoo identifying tag you provide to help SMARTS identify the - # prefab of a social agent. - agent_locator: str - """The locator reference to the zoo registration call. Expects a string in the format - of 'path.to.file:locator-name' where the path to the registration call is in the form - {PYTHONPATH}[n]/path/to/file.py - """ - policy_kwargs: Dict[str, Any] = field(default_factory=dict) - """Additional keyword arguments to be passed to the constructed class overriding the - existing registered arguments. - """ - initial_speed: Optional[float] = None - """Set the initial speed, defaults to 0.""" - - -@dataclass(frozen=True) -class BoidAgentActor(SocialAgentActor): - """Used as a description/spec for boid traffic actors. Boid actors control multiple - vehicles. - """ - - id: str = field(default_factory=lambda: f"boid-{gen_id()}") - - # The max number of vehicles that this agent will control at a time. This value is - # honored when using a bubble for boid dynamic assignment. - capacity: "BubbleLimits" = None - """The capacity of the boid agent to take over vehicles.""" - - -# A MapBuilder should return an object derived from the RoadMap base class -# and a hash that uniquely identifies it (changes to the hash should signify -# that the map is different enough that map-related caches should be reloaded). -# -# This function should be re-callable (although caching is up to the implementation). -# The idea here is that anything in SMARTS that needs to use a RoadMap -# can call this builder to get or create one as necessary. -MapBuilder = Callable[[Any], Tuple[Optional[RoadMap], Optional[str]]] - - -@dataclass(frozen=True) -class MapSpec: - """A map specification that describes how to generate a roadmap.""" - - source: str - """A path or URL or name uniquely designating the map source.""" - lanepoint_spacing: float = 1.0 - """The default distance between pre-generated Lane Points (Waypoints).""" - default_lane_width: Optional[float] = None - """If specified, the default width (in meters) of lanes on this map.""" - shift_to_origin: bool = False - """If True, upon creation a map whose bounding-box does not intersect with - the origin point (0,0) will be shifted such that it does.""" - builder_fn: MapBuilder = get_road_map - """If specified, this should return an object derived from the RoadMap base class - and a hash that uniquely identifies it (changes to the hash should signify - that the map is different enough that map-related caches should be reloaded). - The parameter is this MapSpec object itself. - If not specified, this currently defaults to a function that creates - SUMO road networks (get_road_map()) in smarts.core.default_map_builder.""" - - -@dataclass(frozen=True) -class Route: - """A route is represented by begin and end road IDs, with an optional list of - intermediary road IDs. When an intermediary is not specified the router will - decide what it should be. - """ - - ## road, lane index, offset - begin: Tuple[str, int, Any] - """The (road, lane_index, offset) details of the start location for the route. - - road: - The starting road by name. - lane_index: - The lane index from the rightmost lane. - offset: - The offset in metres into the lane. Also acceptable\\: "max", "random" - """ - ## road, lane index, offset - end: Tuple[str, int, Any] - """The (road, lane_index, offset) details of the end location for the route. - - road: - The starting road by name. - lane_index: - The lane index from the rightmost lane. - offset: - The offset in metres into the lane. Also acceptable\\: "max", "random" - """ - - # Roads we want to make sure this route includes - via: Tuple[str, ...] = field(default_factory=tuple) - """The ids of roads that must be included in the route between `begin` and `end`.""" - - map_spec: Optional[MapSpec] = None - """All routes are relative to a road map. If not specified here, - the default map_spec for the scenario is used.""" - - @property - def id(self) -> str: - """The unique id of this route.""" - return "{}-{}-{}".format( - "_".join(map(str, self.begin)), - "_".join(map(str, self.end)), - str(hash(self))[:6], - ) - - @property - def roads(self): - """All roads that are used within this route.""" - return (self.begin[0],) + self.via + (self.end[0],) - - def __hash__(self): - return pickle_hash_int(self) - - def __eq__(self, other): - return self.__class__ == other.__class__ and hash(self) == hash(other) - - -@dataclass(frozen=True) -class RandomRoute: - """An alternative to types.Route which specifies to sstudio to generate a random - route. - """ - - id: str = field(default_factory=lambda: f"random-route-{gen_id()}") - - map_spec: Optional[MapSpec] = None - """All routes are relative to a road map. If not specified here, - the default map_spec for the scenario is used.""" - - def __hash__(self): - return hash(self.id) - - def __eq__(self, other): - return self.__class__ == other.__class__ and hash(self) == hash(other) - - -@dataclass(frozen=True) -class Flow: - """A route with an actor type emitted at a given rate.""" - - route: Union[RandomRoute, Route] - """The route for the actor to attempt to follow.""" - rate: float - """Vehicles per hour.""" - begin: float = 0 - """Start time in seconds.""" - # XXX: Defaults to 1 hour of traffic. We may want to change this to be "continual - # traffic", effectively an infinite end. - end: float = 1 * 60 * 60 - """End time in seconds.""" - actors: Dict[TrafficActor, float] = field(default_factory=dict) - """An actor to weight mapping associated as\\: { actor: weight } - - :param actor: The traffic actors that are provided. - :param weight: The chance of this actor appearing as a ratio over total weight. - """ - randomly_spaced: bool = False - """Determines if the flow should have randomly spaced traffic. Defaults to `False`.""" - repeat_route: bool = False - """If True, vehicles that finish their route will be restarted at the beginning. Defaults to `False`.""" - - @property - def id(self) -> str: - """The unique id of this flow.""" - return "{}-{}".format( - self.route.id, - str(hash(self))[:6], - ) - - def __hash__(self): - # Custom hash since self.actors is not hashable, here we first convert to a - # frozenset. - return pickle_hash_int((self.route, self.rate, frozenset(self.actors.items()))) - - def __eq__(self, other): - return self.__class__ == other.__class__ and hash(self) == hash(other) - - -@dataclass(frozen=True) -class Trip: - """A route with a single actor type with name and unique id.""" - - vehicle_name: str - """The name of the vehicle. It must be unique. """ - route: Union[RandomRoute, Route] - """The route for the actor to attempt to follow.""" - vehicle_type: str = "passenger" - """The type of the vehicle""" - depart: float = 0 - """Start time in seconds.""" - actor: Optional[TrafficActor] = field(default=None) - """The traffic actor model (usually vehicle) that will be used for the trip.""" - - def __post_init__(self): - object.__setattr__( - self, - "actor", - ( - replace( - self.actor, name=self.vehicle_name, vehicle_type=self.vehicle_type - ) - if self.actor is not None - else TrafficActor( - name=self.vehicle_name, vehicle_type=self.vehicle_type - ) - ), - ) - - @property - def id(self) -> str: - """The unique id of this trip.""" - return self.vehicle_name - - def __hash__(self): - # Custom hash since self.actors is not hashable, here we first convert to a - # frozenset. - return pickle_hash_int((self.route, self.actor)) - - def __eq__(self, other): - return self.__class__ == other.__class__ and hash(self) == hash(other) - - -@dataclass(frozen=True) -class JunctionEdgeIDResolver: - """A utility for resolving a junction connection edge""" - - start_edge_id: str - start_lane_index: int - end_edge_id: str - end_lane_index: int - - def to_edge(self, sumo_road_network) -> str: - """Queries the road network to see if there is a junction edge between the two - given edges. - """ - return sumo_road_network.get_edge_in_junction( - self.start_edge_id, - self.start_lane_index, - self.end_edge_id, - self.end_lane_index, - ) - - -@dataclass(frozen=True) -class Via: - """A point on a road that an actor must pass through""" - - road_id: Union[str, JunctionEdgeIDResolver] - """The road this via is on""" - lane_index: int - """The lane this via sits on""" - lane_offset: int - """The offset along the lane where this via sits""" - required_speed: float - """The speed that a vehicle should travel through this via""" - hit_distance: float = -1 - """The distance at which this waypoint can be hit. Negative means half the lane radius.""" - - -@dataclass(frozen=True) -class Traffic: - """The descriptor for traffic.""" - - flows: Sequence[Flow] - """Flows are used to define a steady supply of vehicles.""" - # TODO: consider moving TrafficHistory stuff in here (and rename to Trajectory) - # TODO: - treat history points like Vias (no guarantee on history timesteps anyway) - trips: Optional[Sequence[Trip]] = None - """Trips are used to define a series of single vehicle trip.""" - engine: str = "SUMO" - """Traffic-generation engine to use. Supported values include "SUMO" and "SMARTS". "SUMO" requires using a SumoRoadNetwork for the RoadMap. - """ - - -class ConditionOperator(IntEnum): - """Represents logical operators between conditions.""" - - CONJUNCTION = enum.auto() - """Evaluate true if both operands are true, otherwise false.""" - - DISJUNCTION = enum.auto() - """Evaluate true if either operand is true, otherwise false.""" - - IMPLICATION = enum.auto() - """Evaluate true if either the first operand is false, or both operands are true, otherwise false.""" - - ## This would be desirable but makes the implementation more difficult in comparison to a negated condition. - # NEGATION=enum.auto() - # """True if its operand is false, otherwise false.""" - - -class ConditionRequires(IntFlag): - """This bitfield lays out the required information that a condition needs in order to evaluate.""" - - none = 0 - - # MISSION CONSTANTS - agent_id = enum.auto() - mission = enum.auto() - - # SIMULATION STATE - time = enum.auto() - actor_ids = enum.auto() - actor_states = enum.auto() - road_map = enum.auto() - simulation = enum.auto() - - # ACTOR STATE - current_actor_state = enum.auto() - current_actor_road_status = enum.auto() - - any_simulation_state = time | actor_ids | actor_states | simulation - any_current_actor_state = mission | current_actor_state | current_actor_road_status - any_mission_state = agent_id | mission - - -@dataclass(frozen=True) -class Condition: - """This encompasses an expression to evaluate to a logical result.""" - - def evaluate(self, **kwargs) -> ConditionState: - """Used to evaluate if a condition is met. - - Returns: - ConditionState: The evaluation result of the condition. - """ - raise NotImplementedError() - - @property - def requires(self) -> ConditionRequires: - """Information that the condition requires to evaluate state. - - Returns: - ConditionRequires: The types of information this condition needs in order to evaluate. - """ - raise NotImplementedError() - - def negation(self) -> "NegatedCondition": - """Negates this condition giving the opposite result on evaluation. - - >>> condition_true = LiteralCondition(ConditionState.TRUE) - >>> condition_true.evaluate() - - >>> condition_false = condition_true.negation() - >>> condition_false.evaluate() - - - Note\\: This erases temporal values EXPIRED and BEFORE. - >>> condition_before = LiteralCondition(ConditionState.BEFORE) - >>> condition_before.negation().negation().evaluate() - - - Returns: - NegatedCondition: The wrapped condition. - """ - return NegatedCondition(self) - - def conjunction(self, other: "Condition") -> "CompoundCondition": - """Resolve conditions as A AND B. - - The bit AND operator has been overloaded to call this method. - >>> dependee_condition = DependeeActorCondition("leader") - >>> dependee_condition.evaluate(actor_ids={"leader"}) - - >>> conjunction = dependee_condition & LiteralCondition(ConditionState.FALSE) - >>> conjunction.evaluate(actor_ids={"leader"}) - - - Note that the resolution has the priority EXPIRED > BEFORE > FALSE > TRUE. - >>> conjunction = LiteralCondition(ConditionState.TRUE) & LiteralCondition(ConditionState.BEFORE) - >>> conjunction.evaluate() - - >>> (conjunction & LiteralCondition(ConditionState.EXPIRED)).evaluate() - - - Returns: - CompoundCondition: A condition combining two conditions using an AND operation. - """ - return CompoundCondition(self, other, operator=ConditionOperator.CONJUNCTION) - - def disjunction(self, other: "Condition") -> "CompoundCondition": - """Resolve conditions as A OR B. - - The bit OR operator has been overloaded to call this method. - >>> disjunction = LiteralCondition(ConditionState.TRUE) | LiteralCondition(ConditionState.BEFORE) - >>> disjunction.evaluate() - - - Note that the resolution has the priority TRUE > BEFORE > FALSE > EXPIRED. - >>> disjunction = LiteralCondition(ConditionState.FALSE) | LiteralCondition(ConditionState.EXPIRED) - >>> disjunction.evaluate() - - >>> (disjunction | LiteralCondition(ConditionState.BEFORE)).evaluate() - - """ - return CompoundCondition(self, other, operator=ConditionOperator.DISJUNCTION) - - def implication(self, other: "Condition") -> "CompoundCondition": - """Resolve conditions as A IMPLIES B. This is the same as A AND B OR NOT A.""" - return CompoundCondition(self, other, operator=ConditionOperator.IMPLICATION) - - def trigger( - self, delay_seconds: float, persistant: bool = False - ) -> "ConditionTrigger": - """Converts the condition to a trigger which becomes permanently TRUE after the first time the inner condition becomes TRUE. - - >>> trigger = TimeWindowCondition(2, 5).trigger(delay_seconds=0) - >>> trigger.evaluate(time=1) - - >>> trigger.evaluate(time=4) - - >>> trigger.evaluate(time=90) - - - >>> start_time = 5 - >>> between_time = 10 - >>> delay_seconds = 20 - >>> trigger = LiteralCondition(ConditionState.TRUE).trigger(delay_seconds=delay_seconds) - >>> trigger.evaluate(time=start_time) - - >>> trigger.evaluate(time=between_time) - - >>> trigger.evaluate(time=start_time + delay_seconds) - - >>> trigger.evaluate(time=between_time) - - - Args: - delay_seconds (float): Applies the trigger after the delay has passed since the inner condition first TRUE. Defaults to False. - persistant (bool, optional): Mixes the inner result with the trigger result using an AND operation. - - Returns: - ConditionTrigger: A resulting condition. - """ - return ConditionTrigger( - self, delay_seconds=delay_seconds, persistant=persistant - ) - - def expire( - self, time, expired_state=ConditionState.EXPIRED, relative: bool = False - ) -> "ExpireTrigger": - """This trigger evaluates to the expired state value after the given simulation time. - - >>> trigger = LiteralCondition(ConditionState.TRUE).expire(20) - >>> trigger.evaluate(time=10) - - >>> trigger.evaluate(time=30) - - - Args: - time (float): The simulation time when this trigger changes. - expired_state (ConditionState, optional): The condition state to use when the simulation is after the given time. Defaults to ConditionState.EXPIRED. - relative (bool, optional): If this trigger should resolve relative to the first evaluated time. - Returns: - ExpireTrigger: The resulting condition. - """ - return ExpireTrigger( - inner_condition=self, - time=time, - expired_state=expired_state, - relative=relative, - ) - - def __and__(self, other: "Condition") -> "CompoundCondition": - """Resolve conditions as A AND B.""" - assert isinstance(other, Condition) - return self.conjunction(other) - - def __or__(self, other: "Condition") -> "CompoundCondition": - """Resolve conditions as A OR B.""" - assert isinstance(other, Condition) - return self.disjunction(other) - - def __neg__(self) -> "NegatedCondition": - """Negates this condition""" - return self.negation() - - -@dataclass(frozen=True) -class SubjectCondition(Condition): - """This condition assumes that there is a subject involved.""" - - def evaluate(self, **kwargs) -> ConditionState: - """Used to evaluate if a condition is met. - - Args: - actor_info: Information about the currently relevant actor. - Returns: - ConditionState: The evaluation result of the condition. - """ - raise NotImplementedError() - - @property - def requires(self) -> ConditionRequires: - return ConditionRequires.current_actor_state - - -_abstract_conditions = (Condition, SubjectCondition) - - -@dataclass(frozen=True) -class LiteralCondition(Condition): - """This condition evaluates as a literal without considering evaluation parameters.""" - - literal: ConditionState - """The literal value of this condition.""" - - def evaluate(self, **kwargs) -> ConditionState: - return self.literal - - @property - def requires(self) -> ConditionRequires: - return ConditionRequires.none - - -@dataclass(frozen=True) -class TimeWindowCondition(Condition): - """This condition should be true in the given simulation time window.""" - - start: float - """The starting simulation time before which this condition becomes false.""" - end: float - """The ending simulation time as of which this condition becomes expired.""" - - def evaluate(self, **kwargs) -> ConditionState: - time = kwargs[ConditionRequires.time.name] - if self.start <= time < self.end or self.end == sys.maxsize: - return ConditionState.TRUE - elif time > self.end: - return ConditionState.EXPIRED - return ConditionState.BEFORE - - @property - def requires(self) -> ConditionRequires: - return ConditionRequires.time - - -@dataclass(frozen=True) -class DependeeActorCondition(Condition): - """This condition should be true if the given actor exists.""" - - actor_id: str - """The id of an actor in the simulation that needs to exist for this condition to be true.""" - - def evaluate(self, **kwargs) -> ConditionState: - actor_ids = kwargs[self.requires.name] - if self.actor_id in actor_ids: - return ConditionState.TRUE - return ConditionState.FALSE - - @property - def requires(self) -> ConditionRequires: - return ConditionRequires.actor_ids - - def __post_init__(self): - assert isinstance(self.actor_id, str) - - -@dataclass(frozen=True) -class NegatedCondition(Condition): - """This condition negates the inner condition to flip between TRUE and FALSE. - - Note\\: This erases temporal values EXPIRED and BEFORE. - """ - - inner_condition: Condition - """The inner condition to negate.""" - - def evaluate(self, **kwargs) -> ConditionState: - result = self.inner_condition.evaluate(**kwargs) - if ConditionState.TRUE in result: - return ConditionState.FALSE - return ConditionState.TRUE - - @property - def requires(self) -> ConditionRequires: - return self.inner_condition.requires - - def __post_init__(self): - if self.inner_condition.__class__ in _abstract_conditions: - raise TypeError( - f"Abstract `{self.inner_condition.__class__.__name__}` cannot use the negation operation." - ) - - -@dataclass(frozen=True) -class ExpireTrigger(Condition): - """This condition allows for expiration after a given time.""" - - inner_condition: Condition - """The inner condition to delay.""" - - time: float - """The simulation time when this trigger becomes expired.""" - - expired_state: ConditionState = ConditionState.EXPIRED - """The state value this trigger should have when it expires.""" - - relative: bool = False - """If this should start relative to the first time evaluated.""" - - def evaluate(self, **kwargs) -> ConditionState: - time = kwargs[ConditionRequires.time.name] - if self.relative: - key = "met" - met_time = getattr(self, key, -1) - if met_time == -1: - object.__setattr__(self, key, time) - time = 0 - else: - time -= met_time - if time >= self.time: - return self.expired_state - return self.inner_condition.evaluate(**kwargs) - - @cached_property - def requires(self) -> ConditionRequires: - return self.inner_condition.requires | ConditionRequires.time - - def __post_init__(self): - if self.inner_condition.__class__ in _abstract_conditions: - raise TypeError( - f"Abstract `{self.inner_condition.__class__.__name__}` cannot be wrapped by a trigger." - ) - - -@dataclass(frozen=True) -class ConditionTrigger(Condition): - """This condition is a trigger that assumes an untriggered constant state and then turns to the other state permanently - on the inner condition becoming TRUE. There is also an option to delay repsonse to the the inner condition by a number - of seconds. This will convey an EXPIRED value immediately because that state means the inner value will never be TRUE. - - This can be used to wait for some time after the inner condition has become TRUE to trigger. - Note that the original condition may no longer be true by the time delay has expired. - - This will never resolve TRUE on the first evaluate. - """ - - inner_condition: Condition - """The inner condition to delay.""" - - delay_seconds: float - """The number of seconds to delay for.""" - - untriggered_state: ConditionState = ConditionState.BEFORE - """The state before the inner trigger condition and delay is resolved.""" - - triggered_state: ConditionState = ConditionState.TRUE - """The state after the inner trigger condition and delay is resolved.""" - - persistant: bool = False - """If the inner condition state is used in conjuction with the triggered state. (inner_condition_state & triggered_state)""" - - def evaluate(self, **kwargs) -> ConditionState: - time = kwargs[ConditionRequires.time.name] - key = "met_time" - result = self.untriggered_state - met_time = getattr(self, key, -1) - if met_time == -1: - if self.inner_condition.evaluate(**kwargs): - object.__setattr__(self, key, time) - time = 0 - else: - time = -1 - else: - time -= met_time - if time >= self.delay_seconds: - result = self.triggered_state - if self.persistant: - result &= self.inner_condition.evaluate(**kwargs) - return result - - temporals = result & (ConditionState.EXPIRED) - if ConditionState.EXPIRED in temporals: - return ConditionState.EXPIRED - return self.untriggered_state - - @property - def requires(self) -> ConditionRequires: - return self.inner_condition.requires | ConditionRequires.time - - def __post_init__(self): - if self.inner_condition.__class__ in _abstract_conditions: - raise TypeError( - f"Abstract `{self.inner_condition.__class__.__name__}` cannot be wrapped by a trigger." - ) - if self.delay_seconds < 0: - raise ValueError("Delay cannot be negative.") - - -@dataclass(frozen=True) -class OffRoadCondition(SubjectCondition): - """This condition is true if the subject is on road.""" - - def evaluate(self, **kwargs) -> ConditionState: - current_actor_road_status = kwargs[self.requires.name] - if ( - current_actor_road_status.road is None - and not current_actor_road_status.off_road - ): - return ConditionState.BEFORE - return ( - ConditionState.TRUE - if current_actor_road_status.off_road - else ConditionState.FALSE - ) - - @property - def requires(self) -> ConditionRequires: - return ConditionRequires.current_actor_road_status - - -@dataclass(frozen=True) -class VehicleTypeCondition(SubjectCondition): - """This condition is true if the subject is of the given vehicle types.""" - - vehicle_type: str - - def evaluate(self, **kwargs) -> ConditionState: - current_actor_state = kwargs[self.requires.name] - return ( - ConditionState.TRUE - if current_actor_state.vehicle_config_type == self.vehicle_type - else ConditionState.FALSE - ) - - @property - def requires(self) -> ConditionRequires: - return ConditionRequires.current_actor_state - - -@dataclass(frozen=True) -class VehicleSpeedCondition(SubjectCondition): - """This condition is true if the subject has a speed between low and high.""" - - low: float - """The lowest speed allowed.""" - - high: float - """The highest speed allowed.""" - - def evaluate(self, **kwargs) -> ConditionState: - vehicle_state = kwargs[self.requires.name] - return ( - ConditionState.TRUE - if self.low <= vehicle_state.speed <= self.high - else ConditionState.FALSE - ) - - @property - def requires(self) -> ConditionRequires: - return ConditionRequires.current_actor_state - - @classmethod - def loitering(cls: Type["VehicleSpeedCondition"], abs_error=0.01): - """Generates a speed condition which assumes that the subject is stationary.""" - return cls(low=abs_error, high=abs_error) - - -@dataclass(frozen=True) -class CompoundCondition(Condition): - """This compounds multiple conditions. - - The following cases are notable - CONJUNCTION (A AND B) - If both conditions evaluate TRUE the result is exclusively TRUE. - Else if either condition evaluates EXPIRED the result will be EXPIRED. - Else if either condition evaluates BEFORE the result will be BEFORE. - Else FALSE - DISJUNCTION (A OR B) - If either condition evaluates TRUE the result is exclusively TRUE. - Else if either condition evaluates BEFORE then the result will be BEFORE. - Else if both conditions evaluate EXPIRED then the result will be EXPIRED. - Else FALSE - IMPLICATION (A AND B or not A) - If the first condition evaluates *not* TRUE the result is exclusively TRUE. - Else if the first condition evaluates TRUE and the second condition evaluates TRUE the result is exclusively TRUE. - Else FALSE - """ - - first_condition: Condition - """The first condition.""" - - second_condition: Condition - """The second condition.""" - - operator: ConditionOperator - """The operator used to combine these conditions.""" - - def evaluate(self, **kwargs) -> ConditionState: - # Short circuits - first_eval = self.first_condition.evaluate(**kwargs) - if ( - self.operator == ConditionOperator.CONJUNCTION - and ConditionState.EXPIRED in first_eval - ): - return ConditionState.EXPIRED - elif ( - self.operator == ConditionOperator.DISJUNCTION - and ConditionState.TRUE in first_eval - ): - return ConditionState.TRUE - elif ( - self.operator == ConditionOperator.IMPLICATION - and ConditionState.TRUE not in first_eval - ): - return ConditionState.TRUE - - second_eval = self.second_condition.evaluate(**kwargs) - if ( - self.operator == ConditionOperator.IMPLICATION - and ConditionState.TRUE in first_eval - and ConditionState.TRUE in second_eval - ): - return ConditionState.TRUE - - elif self.operator == ConditionOperator.CONJUNCTION: - conjuction = first_eval & second_eval - if ConditionState.TRUE in conjuction: - return ConditionState.TRUE - - # To priority of temporal versions of FALSE - disjunction = first_eval | second_eval - if ConditionState.EXPIRED in disjunction: - return ConditionState.EXPIRED - - if ConditionState.BEFORE in disjunction: - return ConditionState.BEFORE - - elif self.operator == ConditionOperator.DISJUNCTION: - result = first_eval | second_eval - - if ConditionState.TRUE in result: - return ConditionState.TRUE - - if ConditionState.BEFORE in result: - return ConditionState.BEFORE - - if ConditionState.EXPIRED in first_eval & second_eval: - return ConditionState.EXPIRED - - return ConditionState.FALSE - - @cached_property - def requires(self) -> ConditionRequires: - return self.first_condition.requires | self.second_condition.requires - - def __post_init__(self): - for condition in (self.first_condition, self.second_condition): - if condition.__class__ in _abstract_conditions: - raise TypeError( - f"Abstract `{condition.__class__.__name__}` cannot use compound operations." - ) - - -@dataclass(frozen=True) -class EntryTactic: - """The tactic that the simulation should use to acquire a vehicle for an agent.""" - - start_time: float - - def __post_init__(self): - assert ( - getattr(self, "condition", None) is not None - ), "Abstract class, inheriting types must implement the `condition` field." - - -@dataclass(frozen=True) -class TrapEntryTactic(EntryTactic): - """An entry tactic that repurposes a pre-existing vehicle for an agent.""" - - wait_to_hijack_limit_s: float = 0 - """The amount of seconds a hijack will wait to get a vehicle before defaulting to a new vehicle""" - zone: Optional["MapZone"] = None - """The zone of the hijack area""" - exclusion_prefixes: Tuple[str, ...] = tuple() - """The prefixes of vehicles to avoid hijacking""" - default_entry_speed: Optional[float] = None - """The speed that the vehicle starts at when the hijack limit expiry emits a new vehicle""" - condition: Condition = LiteralCondition(ConditionState.TRUE) - """A condition that is used to add additional exclusions.""" - - def __post_init__(self): - assert isinstance(self.condition, (Condition)) - assert not ( - self.condition.requires & ConditionRequires.any_current_actor_state - ), f"Trap entry tactic cannot use conditions that require any_vehicle_state." - - -@dataclass(frozen=True) -class IdEntryTactic(EntryTactic): - """An entry tactic which repurposes a pre-existing actor for an agent. Selects that actor by id.""" - - actor_id: str - """The id of the actor to take over.""" - - condition: Condition = TimeWindowCondition(0.1, sys.maxsize) - """A condition that is used to add additional exclusions.""" - - def __post_init__(self): - assert isinstance(self.actor_id, str) - assert isinstance(self.condition, (Condition)) - - -@dataclass(frozen=True) -class Mission: - """The descriptor for an actor's mission.""" - - route: Union[RandomRoute, Route] - """The route for the actor to attempt to follow.""" - - via: Tuple[Via, ...] = () - """Points on an road that an actor must pass through""" - - start_time: float = MISSING - """The earliest simulation time that this mission starts but may start later in couple with - `entry_tactic`. - """ - - entry_tactic: Optional[EntryTactic] = None - """A specific tactic the mission should employ to start the mission.""" - - def __post_init__(self): - if self.start_time != sys.maxsize: - warnings.warn( - "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", - category=DeprecationWarning, - ) - - -@dataclass(frozen=True) -class EndlessMission: - """The descriptor for an actor's mission that has no end.""" - - begin: Tuple[str, int, float] - """The (road, lane_index, offset) details of the start location for the route. - - road: - The starting road by name. - lane_index: - The lane index from the rightmost lane. - offset: - The offset in metres into the lane. Also acceptable\\: 'max', 'random' - """ - via: Tuple[Via, ...] = () - """Points on a road that an actor must pass through""" - start_time: float = MISSING - """The earliest simulation time that this mission starts""" - entry_tactic: Optional[EntryTactic] = None - """A specific tactic the mission should employ to start the mission""" - - def __post_init__(self): - if self.start_time != sys.maxsize: - warnings.warn( - "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", - category=DeprecationWarning, - ) - - -@dataclass(frozen=True) -class LapMission: - """The descriptor for an actor's mission that defines mission that repeats - in a closed loop. - """ - - route: Route - """The route for the actor to attempt to follow""" - num_laps: int - """The amount of times to repeat the mission""" - via: Tuple[Via, ...] = () - """Points on a road that an actor must pass through""" - start_time: float = MISSING - """The earliest simulation time that this mission starts""" - entry_tactic: Optional[EntryTactic] = None - """A specific tactic the mission should employ to start the mission""" - - def __post_init__(self): - if self.start_time != sys.maxsize: - warnings.warn( - "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", - category=DeprecationWarning, - ) - - -@dataclass(frozen=True) -class GroupedLapMission: - """The descriptor for a group of actor missions that repeat in a closed loop.""" - - route: Route - """The route for the actors to attempt to follow""" - offset: int - """The offset of the "starting line" for the group""" - lanes: int - """The number of lanes the group occupies""" - actor_count: int - """The number of actors to be part of the group""" - num_laps: int - """The amount of times to repeat the mission""" - via: Tuple[Via, ...] = () - """Points on a road that an actor must pass through""" - entry_tactic: Optional[EntryTactic] = None - """A specific tactic the mission should employ to start the mission""" - - -@dataclass(frozen=True) -class Zone: - """The base for a descriptor that defines a capture area.""" - - def to_geometry(self, road_map: Optional[RoadMap] = None) -> Polygon: - """Generates the geometry from this zone.""" - raise NotImplementedError - - -@dataclass(frozen=True) -class MapZone(Zone): - """A descriptor that defines a capture area.""" - - start: Tuple[str, int, float] - """The (road_id, lane_index, offset) details of the starting location. - - road_id: - The starting road by name. - lane_index: - The lane index from the rightmost lane. - offset: - The offset in metres into the lane. Also acceptable\\: 'max', 'random' - """ - length: float - """The length of the geometry along the center of the lane. Also acceptable\\: 'max'""" - n_lanes: int = 2 - """The number of lanes from right to left that this zone covers.""" - - def to_geometry(self, road_map: Optional[RoadMap]) -> Polygon: - """Generates a map zone over a stretch of the given lanes.""" - - assert ( - road_map is not None - ), f"{self.__class__.__name__} requires a road map to resolve geometry." - - def resolve_offset(offset, geometry_length, lane_length): - if offset == "base": - return 0 - # push off of end of lane - elif offset == "max": - return lane_length - geometry_length - elif offset == "random": - return random.uniform(0, lane_length - geometry_length) - else: - return float(offset) - - def pick_remaining_shape_after_split(geometry_collection, expected_point): - lane_shape = geometry_collection - if not isinstance(lane_shape, GeometryCollection): - return lane_shape - - # For simplicity, we only deal w/ the == 1 or 2 case - if len(lane_shape.geoms) not in {1, 2}: - return None - - if len(lane_shape.geoms) == 1: - return lane_shape.geoms[0] - - # We assume that there are only two split shapes to choose from - keep_index = 0 - if lane_shape.geoms[1].minimum_rotated_rectangle.contains(expected_point): - # 0 is the discard piece, keep the other - keep_index = 1 - - lane_shape = lane_shape.geoms[keep_index] - - return lane_shape - - def split_lane_shape_at_offset( - lane_shape: Polygon, lane: RoadMap.Lane, offset: float - ): - # XXX: generalize to n-dim - width_2, _ = lane.width_at_offset(offset) - point = np.array(lane.from_lane_coord(RefLinePoint(offset)))[:2] - lane_vec = lane.vector_at_offset(offset)[:2] - - perp_vec_right = rotate_cw_around_point(lane_vec, np.pi / 2, origin=(0, 0)) - perp_vec_right = ( - perp_vec_right / max(np.linalg.norm(perp_vec_right), 1e-3) * width_2 - + point - ) - - perp_vec_left = rotate_cw_around_point(lane_vec, -np.pi / 2, origin=(0, 0)) - perp_vec_left = ( - perp_vec_left / max(np.linalg.norm(perp_vec_left), 1e-3) * width_2 - + point - ) - - split_line = LineString([perp_vec_left, perp_vec_right]) - return split(lane_shape, split_line) - - lane_shapes = [] - road_id, lane_idx, offset = self.start - road = road_map.road_by_id(road_id) - buffer_from_ends = 1e-6 - for lane_idx in range(lane_idx, lane_idx + self.n_lanes): - lane = road.lane_at_index(lane_idx) - lane_length = lane.length - geom_length = self.length - - if geom_length > lane_length: - logging.debug( - f"Geometry is too long={geom_length} with offset={offset} for " - f"lane={lane.lane_id}, using length={lane_length} instead" - ) - geom_length = lane_length - - assert geom_length > 0 # Geom length is negative - - lane_offset = resolve_offset(offset, geom_length, lane_length) - lane_offset += buffer_from_ends - width, _ = lane.width_at_offset(lane_offset) # TODO - lane_shape = lane.shape(0.3, width) # TODO - - geom_length = max(geom_length - buffer_from_ends, buffer_from_ends) - lane_length = max(lane_length - buffer_from_ends, buffer_from_ends) - - min_cut = min(lane_offset, lane_length) - # Second cut takes into account shortening of geometry by `min_cut`. - max_cut = min(min_cut + geom_length, lane_length) - - midpoint = Point( - *lane.from_lane_coord(RefLinePoint(s=lane_offset + geom_length * 0.5)) - ) - - lane_shape = split_lane_shape_at_offset(lane_shape, lane, min_cut) - lane_shape = pick_remaining_shape_after_split(lane_shape, midpoint) - if lane_shape is None: - continue - - lane_shape = split_lane_shape_at_offset( - lane_shape, - lane, - max_cut, - ) - lane_shape = pick_remaining_shape_after_split(lane_shape, midpoint) - if lane_shape is None: - continue - - lane_shapes.append(lane_shape) - - geom = unary_union(MultiPolygon(lane_shapes)) - return geom - - -@dataclass(frozen=True) -class PositionalZone(Zone): - """A descriptor that defines a capture area at a specific XY location.""" - - # center point - pos: Tuple[float, float] - """A (x,y) position of the zone in the scenario.""" - size: Tuple[float, float] - """The (length, width) dimensions of the zone.""" - rotation: Optional[float] = None - """The heading direction of the bubble. (radians, clock-wise rotation)""" - - def to_geometry(self, road_map: Optional[RoadMap] = None) -> Polygon: - """Generates a box zone at the given position.""" - w, h = self.size - x, y = self.pos[:2] - p0 = (-w / 2, -h / 2) # min - p1 = (w / 2, h / 2) # max - poly = Polygon([p0, (p0[0], p1[1]), p1, (p1[0], p0[1])]) - if self.rotation is not None: - poly = shapely_rotate(poly, self.rotation, use_radians=True) - return shapely_translate(poly, xoff=x, yoff=y) - - -@dataclass(frozen=True) -class ConfigurableZone(Zone): - """A descriptor for a zone with user-defined geometry.""" - - ext_coordinates: List[Tuple[float, float]] - """external coordinates of the polygon - < 2 points provided: error - = 2 points provided: generates a box using these two points as diagonal - > 2 points provided: generates a polygon according to the coordinates""" - rotation: Optional[float] = None - """The heading direction of the bubble(radians, clock-wise rotation)""" - - def __post_init__(self): - if ( - not self.ext_coordinates - or len(self.ext_coordinates) < 2 - or not isinstance(self.ext_coordinates[0], tuple) - ): - raise ValueError( - "Two points or more are needed to create a polygon. (less than two points are provided)" - ) - - x_set = set(point[0] for point in self.ext_coordinates) - y_set = set(point[1] for point in self.ext_coordinates) - if len(x_set) == 1 or len(y_set) == 1: - raise ValueError( - "Parallel line cannot form a polygon. (points provided form a parallel line)" - ) - - def to_geometry(self, road_map: Optional[RoadMap] = None) -> Polygon: - """Generate a polygon according to given coordinates""" - poly = None - if ( - len(self.ext_coordinates) == 2 - ): # if user only specified two points, create a box - x_min = min(self.ext_coordinates[0][0], self.ext_coordinates[1][0]) - x_max = max(self.ext_coordinates[0][0], self.ext_coordinates[1][0]) - y_min = min(self.ext_coordinates[0][1], self.ext_coordinates[1][1]) - y_max = max(self.ext_coordinates[0][1], self.ext_coordinates[1][1]) - poly = box(x_min, y_min, x_max, y_max) - - else: # else create a polygon according to the coordinates - poly = Polygon(self.ext_coordinates) - - if self.rotation is not None: - poly = shapely_rotate(poly, self.rotation, use_radians=True) - return poly - - -@dataclass(frozen=True) -class BubbleLimits: - """Defines the capture limits of a bubble.""" - - hijack_limit: int = sys.maxsize - """The maximum number of vehicles the bubble can hijack""" - shadow_limit: int = sys.maxsize - """The maximum number of vehicles the bubble can shadow""" - - def __post_init__(self): - if self.shadow_limit is None: - raise ValueError("Shadow limit must be a non-negative real number") - if self.hijack_limit is None or self.shadow_limit < self.hijack_limit: - raise ValueError("Shadow limit must be >= hijack limit") - - -@dataclass(frozen=True) -class Bubble: - """A descriptor that defines a capture bubble for social agents.""" - - zone: Zone - """The zone which to capture vehicles.""" - actor: SocialAgentActor - """The actor specification that this bubble works for.""" - margin: float = 2 # Used for "airlocking"; must be > 0 - """The exterior buffer area for airlocking. Must be > 0.""" - # If limit != None it will only allow that specified number of vehicles to be - # hijacked. N.B. when actor = BoidAgentActor the lesser of the actor capacity - # and bubble limit will be used. - limit: Optional[BubbleLimits] = None - """The maximum number of actors that could be captured.""" - exclusion_prefixes: Tuple[str, ...] = field(default_factory=tuple) - """Used to exclude social actors from capture.""" - id: str = field(default_factory=lambda: f"bubble-{gen_id()}") - follow_actor_id: Optional[str] = None - """Actor ID of agent we want to pin to. Doing so makes this a "travelling bubble" - which means it moves to follow the `follow_actor_id`'s vehicle. Offset is from the - vehicle's center position to the bubble's center position. - """ - follow_offset: Optional[Tuple[float, float]] = None - """Maintained offset to place the travelling bubble relative to the follow - vehicle if it were facing north. - """ - keep_alive: bool = False - """If enabled, the social agent actor will be spawned upon first vehicle airlock - and be reused for every subsequent vehicle entering the bubble until the episode - is over. - """ - follow_vehicle_id: Optional[str] = None - """Vehicle ID of a vehicle we want to pin to. Doing so makes this a "travelling bubble" - which means it moves to follow the `follow_vehicle_id`'s vehicle. Offset is from the - vehicle's center position to the bubble's center position. - """ - - def __post_init__(self): - if self.margin < 0: - raise ValueError("Airlocking margin must be greater than 0") - - if self.follow_actor_id is not None and self.follow_vehicle_id is not None: - raise ValueError( - "Only one option of follow actor id and follow vehicle id can be used at any time." - ) - - if ( - self.follow_actor_id is not None or self.follow_vehicle_id is not None - ) and self.follow_offset is None: - raise ValueError( - "A follow offset must be set if this is a travelling bubble" - ) - - if self.keep_alive and not self.is_boid: - # TODO: We may want to remove this restriction in the future - raise ValueError( - "Only boids can have keep_alive enabled (for persistent boids)" - ) - - if not isinstance(self.zone, MapZone): - poly = self.zone.to_geometry(road_map=None) - if not poly.is_valid: - follow_id = ( - self.follow_actor_id - if self.follow_actor_id - else self.follow_vehicle_id - ) - raise ValueError( - f"The zone polygon of {type(self.zone).__name__} of moving {self.id} which following {follow_id} is not a valid closed loop" - if follow_id - else f"The zone polygon of {type(self.zone).__name__} of fixed position {self.id} is not a valid closed loop" - ) - - @staticmethod - def to_actor_id(actor, mission_group): - """Mashes the actor id and mission group to create what needs to be a unique id.""" - return SocialAgentId.new(actor.name, group=mission_group) - - @property - def is_boid(self): - """Tests if the actor is to control multiple vehicles.""" - return isinstance(self.actor, BoidAgentActor) - - -@dataclass(frozen=True) -class RoadSurfacePatch: - """A descriptor that defines a patch of road surface with a different friction coefficient.""" - - zone: Zone - """The zone which to capture vehicles.""" - begin_time: int - """The start time in seconds of when this surface is active.""" - end_time: int - """The end time in seconds of when this surface is active.""" - friction_coefficient: float - """The surface friction coefficient.""" - - -@dataclass(frozen=True) -class ActorAndMission: - """Holds an Actor object and its associated Mission.""" - - actor: Actor - """Specification for traffic actor. - """ - mission: Union[Mission, EndlessMission, LapMission] - """Mission for traffic actor. - """ - - -@dataclass(frozen=True) -class TrafficHistoryDataset: - """Describes a dataset containing trajectories (time-stamped positions) - for a set of vehicles. Often these have been collected by third parties - from real-world observations, hence the name 'history'. When used - with a SMARTS scenario, traffic vehicles will move on the map according - to their trajectories as specified in the dataset. These can be mixed - with other types of traffic (such as would be specified by an object of - the Traffic type in this DSL). In order to use this efficiently, SMARTS - will pre-process ('import') the dataset when the scenario is built.""" - - name: str - """a unique name for the dataset""" - source_type: str - """the type of the dataset; supported values include: NGSIM, INTERACTION, Waymo""" - input_path: Optional[str] = None - """a relative or absolute path to the dataset; if omitted, dataset will not be imported""" - scenario_id: Optional[str] = None - """a unique ID for a Waymo scenario. For other datasets, this field will be None.""" - x_margin_px: float = 0.0 - """x offset of the map from the data (in pixels)""" - y_margin_px: float = 0.0 - """y offset of the map from the data (in pixels)""" - swap_xy: bool = False - """if True, the x and y axes the dataset coordinate system will be swapped""" - flip_y: bool = False - """if True, the dataset will be mirrored around the x-axis""" - filter_off_map: bool = False - """if True, then any vehicle whose coordinates on a time step fall outside of the map's bounding box will be removed for that time step""" - - map_lane_width: float = 3.7 - """This is used to figure out the map scale, which is map_lane_width / real_lane_width_m. (So use `real_lane_width_m` here for 1:1 scale - the default.) It's also used in SMARTS for detecting off-road, etc.""" - real_lane_width_m: float = 3.7 - """Average width in meters of the dataset's lanes in the real world. US highway lanes are about 12 feet (or ~3.7m, the default) wide.""" - speed_limit_mps: Optional[float] = None - """used by SMARTS for the initial speed of new agents being added to the scenario""" - - heading_inference_window: int = 2 - """When inferring headings from positions, a sliding window (moving average) of this size will be used to smooth inferred headings and reduce their dependency on any individual position changes. Defaults to 2 if not specified.""" - heading_inference_min_speed: float = 2.2 - """Speed threshold below which a vehicle's heading is assumed not to change. This is useful to prevent abnormal heading changes that may arise from noise in position estimates in a trajectory dataset dominating real position changes in situations where the real position changes are very small. Defaults to 2.2 m/s if not specified.""" - max_angular_velocity: Optional[float] = None - """When inferring headings from positions, each vehicle's angular velocity will be limited to be at most this amount (in rad/sec) to prevent lateral-coordinate noise in the dataset from causing near-instantaneous heading changes.""" - default_heading: float = 1.5 * math.pi - """A heading in radians to be used by default for vehicles if the headings are not present in the dataset and cannot be inferred from position changes (such as on the first time step).""" - - -@dataclass(frozen=True) -class ScenarioMetadata: - """Scenario data that does not have influence on simulation.""" - - actor_of_interest_re_filter: str - """Vehicles with names that match this pattern are vehicles of interest.""" - actor_of_interest_color: Colors - """The color that the vehicles of interest should have.""" - - -@dataclass(frozen=True) -class Scenario: - """The sstudio scenario representation.""" - - map_spec: Optional[MapSpec] = None - """Specifies the road map.""" - traffic: Optional[Dict[str, Traffic]] = None - """Background traffic vehicle specification.""" - ego_missions: Optional[Sequence[Union[Mission, EndlessMission]]] = None - """Ego agent missions.""" - social_agent_missions: Optional[ - Dict[str, Tuple[Sequence[SocialAgentActor], Sequence[Mission]]] - ] = None - """ - Actors must have unique names regardless of which group they are assigned to. - Every dictionary item ``{group: (actors, missions)}`` gets selected from simultaneously. - If actors > 1 and missions = 0 or actors = 1 and missions > 0, we cycle - through them every episode. Otherwise actors must be the same length as - missions. - """ - bubbles: Optional[Sequence[Bubble]] = None - """Capture bubbles for focused social agent simulation.""" - friction_maps: Optional[Sequence[RoadSurfacePatch]] = None - """Friction coefficient of patches of road surface.""" - traffic_histories: Optional[Sequence[Union[TrafficHistoryDataset, str]]] = None - """Traffic vehicles trajectory dataset to be replayed.""" - scenario_metadata: Optional[ScenarioMetadata] = None - """"Scenario data that does not have influence on simulation.""" - - def __post_init__(self): - def _get_name(item): - return item.name - - if self.social_agent_missions is not None: - groups = [k for k in self.social_agent_missions] - for group, (actors, _) in self.social_agent_missions.items(): - for o_group in groups: - if group == o_group: - continue - if intersection := set.intersection( - set(map(_get_name, actors)), - map(_get_name, self.social_agent_missions[o_group][0]), - ): - raise ValueError( - f"Social agent mission groups `{group}`|`{o_group}` have overlapping actors {intersection}" - ) diff --git a/smarts/sstudio/types/__init__.py b/smarts/sstudio/types/__init__.py new file mode 100644 index 0000000000..9d5b830ae5 --- /dev/null +++ b/smarts/sstudio/types/__init__.py @@ -0,0 +1,35 @@ +# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +from smarts.sstudio.types.actor.social_agent_actor import * +from smarts.sstudio.types.actor.traffic_actor import * +from smarts.sstudio.types.bubble import * +from smarts.sstudio.types.bubble_limits import * +from smarts.sstudio.types.condition import * +from smarts.sstudio.types.constants import * +from smarts.sstudio.types.dataset import * +from smarts.sstudio.types.distribution import * +from smarts.sstudio.types.entry_tactic import * +from smarts.sstudio.types.map_spec import * +from smarts.sstudio.types.mission import * +from smarts.sstudio.types.route import * +from smarts.sstudio.types.scenario import * +from smarts.sstudio.types.traffic import * +from smarts.sstudio.types.traffic_model import * +from smarts.sstudio.types.zone import * diff --git a/smarts/sstudio/types/actor/__init__.py b/smarts/sstudio/types/actor/__init__.py new file mode 100644 index 0000000000..04749050ae --- /dev/null +++ b/smarts/sstudio/types/actor/__init__.py @@ -0,0 +1,31 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Actor: + """This is the base description/spec type for traffic actors.""" + + pass diff --git a/smarts/sstudio/types/actor/social_agent_actor.py b/smarts/sstudio/types/actor/social_agent_actor.py new file mode 100644 index 0000000000..6096a847b8 --- /dev/null +++ b/smarts/sstudio/types/actor/social_agent_actor.py @@ -0,0 +1,67 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +from smarts.core import gen_id +from smarts.sstudio.types.actor import Actor +from smarts.sstudio.types.bubble_limits import BubbleLimits + + +@dataclass(frozen=True) +class SocialAgentActor(Actor): + """Used as a description/spec for zoo traffic actors. These actors use a + pre-trained model to understand how to act in the environment. + """ + + name: str + """The name of the social actor. Must be unique.""" + + # A pre-registered zoo identifying tag you provide to help SMARTS identify the + # prefab of a social agent. + agent_locator: str + """The locator reference to the zoo registration call. Expects a string in the format + of 'path.to.file:locator-name' where the path to the registration call is in the form + {PYTHONPATH}[n]/path/to/file.py + """ + policy_kwargs: Dict[str, Any] = field(default_factory=dict) + """Additional keyword arguments to be passed to the constructed class overriding the + existing registered arguments. + """ + initial_speed: Optional[float] = None + """Set the initial speed, defaults to 0.""" + + +@dataclass(frozen=True) +class BoidAgentActor(SocialAgentActor): + """Used as a description/spec for boid traffic actors. Boid actors control multiple + vehicles. + """ + + id: str = field(default_factory=lambda: f"boid-{gen_id()}") + + # The max number of vehicles that this agent will control at a time. This value is + # honored when using a bubble for boid dynamic assignment. + capacity: BubbleLimits = None + """The capacity of the boid agent to take over vehicles.""" diff --git a/smarts/sstudio/types/actor/traffic_actor.py b/smarts/sstudio/types/actor/traffic_actor.py new file mode 100644 index 0000000000..3f510b5b4b --- /dev/null +++ b/smarts/sstudio/types/actor/traffic_actor.py @@ -0,0 +1,75 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +from dataclasses import dataclass, field +from typing import Union + +from smarts.core.utils.file import pickle_hash_int +from smarts.sstudio.types.actor import Actor +from smarts.sstudio.types.distribution import Distribution +from smarts.sstudio.types.traffic_model import JunctionModel, LaneChangingModel + + +@dataclass(frozen=True) +class TrafficActor(Actor): + """Used as a description/spec for traffic actors (e.x. Vehicles, Pedestrians, + etc). The defaults provided are for a car, but the name is not set to make it + explicit that you actually want a car. + """ + + name: str + """The name of the traffic actor. It must be unique.""" + accel: float = 2.6 + """The maximum acceleration value of the actor (in m/s^2).""" + decel: float = 4.5 + """The maximum deceleration value of the actor (in m/s^2).""" + tau: float = 1.0 + """The minimum time headway""" + sigma: float = 0.5 + """The driver imperfection""" # TODO: appears to not be used in generators.py + depart_speed: Union[float, str] = "max" + """The starting speed of the actor""" + emergency_decel: float = 4.5 + """maximum deceleration ability of vehicle in case of emergency""" + speed: Distribution = Distribution(mean=1.0, sigma=0.1) + """The speed distribution of this actor in m/s.""" + imperfection: Distribution = Distribution(mean=0.5, sigma=0) + """Driver imperfection within range [0..1]""" + min_gap: Distribution = Distribution(mean=2.5, sigma=0) + """Minimum gap (when standing) in meters.""" + max_speed: float = 55.5 + """The vehicle's maximum velocity (in m/s), defaults to 200 km/h for vehicles""" + vehicle_type: str = "passenger" + """The configured vehicle type this actor will perform as. ("passenger", "bus", "coach", "truck", "trailer")""" + lane_changing_model: LaneChangingModel = field( + default_factory=LaneChangingModel, hash=False + ) + junction_model: JunctionModel = field(default_factory=JunctionModel, hash=False) + + def __hash__(self) -> int: + return pickle_hash_int(self) + + @property + def id(self) -> str: + """The identifier tag of the traffic actor.""" + return "{}-{}".format(self.name, str(hash(self))[:6]) diff --git a/smarts/sstudio/types/bubble.py b/smarts/sstudio/types/bubble.py new file mode 100644 index 0000000000..be2f6e8115 --- /dev/null +++ b/smarts/sstudio/types/bubble.py @@ -0,0 +1,119 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +from dataclasses import dataclass, field +from typing import Optional, Tuple + +from smarts.core import gen_id +from smarts.core.utils.id import SocialAgentId +from smarts.sstudio.types.actor.social_agent_actor import ( + BoidAgentActor, + SocialAgentActor, +) +from smarts.sstudio.types.bubble_limits import BubbleLimits +from smarts.sstudio.types.zone import MapZone, Zone + + +@dataclass(frozen=True) +class Bubble: + """A descriptor that defines a capture bubble for social agents.""" + + zone: Zone + """The zone which to capture vehicles.""" + actor: SocialAgentActor + """The actor specification that this bubble works for.""" + margin: float = 2 # Used for "airlocking"; must be > 0 + """The exterior buffer area for airlocking. Must be > 0.""" + # If limit != None it will only allow that specified number of vehicles to be + # hijacked. N.B. when actor = BoidAgentActor the lesser of the actor capacity + # and bubble limit will be used. + limit: Optional[BubbleLimits] = None + """The maximum number of actors that could be captured.""" + exclusion_prefixes: Tuple[str, ...] = field(default_factory=tuple) + """Used to exclude social actors from capture.""" + id: str = field(default_factory=lambda: f"bubble-{gen_id()}") + follow_actor_id: Optional[str] = None + """Actor ID of agent we want to pin to. Doing so makes this a "travelling bubble" + which means it moves to follow the `follow_actor_id`'s vehicle. Offset is from the + vehicle's center position to the bubble's center position. + """ + follow_offset: Optional[Tuple[float, float]] = None + """Maintained offset to place the travelling bubble relative to the follow + vehicle if it were facing north. + """ + keep_alive: bool = False + """If enabled, the social agent actor will be spawned upon first vehicle airlock + and be reused for every subsequent vehicle entering the bubble until the episode + is over. + """ + follow_vehicle_id: Optional[str] = None + """Vehicle ID of a vehicle we want to pin to. Doing so makes this a "travelling bubble" + which means it moves to follow the `follow_vehicle_id`'s vehicle. Offset is from the + vehicle's center position to the bubble's center position. + """ + + def __post_init__(self): + if self.margin < 0: + raise ValueError("Airlocking margin must be greater than 0") + + if self.follow_actor_id is not None and self.follow_vehicle_id is not None: + raise ValueError( + "Only one option of follow actor id and follow vehicle id can be used at any time." + ) + + if ( + self.follow_actor_id is not None or self.follow_vehicle_id is not None + ) and self.follow_offset is None: + raise ValueError( + "A follow offset must be set if this is a travelling bubble" + ) + + if self.keep_alive and not self.is_boid: + # TODO: We may want to remove this restriction in the future + raise ValueError( + "Only boids can have keep_alive enabled (for persistent boids)" + ) + + if not isinstance(self.zone, MapZone): + poly = self.zone.to_geometry(road_map=None) + if not poly.is_valid: + follow_id = ( + self.follow_actor_id + if self.follow_actor_id + else self.follow_vehicle_id + ) + raise ValueError( + f"The zone polygon of {type(self.zone).__name__} of moving {self.id} which following {follow_id} is not a valid closed loop" + if follow_id + else f"The zone polygon of {type(self.zone).__name__} of fixed position {self.id} is not a valid closed loop" + ) + + @staticmethod + def to_actor_id(actor, mission_group): + """Mashes the actor id and mission group to create what needs to be a unique id.""" + return SocialAgentId.new(actor.name, group=mission_group) + + @property + def is_boid(self): + """Tests if the actor is to control multiple vehicles.""" + return isinstance(self.actor, BoidAgentActor) diff --git a/smarts/sstudio/types/bubble_limits.py b/smarts/sstudio/types/bubble_limits.py new file mode 100644 index 0000000000..a4255bff2a --- /dev/null +++ b/smarts/sstudio/types/bubble_limits.py @@ -0,0 +1,42 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +from dataclasses import dataclass + +from smarts.sstudio.types.constants import MAX + + +@dataclass(frozen=True) +class BubbleLimits: + """Defines the capture limits of a bubble.""" + + hijack_limit: int = MAX + """The maximum number of vehicles the bubble can hijack""" + shadow_limit: int = MAX + """The maximum number of vehicles the bubble can shadow""" + + def __post_init__(self): + if self.shadow_limit is None: + raise ValueError("Shadow limit must be a non-negative real number") + if self.hijack_limit is None or self.shadow_limit < self.hijack_limit: + raise ValueError("Shadow limit must be >= hijack limit") diff --git a/smarts/sstudio/types/condition.py b/smarts/sstudio/types/condition.py new file mode 100644 index 0000000000..92502adcf9 --- /dev/null +++ b/smarts/sstudio/types/condition.py @@ -0,0 +1,610 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import enum +import sys +from dataclasses import dataclass +from enum import IntEnum, IntFlag +from functools import cached_property +from typing import Type + +from smarts.core.condition_state import ConditionState + + +class ConditionOperator(IntEnum): + """Represents logical operators between conditions.""" + + CONJUNCTION = enum.auto() + """Evaluate true if both operands are true, otherwise false.""" + + DISJUNCTION = enum.auto() + """Evaluate true if either operand is true, otherwise false.""" + + IMPLICATION = enum.auto() + """Evaluate true if either the first operand is false, or both operands are true, otherwise false.""" + + ## This would be desirable but makes the implementation more difficult in comparison to a negated condition. + # NEGATION=enum.auto() + # """True if its operand is false, otherwise false.""" + + +class ConditionRequires(IntFlag): + """This bitfield lays out the required information that a condition needs in order to evaluate.""" + + none = 0 + + # MISSION CONSTANTS + agent_id = enum.auto() + mission = enum.auto() + + # SIMULATION STATE + time = enum.auto() + actor_ids = enum.auto() + actor_states = enum.auto() + road_map = enum.auto() + simulation = enum.auto() + + # ACTOR STATE + current_actor_state = enum.auto() + current_actor_road_status = enum.auto() + + any_simulation_state = time | actor_ids | actor_states | simulation + any_current_actor_state = mission | current_actor_state | current_actor_road_status + any_mission_state = agent_id | mission + + +@dataclass(frozen=True) +class Condition: + """This encompasses an expression to evaluate to a logical result.""" + + def evaluate(self, **kwargs) -> ConditionState: + """Used to evaluate if a condition is met. + + Returns: + ConditionState: The evaluation result of the condition. + """ + raise NotImplementedError() + + @property + def requires(self) -> ConditionRequires: + """Information that the condition requires to evaluate state. + + Returns: + ConditionRequires: The types of information this condition needs in order to evaluate. + """ + raise NotImplementedError() + + def negation(self) -> "NegatedCondition": + """Negates this condition giving the opposite result on evaluation. + + >>> condition_true = LiteralCondition(ConditionState.TRUE) + >>> condition_true.evaluate() + + >>> condition_false = condition_true.negation() + >>> condition_false.evaluate() + + + Note\\: This erases temporal values EXPIRED and BEFORE. + >>> condition_before = LiteralCondition(ConditionState.BEFORE) + >>> condition_before.negation().negation().evaluate() + + + Returns: + NegatedCondition: The wrapped condition. + """ + return NegatedCondition(self) + + def conjunction(self, other: "Condition") -> "CompoundCondition": + """Resolve conditions as A AND B. + + The bit AND operator has been overloaded to call this method. + >>> dependee_condition = DependeeActorCondition("leader") + >>> dependee_condition.evaluate(actor_ids={"leader"}) + + >>> conjunction = dependee_condition & LiteralCondition(ConditionState.FALSE) + >>> conjunction.evaluate(actor_ids={"leader"}) + + + Note that the resolution has the priority EXPIRED > BEFORE > FALSE > TRUE. + >>> conjunction = LiteralCondition(ConditionState.TRUE) & LiteralCondition(ConditionState.BEFORE) + >>> conjunction.evaluate() + + >>> (conjunction & LiteralCondition(ConditionState.EXPIRED)).evaluate() + + + Returns: + CompoundCondition: A condition combining two conditions using an AND operation. + """ + return CompoundCondition(self, other, operator=ConditionOperator.CONJUNCTION) + + def disjunction(self, other: "Condition") -> "CompoundCondition": + """Resolve conditions as A OR B. + + The bit OR operator has been overloaded to call this method. + >>> disjunction = LiteralCondition(ConditionState.TRUE) | LiteralCondition(ConditionState.BEFORE) + >>> disjunction.evaluate() + + + Note that the resolution has the priority TRUE > BEFORE > FALSE > EXPIRED. + >>> disjunction = LiteralCondition(ConditionState.FALSE) | LiteralCondition(ConditionState.EXPIRED) + >>> disjunction.evaluate() + + >>> (disjunction | LiteralCondition(ConditionState.BEFORE)).evaluate() + + """ + return CompoundCondition(self, other, operator=ConditionOperator.DISJUNCTION) + + def implication(self, other: "Condition") -> "CompoundCondition": + """Resolve conditions as A IMPLIES B. This is the same as A AND B OR NOT A.""" + return CompoundCondition(self, other, operator=ConditionOperator.IMPLICATION) + + def trigger( + self, delay_seconds: float, persistant: bool = False + ) -> "ConditionTrigger": + """Converts the condition to a trigger which becomes permanently TRUE after the first time the inner condition becomes TRUE. + + >>> trigger = TimeWindowCondition(2, 5).trigger(delay_seconds=0) + >>> trigger.evaluate(time=1) + + >>> trigger.evaluate(time=4) + + >>> trigger.evaluate(time=90) + + + >>> start_time = 5 + >>> between_time = 10 + >>> delay_seconds = 20 + >>> trigger = LiteralCondition(ConditionState.TRUE).trigger(delay_seconds=delay_seconds) + >>> trigger.evaluate(time=start_time) + + >>> trigger.evaluate(time=between_time) + + >>> trigger.evaluate(time=start_time + delay_seconds) + + >>> trigger.evaluate(time=between_time) + + + Args: + delay_seconds (float): Applies the trigger after the delay has passed since the inner condition first TRUE. Defaults to False. + persistant (bool, optional): Mixes the inner result with the trigger result using an AND operation. + + Returns: + ConditionTrigger: A resulting condition. + """ + return ConditionTrigger( + self, delay_seconds=delay_seconds, persistant=persistant + ) + + def expire( + self, time, expired_state=ConditionState.EXPIRED, relative: bool = False + ) -> "ExpireTrigger": + """This trigger evaluates to the expired state value after the given simulation time. + + >>> trigger = LiteralCondition(ConditionState.TRUE).expire(20) + >>> trigger.evaluate(time=10) + + >>> trigger.evaluate(time=30) + + + Args: + time (float): The simulation time when this trigger changes. + expired_state (ConditionState, optional): The condition state to use when the simulation is after the given time. Defaults to ConditionState.EXPIRED. + relative (bool, optional): If this trigger should resolve relative to the first evaluated time. + Returns: + ExpireTrigger: The resulting condition. + """ + return ExpireTrigger( + inner_condition=self, + time=time, + expired_state=expired_state, + relative=relative, + ) + + def __and__(self, other: "Condition") -> "CompoundCondition": + """Resolve conditions as A AND B.""" + assert isinstance(other, Condition) + return self.conjunction(other) + + def __or__(self, other: "Condition") -> "CompoundCondition": + """Resolve conditions as A OR B.""" + assert isinstance(other, Condition) + return self.disjunction(other) + + def __neg__(self) -> "NegatedCondition": + """Negates this condition""" + return self.negation() + + +@dataclass(frozen=True) +class SubjectCondition(Condition): + """This condition assumes that there is a subject involved.""" + + def evaluate(self, **kwargs) -> ConditionState: + """Used to evaluate if a condition is met. + + Args: + actor_info: Information about the currently relevant actor. + Returns: + ConditionState: The evaluation result of the condition. + """ + raise NotImplementedError() + + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.current_actor_state + + +_abstract_conditions = (Condition, SubjectCondition) + + +@dataclass(frozen=True) +class LiteralCondition(Condition): + """This condition evaluates as a literal without considering evaluation parameters.""" + + literal: ConditionState + """The literal value of this condition.""" + + def evaluate(self, **kwargs) -> ConditionState: + return self.literal + + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.none + + +@dataclass(frozen=True) +class TimeWindowCondition(Condition): + """This condition should be true in the given simulation time window.""" + + start: float + """The starting simulation time before which this condition becomes false.""" + end: float + """The ending simulation time as of which this condition becomes expired.""" + + def evaluate(self, **kwargs) -> ConditionState: + time = kwargs[ConditionRequires.time.name] + if self.start <= time < self.end or self.end == sys.maxsize: + return ConditionState.TRUE + elif time > self.end: + return ConditionState.EXPIRED + return ConditionState.BEFORE + + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.time + + +@dataclass(frozen=True) +class DependeeActorCondition(Condition): + """This condition should be true if the given actor exists.""" + + actor_id: str + """The id of an actor in the simulation that needs to exist for this condition to be true.""" + + def evaluate(self, **kwargs) -> ConditionState: + actor_ids = kwargs[self.requires.name] + if self.actor_id in actor_ids: + return ConditionState.TRUE + return ConditionState.FALSE + + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.actor_ids + + def __post_init__(self): + assert isinstance(self.actor_id, str) + + +@dataclass(frozen=True) +class NegatedCondition(Condition): + """This condition negates the inner condition to flip between TRUE and FALSE. + + Note\\: This erases temporal values EXPIRED and BEFORE. + """ + + inner_condition: Condition + """The inner condition to negate.""" + + def evaluate(self, **kwargs) -> ConditionState: + result = self.inner_condition.evaluate(**kwargs) + if ConditionState.TRUE in result: + return ConditionState.FALSE + return ConditionState.TRUE + + @property + def requires(self) -> ConditionRequires: + return self.inner_condition.requires + + def __post_init__(self): + if self.inner_condition.__class__ in _abstract_conditions: + raise TypeError( + f"Abstract `{self.inner_condition.__class__.__name__}` cannot use the negation operation." + ) + + +@dataclass(frozen=True) +class ExpireTrigger(Condition): + """This condition allows for expiration after a given time.""" + + inner_condition: Condition + """The inner condition to delay.""" + + time: float + """The simulation time when this trigger becomes expired.""" + + expired_state: ConditionState = ConditionState.EXPIRED + """The state value this trigger should have when it expires.""" + + relative: bool = False + """If this should start relative to the first time evaluated.""" + + def evaluate(self, **kwargs) -> ConditionState: + time = kwargs[ConditionRequires.time.name] + if self.relative: + key = "met" + met_time = getattr(self, key, -1) + if met_time == -1: + object.__setattr__(self, key, time) + time = 0 + else: + time -= met_time + if time >= self.time: + return self.expired_state + return self.inner_condition.evaluate(**kwargs) + + @cached_property + def requires(self) -> ConditionRequires: + return self.inner_condition.requires | ConditionRequires.time + + def __post_init__(self): + if self.inner_condition.__class__ in _abstract_conditions: + raise TypeError( + f"Abstract `{self.inner_condition.__class__.__name__}` cannot be wrapped by a trigger." + ) + + +@dataclass(frozen=True) +class ConditionTrigger(Condition): + """This condition is a trigger that assumes an untriggered constant state and then turns to the other state permanently + on the inner condition becoming TRUE. There is also an option to delay repsonse to the the inner condition by a number + of seconds. This will convey an EXPIRED value immediately because that state means the inner value will never be TRUE. + + This can be used to wait for some time after the inner condition has become TRUE to trigger. + Note that the original condition may no longer be true by the time delay has expired. + + This will never resolve TRUE on the first evaluate. + """ + + inner_condition: Condition + """The inner condition to delay.""" + + delay_seconds: float + """The number of seconds to delay for.""" + + untriggered_state: ConditionState = ConditionState.BEFORE + """The state before the inner trigger condition and delay is resolved.""" + + triggered_state: ConditionState = ConditionState.TRUE + """The state after the inner trigger condition and delay is resolved.""" + + persistant: bool = False + """If the inner condition state is used in conjuction with the triggered state. (inner_condition_state & triggered_state)""" + + def evaluate(self, **kwargs) -> ConditionState: + time = kwargs[ConditionRequires.time.name] + key = "met_time" + result = self.untriggered_state + met_time = getattr(self, key, -1) + if met_time == -1: + if self.inner_condition.evaluate(**kwargs): + object.__setattr__(self, key, time) + time = 0 + else: + time = -1 + else: + time -= met_time + if time >= self.delay_seconds: + result = self.triggered_state + if self.persistant: + result &= self.inner_condition.evaluate(**kwargs) + return result + + temporals = result & (ConditionState.EXPIRED) + if ConditionState.EXPIRED in temporals: + return ConditionState.EXPIRED + return self.untriggered_state + + @property + def requires(self) -> ConditionRequires: + return self.inner_condition.requires | ConditionRequires.time + + def __post_init__(self): + if self.inner_condition.__class__ in _abstract_conditions: + raise TypeError( + f"Abstract `{self.inner_condition.__class__.__name__}` cannot be wrapped by a trigger." + ) + if self.delay_seconds < 0: + raise ValueError("Delay cannot be negative.") + + +@dataclass(frozen=True) +class OffRoadCondition(SubjectCondition): + """This condition is true if the subject is on road.""" + + def evaluate(self, **kwargs) -> ConditionState: + current_actor_road_status = kwargs[self.requires.name] + if ( + current_actor_road_status.road is None + and not current_actor_road_status.off_road + ): + return ConditionState.BEFORE + return ( + ConditionState.TRUE + if current_actor_road_status.off_road + else ConditionState.FALSE + ) + + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.current_actor_road_status + + +@dataclass(frozen=True) +class VehicleTypeCondition(SubjectCondition): + """This condition is true if the subject is of the given vehicle types.""" + + vehicle_type: str + + def evaluate(self, **kwargs) -> ConditionState: + current_actor_state = kwargs[self.requires.name] + return ( + ConditionState.TRUE + if current_actor_state.vehicle_config_type == self.vehicle_type + else ConditionState.FALSE + ) + + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.current_actor_state + + +@dataclass(frozen=True) +class VehicleSpeedCondition(SubjectCondition): + """This condition is true if the subject has a speed between low and high.""" + + low: float + """The lowest speed allowed.""" + + high: float + """The highest speed allowed.""" + + def evaluate(self, **kwargs) -> ConditionState: + vehicle_state = kwargs[self.requires.name] + return ( + ConditionState.TRUE + if self.low <= vehicle_state.speed <= self.high + else ConditionState.FALSE + ) + + @property + def requires(self) -> ConditionRequires: + return ConditionRequires.current_actor_state + + @classmethod + def loitering(cls: Type["VehicleSpeedCondition"], abs_error=0.01): + """Generates a speed condition which assumes that the subject is stationary.""" + return cls(low=abs_error, high=abs_error) + + +@dataclass(frozen=True) +class CompoundCondition(Condition): + """This compounds multiple conditions. + + The following cases are notable + CONJUNCTION (A AND B) + If both conditions evaluate TRUE the result is exclusively TRUE. + Else if either condition evaluates EXPIRED the result will be EXPIRED. + Else if either condition evaluates BEFORE the result will be BEFORE. + Else FALSE + DISJUNCTION (A OR B) + If either condition evaluates TRUE the result is exclusively TRUE. + Else if either condition evaluates BEFORE then the result will be BEFORE. + Else if both conditions evaluate EXPIRED then the result will be EXPIRED. + Else FALSE + IMPLICATION (A AND B or not A) + If the first condition evaluates *not* TRUE the result is exclusively TRUE. + Else if the first condition evaluates TRUE and the second condition evaluates TRUE the result is exclusively TRUE. + Else FALSE + """ + + first_condition: Condition + """The first condition.""" + + second_condition: Condition + """The second condition.""" + + operator: ConditionOperator + """The operator used to combine these conditions.""" + + def evaluate(self, **kwargs) -> ConditionState: + # Short circuits + first_eval = self.first_condition.evaluate(**kwargs) + if ( + self.operator == ConditionOperator.CONJUNCTION + and ConditionState.EXPIRED in first_eval + ): + return ConditionState.EXPIRED + elif ( + self.operator == ConditionOperator.DISJUNCTION + and ConditionState.TRUE in first_eval + ): + return ConditionState.TRUE + elif ( + self.operator == ConditionOperator.IMPLICATION + and ConditionState.TRUE not in first_eval + ): + return ConditionState.TRUE + + second_eval = self.second_condition.evaluate(**kwargs) + if ( + self.operator == ConditionOperator.IMPLICATION + and ConditionState.TRUE in first_eval + and ConditionState.TRUE in second_eval + ): + return ConditionState.TRUE + + elif self.operator == ConditionOperator.CONJUNCTION: + conjuction = first_eval & second_eval + if ConditionState.TRUE in conjuction: + return ConditionState.TRUE + + # To priority of temporal versions of FALSE + disjunction = first_eval | second_eval + if ConditionState.EXPIRED in disjunction: + return ConditionState.EXPIRED + + if ConditionState.BEFORE in disjunction: + return ConditionState.BEFORE + + elif self.operator == ConditionOperator.DISJUNCTION: + result = first_eval | second_eval + + if ConditionState.TRUE in result: + return ConditionState.TRUE + + if ConditionState.BEFORE in result: + return ConditionState.BEFORE + + if ConditionState.EXPIRED in first_eval & second_eval: + return ConditionState.EXPIRED + + return ConditionState.FALSE + + @cached_property + def requires(self) -> ConditionRequires: + return self.first_condition.requires | self.second_condition.requires + + def __post_init__(self): + for condition in (self.first_condition, self.second_condition): + if condition.__class__ in _abstract_conditions: + raise TypeError( + f"Abstract `{condition.__class__.__name__}` cannot use compound operations." + ) diff --git a/smarts/sstudio/types/constants.py b/smarts/sstudio/types/constants.py new file mode 100644 index 0000000000..78e58ce442 --- /dev/null +++ b/smarts/sstudio/types/constants.py @@ -0,0 +1,26 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import sys + +MAX = sys.maxsize +MISSING = sys.maxsize diff --git a/smarts/sstudio/types/dataset.py b/smarts/sstudio/types/dataset.py new file mode 100644 index 0000000000..91adef8c97 --- /dev/null +++ b/smarts/sstudio/types/dataset.py @@ -0,0 +1,71 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import math +from dataclasses import dataclass +from typing import Optional + + +@dataclass(frozen=True) +class TrafficHistoryDataset: + """Describes a dataset containing trajectories (time-stamped positions) + for a set of vehicles. Often these have been collected by third parties + from real-world observations, hence the name 'history'. When used + with a SMARTS scenario, traffic vehicles will move on the map according + to their trajectories as specified in the dataset. These can be mixed + with other types of traffic (such as would be specified by an object of + the Traffic type in this DSL). In order to use this efficiently, SMARTS + will pre-process ('import') the dataset when the scenario is built.""" + + name: str + """a unique name for the dataset""" + source_type: str + """the type of the dataset; supported values include: NGSIM, INTERACTION, Waymo""" + input_path: Optional[str] = None + """a relative or absolute path to the dataset; if omitted, dataset will not be imported""" + scenario_id: Optional[str] = None + """a unique ID for a Waymo scenario. For other datasets, this field will be None.""" + x_margin_px: float = 0.0 + """x offset of the map from the data (in pixels)""" + y_margin_px: float = 0.0 + """y offset of the map from the data (in pixels)""" + swap_xy: bool = False + """if True, the x and y axes the dataset coordinate system will be swapped""" + flip_y: bool = False + """if True, the dataset will be mirrored around the x-axis""" + filter_off_map: bool = False + """if True, then any vehicle whose coordinates on a time step fall outside of the map's bounding box will be removed for that time step""" + + map_lane_width: float = 3.7 + """This is used to figure out the map scale, which is map_lane_width / real_lane_width_m. (So use `real_lane_width_m` here for 1:1 scale - the default.) It's also used in SMARTS for detecting off-road, etc.""" + real_lane_width_m: float = 3.7 + """Average width in meters of the dataset's lanes in the real world. US highway lanes are about 12 feet (or ~3.7m, the default) wide.""" + speed_limit_mps: Optional[float] = None + """used by SMARTS for the initial speed of new agents being added to the scenario""" + + heading_inference_window: int = 2 + """When inferring headings from positions, a sliding window (moving average) of this size will be used to smooth inferred headings and reduce their dependency on any individual position changes. Defaults to 2 if not specified.""" + heading_inference_min_speed: float = 2.2 + """Speed threshold below which a vehicle's heading is assumed not to change. This is useful to prevent abnormal heading changes that may arise from noise in position estimates in a trajectory dataset dominating real position changes in situations where the real position changes are very small. Defaults to 2.2 m/s if not specified.""" + max_angular_velocity: Optional[float] = None + """When inferring headings from positions, each vehicle's angular velocity will be limited to be at most this amount (in rad/sec) to prevent lateral-coordinate noise in the dataset from causing near-instantaneous heading changes.""" + default_heading: float = 1.5 * math.pi + """A heading in radians to be used by default for vehicles if the headings are not present in the dataset and cannot be inferred from position changes (such as on the first time step).""" diff --git a/smarts/sstudio/types/distribution.py b/smarts/sstudio/types/distribution.py new file mode 100644 index 0000000000..7493967157 --- /dev/null +++ b/smarts/sstudio/types/distribution.py @@ -0,0 +1,78 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +import random +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Distribution: + """A gaussian distribution used for randomized parameters.""" + + mean: float + """The mean value of the gaussian distribution.""" + sigma: float + """The sigma value of the gaussian distribution.""" + + def sample(self): + """The next sample from the distribution.""" + return random.gauss(self.mean, self.sigma) + + +@dataclass +class UniformDistribution: + """A uniform distribution, return a random number N + such that a <= N <= b for a <= b and b <= N <= a for b < a. + """ + + a: float + b: float + + def __post_init__(self): + if self.b < self.a: + self.a, self.b = self.b, self.a + + def sample(self): + """Get the next sample.""" + return random.uniform(self.a, self.b) + + +@dataclass +class TruncatedDistribution: + """A truncated normal distribution, by default, location=0, scale=1""" + + a: float + b: float + loc: float = 0 + scale: float = 1 + + def __post_init__(self): + assert self.a != self.b + if self.b < self.a: + self.a, self.b = self.b, self.a + + def sample(self): + """Get the next sample""" + from scipy.stats import truncnorm + + return truncnorm.rvs(self.a, self.b, loc=self.loc, scale=self.scale) diff --git a/smarts/sstudio/types/entry_tactic.py b/smarts/sstudio/types/entry_tactic.py new file mode 100644 index 0000000000..33e19b449f --- /dev/null +++ b/smarts/sstudio/types/entry_tactic.py @@ -0,0 +1,82 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +from dataclasses import dataclass +from typing import Optional, Tuple + +from smarts.core.condition_state import ConditionState +from smarts.sstudio.types.condition import ( + Condition, + ConditionRequires, + LiteralCondition, +) +from smarts.sstudio.types.zone import MapZone + + +@dataclass(frozen=True) +class EntryTactic: + """The tactic that the simulation should use to acquire a vehicle for an agent.""" + + start_time: float + + def __post_init__(self): + assert ( + getattr(self, "condition", None) is not None + ), "Abstract class, inheriting types must implement the `condition` field." + + +@dataclass(frozen=True) +class TrapEntryTactic(EntryTactic): + """An entry tactic that repurposes a pre-existing vehicle for an agent.""" + + wait_to_hijack_limit_s: float = 0 + """The amount of seconds a hijack will wait to get a vehicle before defaulting to a new vehicle""" + zone: Optional[MapZone] = None + """The zone of the hijack area""" + exclusion_prefixes: Tuple[str, ...] = tuple() + """The prefixes of vehicles to avoid hijacking""" + default_entry_speed: Optional[float] = None + """The speed that the vehicle starts at when the hijack limit expiry emits a new vehicle""" + condition: Condition = LiteralCondition(ConditionState.TRUE) + """A condition that is used to add additional exclusions.""" + + def __post_init__(self): + assert isinstance(self.condition, (Condition)) + assert not ( + self.condition.requires & ConditionRequires.any_current_actor_state + ), f"Trap entry tactic cannot use conditions that require any_vehicle_state." + + +@dataclass(frozen=True) +class IdEntryTactic(EntryTactic): + """An entry tactic which repurposes a pre-existing actor for an agent. Selects that actor by id.""" + + actor_id: str + """The id of the actor to take over.""" + + condition: Condition = LiteralCondition(ConditionState.TRUE) + """A condition that is used to add additional exclusions.""" + + def __post_init__(self): + assert isinstance(self.actor_id, str) + assert isinstance(self.condition, (Condition)) diff --git a/smarts/sstudio/types/map_spec.py b/smarts/sstudio/types/map_spec.py new file mode 100644 index 0000000000..bbfce962a7 --- /dev/null +++ b/smarts/sstudio/types/map_spec.py @@ -0,0 +1,59 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +# A MapBuilder should return an object derived from the RoadMap base class +# and a hash that uniquely identifies it (changes to the hash should signify +# that the map is different enough that map-related caches should be reloaded). +# +# This function should be re-callable (although caching is up to the implementation). +# The idea here is that anything in SMARTS that needs to use a RoadMap +# can call this builder to get or create one as necessary. +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple + +from smarts.core.default_map_builder import get_road_map +from smarts.core.road_map import RoadMap + +MapBuilder = Callable[[Any], Tuple[Optional[RoadMap], Optional[str]]] + + +@dataclass(frozen=True) +class MapSpec: + """A map specification that describes how to generate a roadmap.""" + + source: str + """A path or URL or name uniquely designating the map source.""" + lanepoint_spacing: float = 1.0 + """The default distance between pre-generated Lane Points (Waypoints).""" + default_lane_width: Optional[float] = None + """If specified, the default width (in meters) of lanes on this map.""" + shift_to_origin: bool = False + """If True, upon creation a map whose bounding-box does not intersect with + the origin point (0,0) will be shifted such that it does.""" + builder_fn: MapBuilder = get_road_map + """If specified, this should return an object derived from the RoadMap base class + and a hash that uniquely identifies it (changes to the hash should signify + that the map is different enough that map-related caches should be reloaded). + The parameter is this MapSpec object itself. + If not specified, this currently defaults to a function that creates + SUMO road networks (get_road_map()) in smarts.core.default_map_builder.""" diff --git a/smarts/sstudio/types/mission.py b/smarts/sstudio/types/mission.py new file mode 100644 index 0000000000..a0986ce42f --- /dev/null +++ b/smarts/sstudio/types/mission.py @@ -0,0 +1,145 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import sys +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +from smarts.sstudio.types.constants import MISSING +from smarts.sstudio.types.entry_tactic import EntryTactic +from smarts.sstudio.types.route import JunctionEdgeIDResolver, RandomRoute, Route + + +@dataclass(frozen=True) +class Via: + """A point on a road that an actor must pass through""" + + road_id: Union[str, JunctionEdgeIDResolver] + """The road this via is on""" + lane_index: int + """The lane this via sits on""" + lane_offset: int + """The offset along the lane where this via sits""" + required_speed: float + """The speed that a vehicle should travel through this via""" + hit_distance: float = -1 + """The distance at which this waypoint can be hit. Negative means half the lane radius.""" + + +@dataclass(frozen=True) +class Mission: + """The descriptor for an actor's mission.""" + + route: Union[RandomRoute, Route] + """The route for the actor to attempt to follow.""" + + via: Tuple[Via, ...] = () + """Points on an road that an actor must pass through""" + + start_time: float = MISSING + """The earliest simulation time that this mission starts but may start later in couple with + `entry_tactic`. + """ + + entry_tactic: Optional[EntryTactic] = None + """A specific tactic the mission should employ to start the mission.""" + + def __post_init__(self): + if self.start_time != sys.maxsize: + warnings.warn( + "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", + category=DeprecationWarning, + ) + + +@dataclass(frozen=True) +class EndlessMission: + """The descriptor for an actor's mission that has no end.""" + + begin: Tuple[str, int, float] + """The (road, lane_index, offset) details of the start location for the route. + + road: + The starting road by name. + lane_index: + The lane index from the rightmost lane. + offset: + The offset in metres into the lane. Also acceptable\\: 'max', 'random' + """ + via: Tuple[Via, ...] = () + """Points on a road that an actor must pass through""" + start_time: float = MISSING + """The earliest simulation time that this mission starts""" + entry_tactic: Optional[EntryTactic] = None + """A specific tactic the mission should employ to start the mission""" + + def __post_init__(self): + if self.start_time != sys.maxsize: + warnings.warn( + "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", + category=DeprecationWarning, + ) + + +@dataclass(frozen=True) +class LapMission: + """The descriptor for an actor's mission that defines mission that repeats + in a closed loop. + """ + + route: Route + """The route for the actor to attempt to follow""" + num_laps: int + """The amount of times to repeat the mission""" + via: Tuple[Via, ...] = () + """Points on a road that an actor must pass through""" + start_time: float = MISSING + """The earliest simulation time that this mission starts""" + entry_tactic: Optional[EntryTactic] = None + """A specific tactic the mission should employ to start the mission""" + + def __post_init__(self): + if self.start_time != sys.maxsize: + warnings.warn( + "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", + category=DeprecationWarning, + ) + + +@dataclass(frozen=True) +class GroupedLapMission: + """The descriptor for a group of actor missions that repeat in a closed loop.""" + + route: Route + """The route for the actors to attempt to follow""" + offset: int + """The offset of the "starting line" for the group""" + lanes: int + """The number of lanes the group occupies""" + actor_count: int + """The number of actors to be part of the group""" + num_laps: int + """The amount of times to repeat the mission""" + via: Tuple[Via, ...] = () + """Points on a road that an actor must pass through""" + entry_tactic: Optional[EntryTactic] = None + """A specific tactic the mission should employ to start the mission""" diff --git a/smarts/sstudio/types/route.py b/smarts/sstudio/types/route.py new file mode 100644 index 0000000000..fafb201e4d --- /dev/null +++ b/smarts/sstudio/types/route.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +from dataclasses import dataclass, field +from typing import Any, Optional, Tuple + +from smarts.core import gen_id +from smarts.core.utils.file import pickle_hash_int +from smarts.sstudio.types.map_spec import MapSpec + + +@dataclass(frozen=True) +class JunctionEdgeIDResolver: + """A utility for resolving a junction connection edge""" + + start_edge_id: str + start_lane_index: int + end_edge_id: str + end_lane_index: int + + def to_edge(self, sumo_road_network) -> str: + """Queries the road network to see if there is a junction edge between the two + given edges. + """ + return sumo_road_network.get_edge_in_junction( + self.start_edge_id, + self.start_lane_index, + self.end_edge_id, + self.end_lane_index, + ) + + +@dataclass(frozen=True) +class Route: + """A route is represented by begin and end road IDs, with an optional list of + intermediary road IDs. When an intermediary is not specified the router will + decide what it should be. + """ + + ## road, lane index, offset + begin: Tuple[str, int, Any] + """The (road, lane_index, offset) details of the start location for the route. + + road: + The starting road by name. + lane_index: + The lane index from the rightmost lane. + offset: + The offset in metres into the lane. Also acceptable\\: "max", "random" + """ + ## road, lane index, offset + end: Tuple[str, int, Any] + """The (road, lane_index, offset) details of the end location for the route. + + road: + The starting road by name. + lane_index: + The lane index from the rightmost lane. + offset: + The offset in metres into the lane. Also acceptable\\: "max", "random" + """ + + # Roads we want to make sure this route includes + via: Tuple[str, ...] = field(default_factory=tuple) + """The ids of roads that must be included in the route between `begin` and `end`.""" + + map_spec: Optional[MapSpec] = None + """All routes are relative to a road map. If not specified here, + the default map_spec for the scenario is used.""" + + @property + def id(self) -> str: + """The unique id of this route.""" + return "{}-{}-{}".format( + "_".join(map(str, self.begin)), + "_".join(map(str, self.end)), + str(hash(self))[:6], + ) + + @property + def roads(self): + """All roads that are used within this route.""" + return (self.begin[0],) + self.via + (self.end[0],) + + def __hash__(self): + return pickle_hash_int(self) + + def __eq__(self, other): + return self.__class__ == other.__class__ and hash(self) == hash(other) + + +@dataclass(frozen=True) +class RandomRoute: + """An alternative to types.Route which specifies to sstudio to generate a random + route. + """ + + id: str = field(default_factory=lambda: f"random-route-{gen_id()}") + + map_spec: Optional[MapSpec] = None + """All routes are relative to a road map. If not specified here, + the default map_spec for the scenario is used.""" + + def __hash__(self): + return hash(self.id) + + def __eq__(self, other): + return self.__class__ == other.__class__ and hash(self) == hash(other) diff --git a/smarts/sstudio/types/scenario.py b/smarts/sstudio/types/scenario.py new file mode 100644 index 0000000000..e49932eef4 --- /dev/null +++ b/smarts/sstudio/types/scenario.py @@ -0,0 +1,92 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +from dataclasses import dataclass +from typing import Dict, Optional, Sequence, Tuple, Union + +from smarts.core.colors import Colors +from smarts.sstudio.types.actor.social_agent_actor import SocialAgentActor +from smarts.sstudio.types.bubble import Bubble +from smarts.sstudio.types.dataset import TrafficHistoryDataset +from smarts.sstudio.types.map_spec import MapSpec +from smarts.sstudio.types.mission import EndlessMission, Mission +from smarts.sstudio.types.traffic import Traffic +from smarts.sstudio.types.zone import RoadSurfacePatch + + +@dataclass(frozen=True) +class ScenarioMetadata: + """Scenario data that does not have influence on simulation.""" + + actor_of_interest_re_filter: str + """Vehicles with names that match this pattern are vehicles of interest.""" + actor_of_interest_color: Colors + """The color that the vehicles of interest should have.""" + + +@dataclass(frozen=True) +class Scenario: + """The sstudio scenario representation.""" + + map_spec: Optional[MapSpec] = None + """Specifies the road map.""" + traffic: Optional[Dict[str, Traffic]] = None + """Background traffic vehicle specification.""" + ego_missions: Optional[Sequence[Union[Mission, EndlessMission]]] = None + """Ego agent missions.""" + social_agent_missions: Optional[ + Dict[str, Tuple[Sequence[SocialAgentActor], Sequence[Mission]]] + ] = None + """ + Actors must have unique names regardless of which group they are assigned to. + Every dictionary item ``{group: (actors, missions)}`` gets selected from simultaneously. + If actors > 1 and missions = 0 or actors = 1 and missions > 0, we cycle + through them every episode. Otherwise actors must be the same length as + missions. + """ + bubbles: Optional[Sequence[Bubble]] = None + """Capture bubbles for focused social agent simulation.""" + friction_maps: Optional[Sequence[RoadSurfacePatch]] = None + """Friction coefficient of patches of road surface.""" + traffic_histories: Optional[Sequence[Union[TrafficHistoryDataset, str]]] = None + """Traffic vehicles trajectory dataset to be replayed.""" + scenario_metadata: Optional[ScenarioMetadata] = None + """"Scenario data that does not have influence on simulation.""" + + def __post_init__(self): + def _get_name(item): + return item.name + + if self.social_agent_missions is not None: + groups = [k for k in self.social_agent_missions] + for group, (actors, _) in self.social_agent_missions.items(): + for o_group in groups: + if group == o_group: + continue + if intersection := set.intersection( + set(map(_get_name, actors)), + map(_get_name, self.social_agent_missions[o_group][0]), + ): + raise ValueError( + f"Social agent mission groups `{group}`|`{o_group}` have overlapping actors {intersection}" + ) diff --git a/smarts/sstudio/types/traffic.py b/smarts/sstudio/types/traffic.py new file mode 100644 index 0000000000..6896046ed0 --- /dev/null +++ b/smarts/sstudio/types/traffic.py @@ -0,0 +1,129 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from dataclasses import dataclass, field, replace +from typing import Dict, Optional, Sequence, Union + +from smarts.core.utils.file import pickle_hash_int +from smarts.sstudio.types.actor.traffic_actor import TrafficActor +from smarts.sstudio.types.route import RandomRoute, Route + + +@dataclass(frozen=True) +class Flow: + """A route with an actor type emitted at a given rate.""" + + route: Union[RandomRoute, Route] + """The route for the actor to attempt to follow.""" + rate: float + """Vehicles per hour.""" + begin: float = 0 + """Start time in seconds.""" + # XXX: Defaults to 1 hour of traffic. We may want to change this to be "continual + # traffic", effectively an infinite end. + end: float = 1 * 60 * 60 + """End time in seconds.""" + actors: Dict[TrafficActor, float] = field(default_factory=dict) + """An actor to weight mapping associated as\\: { actor: weight } + + :param actor: The traffic actors that are provided. + :param weight: The chance of this actor appearing as a ratio over total weight. + """ + randomly_spaced: bool = False + """Determines if the flow should have randomly spaced traffic. Defaults to `False`.""" + repeat_route: bool = False + """If True, vehicles that finish their route will be restarted at the beginning. Defaults to `False`.""" + + @property + def id(self) -> str: + """The unique id of this flow.""" + return "{}-{}".format( + self.route.id, + str(hash(self))[:6], + ) + + def __hash__(self): + # Custom hash since self.actors is not hashable, here we first convert to a + # frozenset. + return pickle_hash_int((self.route, self.rate, frozenset(self.actors.items()))) + + def __eq__(self, other): + return self.__class__ == other.__class__ and hash(self) == hash(other) + + +@dataclass(frozen=True) +class Trip: + """A route with a single actor type with name and unique id.""" + + vehicle_name: str + """The name of the vehicle. It must be unique. """ + route: Union[RandomRoute, Route] + """The route for the actor to attempt to follow.""" + vehicle_type: str = "passenger" + """The type of the vehicle""" + depart: float = 0 + """Start time in seconds.""" + actor: Optional[TrafficActor] = field(default=None) + """The traffic actor model (usually vehicle) that will be used for the trip.""" + + def __post_init__(self): + object.__setattr__( + self, + "actor", + ( + replace( + self.actor, name=self.vehicle_name, vehicle_type=self.vehicle_type + ) + if self.actor is not None + else TrafficActor( + name=self.vehicle_name, vehicle_type=self.vehicle_type + ) + ), + ) + + @property + def id(self) -> str: + """The unique id of this trip.""" + return self.vehicle_name + + def __hash__(self): + # Custom hash since self.actors is not hashable, here we first convert to a + # frozenset. + return pickle_hash_int((self.route, self.actor)) + + def __eq__(self, other): + return self.__class__ == other.__class__ and hash(self) == hash(other) + + +@dataclass(frozen=True) +class Traffic: + """The descriptor for traffic.""" + + flows: Sequence[Flow] + """Flows are used to define a steady supply of vehicles.""" + # TODO: consider moving TrafficHistory stuff in here (and rename to Trajectory) + # TODO: - treat history points like Vias (no guarantee on history timesteps anyway) + trips: Optional[Sequence[Trip]] = None + """Trips are used to define a series of single vehicle trip.""" + engine: str = "SUMO" + """Traffic-generation engine to use. Supported values include "SUMO" and "SMARTS". "SUMO" requires using a SumoRoadNetwork for the RoadMap. + """ diff --git a/smarts/sstudio/types/traffic_model.py b/smarts/sstudio/types/traffic_model.py new file mode 100644 index 0000000000..00fecf4e97 --- /dev/null +++ b/smarts/sstudio/types/traffic_model.py @@ -0,0 +1,169 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import collections.abc as collections_abc +from enum import IntEnum +from typing import Callable + + +class _SUMO_PARAMS_MODE(IntEnum): + TITLE_CASE = 0 + KEEP_SNAKE_CASE = 1 + + +class _SumoParams(collections_abc.Mapping): + """For some Sumo params (e.g. LaneChangingModel) the arguments are in title case + with a given prefix. Subclassing this class allows for an automatic way to map + between PEP8-compatible naming and Sumo's. + """ + + def __init__( + self, prefix, whitelist=[], mode=_SUMO_PARAMS_MODE.TITLE_CASE, **kwargs + ): + def snake_to_title(word): + return "".join(x.capitalize() or "_" for x in word.split("_")) + + def keep_snake_case(word: str): + w = word[0].upper() + word[1:] + return "".join(x or "_" for x in w.split("_")) + + func: Callable[[str], str] = snake_to_title + if mode == _SUMO_PARAMS_MODE.TITLE_CASE: + pass + elif mode == _SUMO_PARAMS_MODE.KEEP_SNAKE_CASE: + func = keep_snake_case + + # XXX: On rare occasions sumo doesn't respect their own conventions + # (e.x. junction model's impatience). + self._params = {key: kwargs.pop(key) for key in whitelist if key in kwargs} + + for key, value in kwargs.items(): + self._params[f"{prefix}{func(key)}"] = value + + def __iter__(self): + return iter(self._params) + + def __getitem__(self, key): + return self._params[key] + + def __len__(self): + return len(self._params) + + def __hash__(self): + return hash(frozenset(self._params.items())) + + def __eq__(self, other): + return self.__class__ == other.__class__ and hash(self) == hash(other) + + +class LaneChangingModel(_SumoParams): + """Models how the actor acts with respect to lane changes.""" + + # For SUMO-specific attributes, see: + # https://sumo.dlr.de/docs/Definition_of_Vehicles%2C_Vehicle_Types%2C_and_Routes.html#lane-changing_models + + def __init__(self, **kwargs): + super().__init__("lc", whitelist=["minGapLat"], **kwargs) + + +class JunctionModel(_SumoParams): + """Models how the actor acts with respect to waiting at junctions.""" + + def __init__(self, **kwargs): + super().__init__("jm", whitelist=["impatience"], **kwargs) + + +class SmartsLaneChangingModel(LaneChangingModel): + """Implements the simple lane-changing model built-into SMARTS. + + Args: + cutin_prob (float, optional): Float value [0, 1] that + determines the probabilty this vehicle will "arbitrarily" cut in + front of an adjacent agent vehicle when it has a chance, even if + there would otherwise be no reason to change lanes at that point. + Higher values risk a situation where this vehicle ends up in a lane + where it cannot maintain its planned route. If that happens, this + vehicle will perform whatever its default behavior is when it + completes its route. Defaults to 0.0. + assertive (float, optional): Willingness to accept lower front and rear + gaps in the target lane. The required gap is divided by this value. + Attempts to match the semantics of the attribute in SUMO's default + lane-changing model, see: ``https://sumo.dlr.de/docs/Definition_of_Vehicles%2C_Vehicle_Types%2C_and_Routes.html#lane-changing_models``. + Range: positive reals. Defaults to 1.0. + dogmatic (bool, optional): If True, will cutin when a suitable + opportunity presents itself based on the above parameters, even if + it means the risk of not not completing the assigned route; + otherwise, will forego the chance. Defaults to True. + hold_period (float, optional): The minimum amount of time (seconds) to + remain in the agent's lane after cutting into it (including the + time it takes within the lane to complete the maneuver). Must be + non-negative. Defaults to 3.0. + slow_down_after (float, optional): Target speed during the hold_period + will be scaled by this value. Must be non-negative. Defaults to 1.0. + multi_lane_cutin (bool, optional): If True, this vehicle will consider + changing across multiple lanes at once in order to cutin upon an + agent vehicle when there's an opportunity. Defaults to False. + """ + + def __init__( + self, + cutin_prob: float = 0.0, + assertive: float = 1.0, + dogmatic: bool = True, + hold_period: float = 3.0, + slow_down_after: float = 1.0, + multi_lane_cutin: bool = False, + ): + super().__init__( + cutin_prob=cutin_prob, + assertive=assertive, + dogmatic=dogmatic, + hold_period=hold_period, + slow_down_after=slow_down_after, + multi_lane_cutin=multi_lane_cutin, + ) + + +class SmartsJunctionModel(JunctionModel): + """Implements the simple junction model built-into SMARTS. + + Args: + yield_to_agents (str, optional): Defaults to "normal". 3 options are + available, namely: (1) "always" - Traffic actors will yield to Ego + and Social agents within junctions. (2) "never" - Traffic actors + will never yield to Ego or Social agents within junctions. + (3) "normal" - Traffic actors will attempt to honor normal + right-of-way conventions, only yielding when an agent has the + right-of-way. Examples of such conventions include (a) vehicles + going straight have the right-of-way over turning vehicles; + (b) vehicles on roads with more lanes have the right-of-way + relative to vehicles on intersecting roads with less lanes; + (c) all other things being equal, the vehicle to the right + in a counter-clockwise sense has the right-of-way. + wait_to_restart (float, optional): The amount of time in seconds + after stopping at a signal or stop sign before this vehicle + will start to go again. Defaults to 0.0. + """ + + def __init__(self, yield_to_agents: str = "normal", wait_to_restart: float = 0.0): + super().__init__( + yield_to_agents=yield_to_agents, wait_to_restart=wait_to_restart + ) diff --git a/smarts/sstudio/types/zone.py b/smarts/sstudio/types/zone.py new file mode 100644 index 0000000000..bb83d1e868 --- /dev/null +++ b/smarts/sstudio/types/zone.py @@ -0,0 +1,274 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import logging +import random +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import numpy as np +from shapely.affinity import rotate as shapely_rotate +from shapely.affinity import translate as shapely_translate +from shapely.geometry import ( + GeometryCollection, + LineString, + MultiPolygon, + Point, + Polygon, + box, +) +from shapely.ops import split, unary_union + +from smarts.core.coordinates import RefLinePoint +from smarts.core.road_map import RoadMap +from smarts.core.utils.math import rotate_cw_around_point + + +@dataclass(frozen=True) +class Zone: + """The base for a descriptor that defines a capture area.""" + + def to_geometry(self, road_map: Optional[RoadMap] = None) -> Polygon: + """Generates the geometry from this zone.""" + raise NotImplementedError + + +@dataclass(frozen=True) +class MapZone(Zone): + """A descriptor that defines a capture area.""" + + start: Tuple[str, int, float] + """The (road_id, lane_index, offset) details of the starting location. + + road_id: + The starting road by name. + lane_index: + The lane index from the rightmost lane. + offset: + The offset in metres into the lane. Also acceptable\\: 'max', 'random' + """ + length: float + """The length of the geometry along the center of the lane. Also acceptable\\: 'max'""" + n_lanes: int = 2 + """The number of lanes from right to left that this zone covers.""" + + def to_geometry(self, road_map: Optional[RoadMap]) -> Polygon: + """Generates a map zone over a stretch of the given lanes.""" + + assert ( + road_map is not None + ), f"{self.__class__.__name__} requires a road map to resolve geometry." + + def resolve_offset(offset, geometry_length, lane_length): + if offset == "base": + return 0 + # push off of end of lane + elif offset == "max": + return lane_length - geometry_length + elif offset == "random": + return random.uniform(0, lane_length - geometry_length) + else: + return float(offset) + + def pick_remaining_shape_after_split(geometry_collection, expected_point): + lane_shape = geometry_collection + if not isinstance(lane_shape, GeometryCollection): + return lane_shape + + # For simplicity, we only deal w/ the == 1 or 2 case + if len(lane_shape.geoms) not in {1, 2}: + return None + + if len(lane_shape.geoms) == 1: + return lane_shape.geoms[0] + + # We assume that there are only two split shapes to choose from + keep_index = 0 + if lane_shape.geoms[1].minimum_rotated_rectangle.contains(expected_point): + # 0 is the discard piece, keep the other + keep_index = 1 + + lane_shape = lane_shape.geoms[keep_index] + + return lane_shape + + def split_lane_shape_at_offset( + lane_shape: Polygon, lane: RoadMap.Lane, offset: float + ): + # XXX: generalize to n-dim + width_2, _ = lane.width_at_offset(offset) + point = np.array(lane.from_lane_coord(RefLinePoint(offset)))[:2] + lane_vec = lane.vector_at_offset(offset)[:2] + + perp_vec_right = rotate_cw_around_point(lane_vec, np.pi / 2, origin=(0, 0)) + perp_vec_right = ( + perp_vec_right / max(np.linalg.norm(perp_vec_right), 1e-3) * width_2 + + point + ) + + perp_vec_left = rotate_cw_around_point(lane_vec, -np.pi / 2, origin=(0, 0)) + perp_vec_left = ( + perp_vec_left / max(np.linalg.norm(perp_vec_left), 1e-3) * width_2 + + point + ) + + split_line = LineString([perp_vec_left, perp_vec_right]) + return split(lane_shape, split_line) + + lane_shapes = [] + road_id, lane_idx, offset = self.start + road = road_map.road_by_id(road_id) + buffer_from_ends = 1e-6 + for lane_idx in range(lane_idx, lane_idx + self.n_lanes): + lane = road.lane_at_index(lane_idx) + lane_length = lane.length + geom_length = self.length + + if geom_length > lane_length: + logging.debug( + f"Geometry is too long={geom_length} with offset={offset} for " + f"lane={lane.lane_id}, using length={lane_length} instead" + ) + geom_length = lane_length + + assert geom_length > 0 # Geom length is negative + + lane_offset = resolve_offset(offset, geom_length, lane_length) + lane_offset += buffer_from_ends + width, _ = lane.width_at_offset(lane_offset) # TODO + lane_shape = lane.shape(0.3, width) # TODO + + geom_length = max(geom_length - buffer_from_ends, buffer_from_ends) + lane_length = max(lane_length - buffer_from_ends, buffer_from_ends) + + min_cut = min(lane_offset, lane_length) + # Second cut takes into account shortening of geometry by `min_cut`. + max_cut = min(min_cut + geom_length, lane_length) + + midpoint = Point( + *lane.from_lane_coord(RefLinePoint(s=lane_offset + geom_length * 0.5)) + ) + + lane_shape = split_lane_shape_at_offset(lane_shape, lane, min_cut) + lane_shape = pick_remaining_shape_after_split(lane_shape, midpoint) + if lane_shape is None: + continue + + lane_shape = split_lane_shape_at_offset( + lane_shape, + lane, + max_cut, + ) + lane_shape = pick_remaining_shape_after_split(lane_shape, midpoint) + if lane_shape is None: + continue + + lane_shapes.append(lane_shape) + + geom = unary_union(MultiPolygon(lane_shapes)) + return geom + + +@dataclass(frozen=True) +class PositionalZone(Zone): + """A descriptor that defines a capture area at a specific XY location.""" + + # center point + pos: Tuple[float, float] + """A (x,y) position of the zone in the scenario.""" + size: Tuple[float, float] + """The (length, width) dimensions of the zone.""" + rotation: Optional[float] = None + """The heading direction of the bubble. (radians, clock-wise rotation)""" + + def to_geometry(self, road_map: Optional[RoadMap] = None) -> Polygon: + """Generates a box zone at the given position.""" + w, h = self.size + x, y = self.pos[:2] + p0 = (-w / 2, -h / 2) # min + p1 = (w / 2, h / 2) # max + poly = Polygon([p0, (p0[0], p1[1]), p1, (p1[0], p0[1])]) + if self.rotation is not None: + poly = shapely_rotate(poly, self.rotation, use_radians=True) + return shapely_translate(poly, xoff=x, yoff=y) + + +@dataclass(frozen=True) +class ConfigurableZone(Zone): + """A descriptor for a zone with user-defined geometry.""" + + ext_coordinates: List[Tuple[float, float]] + """external coordinates of the polygon + < 2 points provided: error + = 2 points provided: generates a box using these two points as diagonal + > 2 points provided: generates a polygon according to the coordinates""" + rotation: Optional[float] = None + """The heading direction of the bubble(radians, clock-wise rotation)""" + + def __post_init__(self): + if ( + not self.ext_coordinates + or len(self.ext_coordinates) < 2 + or not isinstance(self.ext_coordinates[0], tuple) + ): + raise ValueError( + "Two points or more are needed to create a polygon. (less than two points are provided)" + ) + + x_set = set(point[0] for point in self.ext_coordinates) + y_set = set(point[1] for point in self.ext_coordinates) + if len(x_set) == 1 or len(y_set) == 1: + raise ValueError( + "Parallel line cannot form a polygon. (points provided form a parallel line)" + ) + + def to_geometry(self, road_map: Optional[RoadMap] = None) -> Polygon: + """Generate a polygon according to given coordinates""" + poly = None + if ( + len(self.ext_coordinates) == 2 + ): # if user only specified two points, create a box + x_min = min(self.ext_coordinates[0][0], self.ext_coordinates[1][0]) + x_max = max(self.ext_coordinates[0][0], self.ext_coordinates[1][0]) + y_min = min(self.ext_coordinates[0][1], self.ext_coordinates[1][1]) + y_max = max(self.ext_coordinates[0][1], self.ext_coordinates[1][1]) + poly = box(x_min, y_min, x_max, y_max) + + else: # else create a polygon according to the coordinates + poly = Polygon(self.ext_coordinates) + + if self.rotation is not None: + poly = shapely_rotate(poly, self.rotation, use_radians=True) + return poly + + +@dataclass(frozen=True) +class RoadSurfacePatch: + """A descriptor that defines a patch of road surface with a different friction coefficient.""" + + zone: Zone + """The zone which to capture vehicles.""" + begin_time: int + """The start time in seconds of when this surface is active.""" + end_time: int + """The end time in seconds of when this surface is active.""" + friction_coefficient: float + """The surface friction coefficient.""" From 85d068f0052f6c5a950b0789b7081c123f5e2d8a Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Tue, 16 May 2023 12:43:40 +0000 Subject: [PATCH 57/59] Update scenarios. --- smarts/core/plan.py | 14 ++++++++++++-- smarts/core/scenario.py | 6 +++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/smarts/core/plan.py b/smarts/core/plan.py index 1f70de8b90..9718017eda 100644 --- a/smarts/core/plan.py +++ b/smarts/core/plan.py @@ -23,6 +23,8 @@ import math import random +import sys +import warnings from dataclasses import dataclass, field from typing import List, Optional, Tuple @@ -33,6 +35,8 @@ from smarts.core.utils.math import min_angles_difference_signed, vec_to_radians from smarts.sstudio.types import EntryTactic, TrapEntryTactic +MISSING = sys.maxsize + class PlanningError(Exception): """Raised in cases when map related planning fails.""" @@ -166,7 +170,7 @@ def _drove_off_map(self, veh_pos: Point, veh_heading: float) -> bool: def default_entry_tactic(default_entry_speed: Optional[float] = None) -> EntryTactic: """The default tactic the simulation will use to acquire an actor for an agent.""" return TrapEntryTactic( - start_time=0, + start_time=MISSING, wait_to_hijack_limit_s=0, exclusion_prefixes=tuple(), zone=None, @@ -209,7 +213,7 @@ class Mission: # An optional list of road IDs between the start and end goal that we want to # ensure the mission includes route_vias: Tuple[str, ...] = field(default_factory=tuple) - start_time: float = 0.1 + start_time: float = MISSING entry_tactic: Optional[EntryTactic] = None via: Tuple[Via, ...] = () # if specified, will use vehicle_spec to build the vehicle (for histories) @@ -258,6 +262,12 @@ def random_endless_mission( target_pose = n_lane.center_pose_at_point(coord) return Mission.endless_mission(start_pose=target_pose) + def __post_init__(self): + if self.entry_tactic is not None and self.entry_tactic.start_time != MISSING: + object.__setattr__(self, "start_time", self.entry_tactic.start_time) + elif self.start_time == MISSING: + object.__setattr__(self, "start_time", 0.1) + @dataclass(frozen=True) class LapMission(Mission): diff --git a/smarts/core/scenario.py b/smarts/core/scenario.py index d37fdecb4a..faf77c99d5 100644 --- a/smarts/core/scenario.py +++ b/smarts/core/scenario.py @@ -861,10 +861,10 @@ def to_scenario_via( @staticmethod def _extract_mission_start_time(mission, entry_tactic: Optional[EntryTactic]): return ( - entry_tactic.start_time + mission.start_time + if mission.start_time != sstudio_types.MISSING + else entry_tactic.start_time if entry_tactic - else mission.start_time - if mission.start_time < sstudio_types.MISSING else 0 ) From 7a07ab483447405b6b62e949d1682679cbaaa33a Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Tue, 16 May 2023 13:01:32 +0000 Subject: [PATCH 58/59] Fix type test. --- smarts/sstudio/types/actor/social_agent_actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smarts/sstudio/types/actor/social_agent_actor.py b/smarts/sstudio/types/actor/social_agent_actor.py index 6096a847b8..23e8d665ae 100644 --- a/smarts/sstudio/types/actor/social_agent_actor.py +++ b/smarts/sstudio/types/actor/social_agent_actor.py @@ -63,5 +63,5 @@ class BoidAgentActor(SocialAgentActor): # The max number of vehicles that this agent will control at a time. This value is # honored when using a bubble for boid dynamic assignment. - capacity: BubbleLimits = None + capacity: Optional[BubbleLimits] = None """The capacity of the boid agent to take over vehicles.""" From 1ece0defd8399999df5f7d3675e6f3e565b3f928 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Tue, 16 May 2023 13:09:30 +0000 Subject: [PATCH 59/59] Fix data formatting test. --- envision/tests/test_data_formatter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/envision/tests/test_data_formatter.py b/envision/tests/test_data_formatter.py index 6952cad2b2..064b0ca050 100644 --- a/envision/tests/test_data_formatter.py +++ b/envision/tests/test_data_formatter.py @@ -300,10 +300,10 @@ def sim_data(): [], [4], { - 0: "car-west_0_0-east_0_max-784511-552438-0-0.0", + 0: "car-west_0_0-east_0_max--41457-668134-0-0.0", 1: None, - 2: "car-west_1_0-east_1_max--85270-291315-1-0.0", - 3: "car-west_2_0-east_2_max--63247--53682-2-0.0", + 2: "car-west_1_0-east_1_max--31231--18481-1-0.0", + 3: "car-west_2_0-east_2_max-674625--72317-2-0.0", 4: "AGENT_1", }, [],