Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing, mostly in AutoWorld.py #476

Merged
merged 1 commit into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion BaseClasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MultiWorld():
plando_texts: List[Dict[str, str]]
plando_items: List[List[Dict[str, Any]]]
plando_connections: List
worlds: Dict[int, Any]
worlds: Dict[int, auto_world]
groups: Dict[int, Group]
itempool: List[Item]
is_race: bool = False
Expand Down
22 changes: 13 additions & 9 deletions Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,45 @@ class Version(typing.NamedTuple):
from yaml import Loader


def int16_as_bytes(value):
def int16_as_bytes(value: int) -> typing.List[int]:
value = value & 0xFFFF
return [value & 0xFF, (value >> 8) & 0xFF]


def int32_as_bytes(value):
def int32_as_bytes(value: int) -> typing.List[int]:
value = value & 0xFFFFFFFF
return [value & 0xFF, (value >> 8) & 0xFF, (value >> 16) & 0xFF, (value >> 24) & 0xFF]


def pc_to_snes(value):
def pc_to_snes(value: int) -> int:
return ((value << 1) & 0x7F0000) | (value & 0x7FFF) | 0x8000


def snes_to_pc(value):
def snes_to_pc(value: int) -> int:
return ((value & 0x7F0000) >> 1) | (value & 0x7FFF)


def cache_argsless(function):
RetType = typing.TypeVar("RetType")


def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[], RetType]:
if function.__code__.co_argcount:
raise Exception("Can only cache 0 argument functions with this cache.")

result = sentinel = object()
sentinel = object()
result: typing.Union[object, RetType] = sentinel

def _wrap():
def _wrap() -> RetType:
nonlocal result
if result is sentinel:
result = function()
return result
return typing.cast(RetType, result)

return _wrap


def is_frozen() -> bool:
return getattr(sys, 'frozen', False)
return typing.cast(bool, getattr(sys, 'frozen', False))


def local_path(*path: str) -> str:
Expand Down
10 changes: 5 additions & 5 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ In addition, the following methods can be implemented and attributes can be set
#### generate_early

```python
def generate_early(self):
def generate_early(self) -> None:
# read player settings to world instance
self.final_boss_hp = self.world.final_boss_hp[self.player].value
```
Expand All @@ -456,7 +456,7 @@ def create_event(self, event: str):
#### create_items

```python
def create_items(self):
def create_items(self) -> None:
# Add items to the Multiworld.
# If there are two of the same item, the item has to be twice in the pool.
# Which items are added to the pool may depend on player settings,
Expand All @@ -483,7 +483,7 @@ def create_items(self):
#### create_regions

```python
def create_regions(self):
def create_regions(self) -> None:
# Add regions to the multiworld. "Menu" is the required starting point.
# Arguments to Region() are name, type, human_readable_name, player, world
r = Region("Menu", None, "Menu", self.player, self.world)
Expand Down Expand Up @@ -518,7 +518,7 @@ def create_regions(self):
#### generate_basic

