From 13e51dfe60eb0a25e4c2b421aaaf43d1ee063472 Mon Sep 17 00:00:00 2001 From: alwaysintreble Date: Tue, 8 Oct 2024 17:45:40 -0500 Subject: [PATCH] move getters to access the RegionManager and add variable typing --- BaseClasses.py | 54 +++++++++++++++++++++++++++++++-------------- worlds/AutoWorld.py | 6 ++--- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/BaseClasses.py b/BaseClasses.py index a345ccb3c1a7..f24c3ebcd3e8 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -55,10 +55,15 @@ class HasNameAndPlayer(Protocol): player: int +_T_Reg = typing.TypeVar("_T_Reg", bound="Region") +_T_Ent = typing.TypeVar("_T_Ent", bound="Entrance") +_T_Loc = typing.TypeVar("_T_Loc", bound="Location") + + class RegionManager(typing.Generic[_T_Reg, _T_Ent, _T_Loc]): - region_cache: Union[Dict[int, Dict[str, Region]], Dict[str, Region]] - entrance_cache: Union[Dict[int, Dict[str, Entrance]], Dict[str, Entrance]] - location_cache: Union[Dict[int, Dict[str, Location]], Dict[str, Location]] + region_cache: Dict[str, Region] + entrance_cache: Dict[str, Entrance] + location_cache: Dict[str, Location] multiworld: "MultiWorld" def __init__(self, multiworld: "Multiworld" = None): @@ -94,11 +99,28 @@ def extend(self, regions: Iterable[Region]): region_cache[region.name] = region def __iter__(self) -> Iterator[Region]: - for regions in self.region_cache.values(): - yield from regions.values() + yield from self.region_cache.values() def __len__(self): - return sum(len(regions) for regions in self.region_cache.values()) + return len(self.region_cache.values()) + + def get_regions(self) -> typing.Iterable[_T_Reg]: + return self.region_cache.values() + + def get_region(self, name: str) -> _T_Reg: + return self.region_cache[name] + + def get_locations(self) -> typing.Iterable[_T_Loc]: + return self.location_cache.values() + + def get_location(self, name: str) -> _T_Loc: + return self.location_cache[name] + + def get_entrances(self) -> typing.Iterable[_T_Ent]: + return self.entrance_cache.values() + + def get_entrance(self, name: str) -> _T_Ent: + return self.entrance_cache[name] class MultiWorld(): @@ -426,18 +448,18 @@ def world_name_lookup(self): def get_regions(self, player: Optional[int] = None) -> Collection[Region]: if player is not None: - return self.worlds[player].regions.region_cache.values() - return Utils.RepeatableChain(tuple(self.worlds[player].regions.region_cache.values() + return self.worlds[player].regions.get_regions() + return Utils.RepeatableChain(tuple(self.worlds[player].regions.get_regions() for player in self.get_all_ids())) def get_region(self, region_name: str, player: int) -> Region: - return self.worlds[player].get_region(region_name) + return self.worlds[player].regions.get_region(region_name) def get_entrance(self, entrance_name: str, player: int) -> Entrance: - return self.worlds[player].get_entrance(entrance_name) + return self.worlds[player].regions.get_entrance(entrance_name) def get_location(self, location_name: str, player: int) -> Location: - return self.worlds[player].get_location(location_name) + return self.worlds[player].regions.get_location(location_name) def get_all_state(self, use_cache: bool) -> CollectionState: cached = getattr(self, "_all_state", None) @@ -500,8 +522,8 @@ def push_item(self, location: Location, item: Item, collect: bool = True): def get_entrances(self, player: Optional[int] = None) -> Iterable[Entrance]: if player is not None: - return self.worlds[player].regions.entrance_cache.values() - return Utils.RepeatableChain(tuple(self.worlds[player].regions.entrance_cache.values() + return self.worlds[player].regions.get_entrances() + return Utils.RepeatableChain(tuple(self.worlds[player].regions.get_entrances() for player in self.get_all_ids())) def register_indirect_condition(self, region: Region, entrance: Entrance): @@ -511,8 +533,8 @@ def register_indirect_condition(self, region: Region, entrance: Entrance): def get_locations(self, player: Optional[int] = None) -> Iterable[Location]: if player is not None: - return self.worlds[player].regions.location_cache.values() - return Utils.RepeatableChain(tuple(self.worlds[player].regions.location_cache.values() + return self.worlds[player].regions.get_locations() + return Utils.RepeatableChain(tuple(self.worlds[player].regions.get_locations() for player in self.get_all_ids())) def get_unfilled_locations(self, player: Optional[int] = None) -> List[Location]: @@ -1015,7 +1037,7 @@ class LocationRegister(Register): def __delitem__(self, index: int) -> None: location: Location = self._list.__getitem__(index) self._list.__delitem__(index) - del(self.region_manager.location_cache[location.player][location.name]) + del(self.region_manager.location_cache[location.name]) def insert(self, index: int, value: Location) -> None: assert value.name not in self.region_manager.location_cache, \ diff --git a/worlds/AutoWorld.py b/worlds/AutoWorld.py index 88550c407ca5..709aba566e12 100644 --- a/worlds/AutoWorld.py +++ b/worlds/AutoWorld.py @@ -532,13 +532,13 @@ def create_filler(self) -> "Item": # convenience methods def get_location(self, location_name: str) -> "Location": - return self.regions.location_cache[location_name] + return self.regions.get_location(location_name) def get_entrance(self, entrance_name: str) -> "Entrance": - return self.regions.entrance_cache[entrance_name] + return self.regions.get_entrance(entrance_name) def get_region(self, region_name: str) -> "Region": - return self.regions.region_cache[region_name] + return self.regions.get_region(region_name) @property def player_name(self) -> str: