Skip to content

Commit

Permalink
move getters to access the RegionManager and add variable typing
Browse files Browse the repository at this point in the history
  • Loading branch information
alwaysintreble committed Oct 8, 2024
1 parent 9379c78 commit 13e51df
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
54 changes: 38 additions & 16 deletions BaseClasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,15 @@ class HasNameAndPlayer(Protocol):
player: int


_T_Reg = typing.TypeVar("_T_Reg", bound="Region")
_T_Ent = typing.TypeVar("_T_Ent", bound="Entrance")
_T_Loc = typing.TypeVar("_T_Loc", bound="Location")


class RegionManager(typing.Generic[_T_Reg, _T_Ent, _T_Loc]):
region_cache: Union[Dict[int, Dict[str, Region]], Dict[str, Region]]
entrance_cache: Union[Dict[int, Dict[str, Entrance]], Dict[str, Entrance]]
location_cache: Union[Dict[int, Dict[str, Location]], Dict[str, Location]]
region_cache: Dict[str, Region]
entrance_cache: Dict[str, Entrance]
location_cache: Dict[str, Location]
multiworld: "MultiWorld"

def __init__(self, multiworld: "Multiworld" = None):
Expand Down Expand Up @@ -94,11 +99,28 @@ def extend(self, regions: Iterable[Region]):
region_cache[region.name] = region

def __iter__(self) -> Iterator[Region]:
for regions in self.region_cache.values():
yield from regions.values()
yield from self.region_cache.values()

def __len__(self):
return sum(len(regions) for regions in self.region_cache.values())
return len(self.region_cache.values())

def get_regions(self) -> typing.Iterable[_T_Reg]:
return self.region_cache.values()

def get_region(self, name: str) -> _T_Reg:
return self.region_cache[name]

def get_locations(self) -> typing.Iterable[_T_Loc]:
return self.location_cache.values()

def get_location(self, name: str) -> _T_Loc:
return self.location_cache[name]

def get_entrances(self) -> typing.Iterable[_T_Ent]:
return self.entrance_cache.values()

def get_entrance(self, name: str) -> _T_Ent:
return self.entrance_cache[name]


class MultiWorld():
Expand Down Expand Up @@ -426,18 +448,18 @@ def world_name_lookup(self):

def get_regions(self, player: Optional[int] = None) -> Collection[Region]:
if player is not None:
return self.worlds[player].regions.region_cache.values()
return Utils.RepeatableChain(tuple(self.worlds[player].regions.region_cache.values()
return self.worlds[player].regions.get_regions()
return Utils.RepeatableChain(tuple(self.worlds[player].regions.get_regions()
for player in self.get_all_ids()))

def get_region(self, region_name: str, player: int) -> Region:
return self.worlds[player].get_region(region_name)
return self.worlds[player].regions.get_region(region_name)

def get_entrance(self, entrance_name: str, player: int) -> Entrance:
return self.worlds[player].get_entrance(entrance_name)
return self.worlds[player].regions.get_entrance(entrance_name)

def get_location(self, location_name: str, player: int) -> Location:
return self.worlds[player].get_location(location_name)
return self.worlds[player].regions.get_location(location_name)

def get_all_state(self, use_cache: bool) -> CollectionState:
cached = getattr(self, "_all_state", None)
Expand Down Expand Up @@ -500,8 +522,8 @@ def push_item(self, location: Location, item: Item, collect: bool = True):

def get_entrances(self, player: Optional[int] = None) -> Iterable[Entrance]:
if player is not None:
return self.worlds[player].regions.entrance_cache.values()
return Utils.RepeatableChain(tuple(self.worlds[player].regions.entrance_cache.values()
return self.worlds[player].regions.get_entrances()
return Utils.RepeatableChain(tuple(self.worlds[player].regions.get_entrances()
for player in self.get_all_ids()))

def register_indirect_condition(self, region: Region, entrance: Entrance):
Expand All @@ -511,8 +533,8 @@ def register_indirect_condition(self, region: Region, entrance: Entrance):

def get_locations(self, player: Optional[int] = None) -> Iterable[Location]:
if player is not None:
return self.worlds[player].regions.location_cache.values()
return Utils.RepeatableChain(tuple(self.worlds[player].regions.location_cache.values()
return self.worlds[player].regions.get_locations()
return Utils.RepeatableChain(tuple(self.worlds[player].regions.get_locations()
for player in self.get_all_ids()))

def get_unfilled_locations(self, player: Optional[int] = None) -> List[Location]:
Expand Down Expand Up @@ -1015,7 +1037,7 @@ class LocationRegister(Register):
def __delitem__(self, index: int) -> None:
location: Location = self._list.__getitem__(index)
self._list.__delitem__(index)
del(self.region_manager.location_cache[location.player][location.name])
del(self.region_manager.location_cache[location.name])

def insert(self, index: int, value: Location) -> None:
assert value.name not in self.region_manager.location_cache, \
Expand Down
6 changes: 3 additions & 3 deletions worlds/AutoWorld.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,13 +532,13 @@ def create_filler(self) -> "Item":

# convenience methods
def get_location(self, location_name: str) -> "Location":
return self.regions.location_cache[location_name]
return self.regions.get_location(location_name)

def get_entrance(self, entrance_name: str) -> "Entrance":
return self.regions.entrance_cache[entrance_name]
return self.regions.get_entrance(entrance_name)

def get_region(self, region_name: str) -> "Region":
return self.regions.region_cache[region_name]
return self.regions.get_region(region_name)

@property
def player_name(self) -> str:
Expand Down

0 comments on commit 13e51df

Please sign in to comment.