```python
def generate_basic(self):
def generate_basic(self) -> None:
# place "Victory" at "Final Boss" and set collection as win condition
self.world.get_location("Final Boss", self.player)\
.place_locked_item(self.create_event("Victory"))
Expand All @@ -539,7 +539,7 @@ def generate_basic(self):
from ..generic.Rules import add_rule, set_rule, forbid_item
from Items import get_item_type

def set_rules(self):
def set_rules(self) -> None:
# For some worlds this step can be omitted if either a Logic mixin
# (see below) is used, it's easier to apply the rules from data during
# location generation or everything is in generate_basic
Expand Down
69 changes: 38 additions & 31 deletions worlds/AutoWorld.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import annotations

import logging
from typing import Dict, Set, Tuple, List, Optional, TextIO, Any, Callable, Union
from typing import Dict, FrozenSet, Set, Tuple, List, Optional, TextIO, Any, Callable, Union

from BaseClasses import MultiWorld, Item, CollectionState, Location
from Options import Option


class AutoWorldRegister(type):
world_types: Dict[str, World] = {}
world_types: Dict[str, AutoWorldRegister] = {}

def __new__(cls, name: str, bases, dct: Dict[str, Any]):
def __new__(cls, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoWorldRegister:
if "web" in dct:
assert isinstance(dct["web"], WebWorld), "WebWorld has to be instantiated."
# filter out any events
Expand All @@ -34,7 +34,8 @@ def __new__(cls, name: str, bases, dct: Dict[str, Any]):
if "required_client_version" in dct and bases:
for base in bases:
if "required_client_version" in base.__dict__:
dct["required_client_version"] = max(dct["required_client_version"], base.required_client_version)
dct["required_client_version"] = max(dct["required_client_version"],
base.__dict__["required_client_version"])

# construct class
new_class = super().__new__(cls, name, bases, dct)
Expand All @@ -44,9 +45,9 @@ def __new__(cls, name: str, bases, dct: Dict[str, Any]):


class AutoLogicRegister(type):
def __new__(cls, name, bases, dct):
def __new__(cls, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoLogicRegister:
new_class = super().__new__(cls, name, bases, dct)
function: Callable
function: Callable[..., Any]
for item_name, function in dct.items():
if item_name == "copy_mixin":
CollectionState.additional_copy_functions.append(function)
Expand All @@ -59,13 +60,13 @@ def __new__(cls, name, bases, dct):
return new_class


def call_single(world: MultiWorld, method_name: str, player: int, *args):
def call_single(world: MultiWorld, method_name: str, player: int, *args: Any) -> Any:
method = getattr(world.worlds[player], method_name)
return method(*args)


def call_all(world: MultiWorld, method_name: str, *args):
world_types = set()
def call_all(world: MultiWorld, method_name: str, *args: Any) -> None:
world_types: Set[AutoWorldRegister] = set()
for player in world.player_ids:
world_types.add(world.worlds[player].__class__)
call_single(world, method_name, player, *args)
Expand All @@ -76,7 +77,7 @@ def call_all(world: MultiWorld, method_name: str, *args):
stage_callable(world, *args)


def call_stage(world: MultiWorld, method_name: str, *args):
def call_stage(world: MultiWorld, method_name: str, *args: Any) -> None:
world_types = {world.worlds[player].__class__ for player in world.player_ids}
for world_type in world_types:
stage_callable = getattr(world_type, f"stage_{method_name}", None)
Expand All @@ -101,10 +102,12 @@ class World(metaclass=AutoWorldRegister):
"""A World object encompasses a game's Items, Locations, Rules and additional data or functionality required.
A Game should have its own subclass of World in which it defines the required data structures."""

options: Dict[str, type(Option)] = {} # link your Options mapping
options: Dict[str, Option[Any]] = {} # link your Options mapping
game: str # name the game
topology_present: bool = False # indicate if world type has any meaningful layout/pathing
all_item_and_group_names: Set[str] = frozenset() # gets automatically populated with all item and item group names

# gets automatically populated with all item and item group names
all_item_and_group_names: FrozenSet[str] = frozenset()

# map names to their IDs
item_name_to_id: Dict[str, int] = {}
Expand All @@ -126,7 +129,7 @@ class World(metaclass=AutoWorldRegister):
# update this if the resulting multidata breaks forward-compatibility of the server
required_server_version: Tuple[int, int, int] = (0, 2, 4)

hint_blacklist: Set[str] = frozenset() # any names that should not be hintable
hint_blacklist: FrozenSet[str] = frozenset() # any names that should not be hintable

# NOTE: remote_items and remote_start_inventory are now available in the network protocol for the client to set.
# These values will be removed.
Expand Down Expand Up @@ -168,61 +171,65 @@ def __init__(self, world: MultiWorld, player: int):
# can also be implemented as a classmethod and called "stage_<original_name>",
# in that case the MultiWorld object is passed as an argument and it gets called once for the entire multiworld.
# An example of this can be found in alttp as stage_pre_fill
def generate_early(self):
def generate_early(self) -> None:
pass

def create_regions(self):
def create_regions(self) -> None:
pass

def create_items(self):
def create_items(self) -> None:
pass

def set_rules(self):
def set_rules(self) -> None:
pass

def generate_basic(self):
def generate_basic(self) -> None:
pass

def pre_fill(self):
def pre_fill(self) -> None:
"""Optional method that is supposed to be used for special fill stages. This is run *after* plando."""
pass

@classmethod
def fill_hook(cls, progitempool: List[Item], nonexcludeditempool: List[Item],
localrestitempool: Dict[int, List[Item]], nonlocalrestitempool: Dict[int, List[Item]],
restitempool: List[Item], fill_locations: List[Location]):
def fill_hook(cls,
progitempool: List[Item],
nonexcludeditempool: List[Item],
localrestitempool: Dict[int, List[Item]],
nonlocalrestitempool: Dict[int, List[Item]],
restitempool: List[Item],
fill_locations: List[Location]) -> None:
"""Special method that gets called as part of distribute_items_restrictive (main fill).
This gets called once per present world type."""
pass

def post_fill(self):
def post_fill(self) -> None:
"""Optional Method that is called after regular fill. Can be used to do adjustments before output generation."""

def generate_output(self, output_directory: str):
def generate_output(self, output_directory: str) -> None:
"""This method gets called from a threadpool, do not use world.random here.
If you need any last-second randomization, use MultiWorld.slot_seeds[slot] instead."""
pass

def fill_slot_data(self) -> dict:
def fill_slot_data(self) -> Dict[str, Any]: # json of WebHostLib.models.Slot
"""Fill in the slot_data field in the Connected network package."""
return {}

def modify_multidata(self, multidata: dict):
def modify_multidata(self, multidata: Dict[str, Any]) -> None: # TODO: TypedDict for multidata?
"""For deeper modification of server multidata."""
pass

# Spoiler writing is optional, these may not get called.
def write_spoiler_header(self, spoiler_handle: TextIO):
def write_spoiler_header(self, spoiler_handle: TextIO) -> None:
"""Write to the spoiler header. If individual it's right at the end of that player's options,
if as stage it's right under the common header before per-player options."""
pass

def write_spoiler(self, spoiler_handle: TextIO):
def write_spoiler(self, spoiler_handle: TextIO) -> None:
"""Write to the spoiler "middle", this is after the per-player options and before locations,
meant for useful or interesting info."""
pass

def write_spoiler_end(self, spoiler_handle: TextIO):
def write_spoiler_end(self, spoiler_handle: TextIO) -> None:
"""Write to the end of the spoiler"""
pass

Expand All @@ -236,7 +243,7 @@ def create_item(self, name: str) -> Item:
def get_filler_item_name(self) -> str:
"""Called when the item pool needs to be filled with additional items to match location count."""
logging.warning(f"World {self} is generating a filler item without custom filler pool.")
return self.world.random.choice(self.item_name_to_id)
return self.world.random.choice(tuple(self.item_name_to_id.keys()))

# decent place to implement progressive items, in most cases can stay as-is
def collect_item(self, state: CollectionState, item: Item, remove: bool = False) -> Optional[str]:
Expand All @@ -247,6 +254,7 @@ def collect_item(self, state: CollectionState, item: Item, remove: bool = False)
:param remove: indicate if this is meant to remove from state instead of adding."""
if item.advancement:
return item.name
return None

# called to create all_state, return Items that are created during pre_fill
def get_pre_fill_items(self) -> List[Item]:
Expand Down Expand Up @@ -277,4 +285,3 @@ def create_filler(self) -> Item:
# please use a prefix as all of them get clobbered together
class LogicMixin(metaclass=AutoLogicRegister):
pass