Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a basic E2E test suite #638

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ roms/*
stats/
stream/
plugins/
tests/last_frame_timeout.png
1 change: 1 addition & 0 deletions modules/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, initial_bot_mode: str = "Manual"):
self.profile: Optional["Profile"] = None
self.stats: Optional["StatsDatabase"] = None
self.debug: bool = False
self.testing: bool = False

self._current_message: str = ""

Expand Down
9 changes: 5 additions & 4 deletions modules/encounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,11 @@ def handle_encounter(
battle_is_active = get_game_state() in (GameState.BATTLE, GameState.BATTLE_STARTING, GameState.BATTLE_ENDING)

if is_of_interest:
filename_suffix = (
f"{encounter_info.value.name}_{make_string_safe_for_file_name(pokemon.species_name_for_stats)}"
)
context.emulator.create_save_state(suffix=filename_suffix)
if not context.testing:
filename_suffix = (
f"{encounter_info.value.name}_{make_string_safe_for_file_name(pokemon.species_name_for_stats)}"
)
context.emulator.create_save_state(suffix=filename_suffix)

if context.config.battle.auto_catch and not disable_auto_catch and battle_is_active:
encounter_info.battle_action = BattleAction.Catch
Expand Down
18 changes: 10 additions & 8 deletions modules/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,19 @@ def _load_event_flags_and_vars(file_name: str) -> None: # TODO Japanese ROMs no

_event_flags.clear()
_reverse_event_flags.clear()
for s in open(get_data_path() / "event_flags" / file_name):
number, name = s.strip().split(" ")
_event_flags[name] = (int(number) // 8) + flags_offset, int(number) % 8
_reverse_event_flags[int(number)] = name
with open(get_data_path() / "event_flags" / file_name) as file_handle:
for s in file_handle:
number, name = s.strip().split(" ")
_event_flags[name] = (int(number) // 8) + flags_offset, int(number) % 8
_reverse_event_flags[int(number)] = name

_event_vars.clear()
_reverse_event_vars.clear()
for s in open(get_data_path() / "event_vars" / file_name):
number, name = s.strip().split(" ")
_event_vars[name] = int(number) * 2 + vars_offset
_reverse_event_vars[int(number)] = name
with open(get_data_path() / "event_vars" / file_name) as file_handle:
for s in file_handle:
number, name = s.strip().split(" ")
_event_vars[name] = int(number) * 2 + vars_offset
_reverse_event_vars[int(number)] = name


def _prepare_character_tables() -> None:
Expand Down
4 changes: 0 additions & 4 deletions modules/gui/debug_menu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
import tkinter
import webbrowser
import zlib
Expand Down Expand Up @@ -29,16 +28,13 @@
pack_uint8,
pack_uint32,
set_event_var,
get_game_state,
GameState,
read_symbol,
)
from modules.modes import BotListener, BotMode, FrameInfo
from modules.player import get_player
from modules.pokemon import get_opponent
from modules.pokemon_party import get_party, get_party_size
from modules.runtime import get_sprites_path, get_base_path
from modules.tasks import task_is_active


def _create_save_state() -> None:
Expand Down
37 changes: 26 additions & 11 deletions modules/libmgba.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ class LibmgbaEmulator:
_audio_sample_rate: int = 32768
_last_audio_data: Queue[bytes]

def __init__(self, profile: Profile, on_frame_callback: callable):
console.print(f"Running [cyan]{libmgba_version_string()}[/]")
def __init__(self, profile: Profile, on_frame_callback: callable, is_test_run: bool = False):
if not is_test_run:
console.print(f"Running [cyan]{libmgba_version_string()}[/]")

# Prevents relentless spamming to stdout by libmgba.
mgba.log.silence()
Expand All @@ -121,12 +122,13 @@ def __init__(self, profile: Profile, on_frame_callback: callable):
# libmgba needs a save file to be loaded, or otherwise it will not save anything
# to disk if the player saves the game. This can be an empty file.
self._current_save_path = profile.path / "current_save.sav"
if not self._current_save_path.exists():
# Create an empty file if a save game does not exist yet.
with open(self._current_save_path, "wb"):
pass
self._save = mgba.vfs.open_path(str(self._current_save_path), "r+")
self._core.load_save(self._save)
if not is_test_run:
if not self._current_save_path.exists():
# Create an empty file if a save game does not exist yet.
with open(self._current_save_path, "wb"):
pass
self._save = mgba.vfs.open_path(str(self._current_save_path), "r+")
self._core.load_save(self._save)
self._last_audio_data = Queue(maxsize=128)

self._screen = mgba.image.Image(*self._core.desired_video_dimensions())
Expand All @@ -136,7 +138,7 @@ def __init__(self, profile: Profile, on_frame_callback: callable):
# Whenever the emulator closes, it stores the current state in `current_state.ss1`.
# Load this file if it exists, to continue exactly where we left off.
self._current_state_path = profile.path / "current_state.ss1"
if self._current_state_path.exists():
if not is_test_run and self._current_state_path.exists():
with open(self._current_state_path, "rb") as state_file:
self.load_save_state(state_file.read())

Expand All @@ -151,8 +153,9 @@ def __init__(self, profile: Profile, on_frame_callback: callable):
self._pressed_inputs: int = 0
self._held_inputs: int = 0

atexit.register(self.shutdown)
self._core._callbacks.savedata_updated.append(self.backup_current_save_game)
if not is_test_run:
atexit.register(self.shutdown)
self._core._callbacks.savedata_updated.append(self.backup_current_save_game)

def _reset_audio(self) -> None:
"""
Expand Down Expand Up @@ -368,6 +371,18 @@ def load_save_state(self, state: bytes) -> None:
vfile.seek(0, whence=0)
self._core.load_state(vfile)

def load_save_game(self, save_game: bytes) -> None:
"""
Loads GBA save data from a string. This should only be used for testing
because it means that save data will not be written out to a file after
an in-game save.
:param save_game: The raw save game data.
"""
vfile = mgba.vfs.VFile.fromEmpty()
vfile.write(save_game, len(save_game))
vfile.seek(0, whence=0)
self._core.load_save(vfile)

def read_save_data(self) -> bytes:
"""
Reads and returns the contents of the save game (SRAM/Flash)
Expand Down
40 changes: 27 additions & 13 deletions modules/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def to_dict(self) -> dict:
return data


_map_layout_cache: dict[tuple[int, int], bytes] = {}
_map_layout_cache: dict[str, dict[tuple[int, int], bytes]] = {}


class MapLocation:
Expand All @@ -744,10 +744,14 @@ def __init__(self, map_header: bytes, map_group: int, map_number: int, local_pos
@cached_property
def _map_layout(self) -> bytes:
global _map_layout_cache
if self.map_group_and_number not in _map_layout_cache:
if context.rom.id not in _map_layout_cache:
_map_layout_cache[context.rom.id] = {}
if self.map_group_and_number not in _map_layout_cache[context.rom.id]:
map_layout_pointer = unpack_uint32(self._map_header[:4])
_map_layout_cache[self.map_group_and_number] = context.emulator.read_bytes(map_layout_pointer, 24)
return _map_layout_cache[self.map_group_and_number]
_map_layout_cache[context.rom.id][self.map_group_and_number] = context.emulator.read_bytes(
map_layout_pointer, 24
)
return _map_layout_cache[context.rom.id][self.map_group_and_number]

@cached_property
def _metatile_attributes(self) -> tuple[int, int, int]:
Expand Down Expand Up @@ -1704,7 +1708,7 @@ def get_map_data_for_current_position() -> MapLocation | None:
return MapLocation(read_symbol("gMapHeader"), map_group, map_number, player.local_coordinates)


_map_header_cache: dict[tuple[int, int], bytes] = {}
_map_header_cache: dict[str, dict[tuple[int, int], bytes]] = {}


def get_map_data(
Expand All @@ -1714,7 +1718,10 @@ def get_map_data(
if not isinstance(map_group_and_number, tuple):
map_group_and_number = map_group_and_number.value

if len(_map_header_cache) == 0:
if context.rom.id not in _map_header_cache:
_map_header_cache[context.rom.id] = {}

if len(_map_header_cache[context.rom.id]) == 0:
from modules.map_data import MapGroupFRLG, MapGroupRSE, MapRSE

if context.rom.is_rse:
Expand All @@ -1734,13 +1741,16 @@ def get_map_data(

map_header_pointer = unpack_uint32(context.emulator.read_bytes(group_pointer + 4 * map_index, 4))
map_header = context.emulator.read_bytes(map_header_pointer, 0x1C)
_map_header_cache[(group_index, map_index)] = map_header
_map_header_cache[context.rom.id][(group_index, map_index)] = map_header

if map_group_and_number not in _map_header_cache:
if map_group_and_number not in _map_header_cache[context.rom.id]:
raise ValueError(f"Tried to access invalid map: ({map_group_and_number})")

return MapLocation(
_map_header_cache[map_group_and_number], map_group_and_number[0], map_group_and_number[1], local_position
_map_header_cache[context.rom.id][map_group_and_number],
map_group_and_number[0],
map_group_and_number[1],
local_position,
)


Expand Down Expand Up @@ -1827,12 +1837,16 @@ def to_dict(self):
}


_wild_encounters_cache: dict[tuple[int, int], WildEncounterList] = {}
_wild_encounters_cache: dict[str, dict[tuple[int, int], WildEncounterList]] = {}


def get_wild_encounters_for_map(map_group: int, map_number: int) -> WildEncounterList | None:
global _wild_encounters_cache
if len(_wild_encounters_cache) == 0:

if context.rom.id not in _wild_encounters_cache:
_wild_encounters_cache[context.rom.id] = {}

if len(_wild_encounters_cache[context.rom.id]) == 0:
types = (
(4, 8, "land", 12, (20, 20, 10, 10, 10, 10, 5, 5, 4, 4, 1, 1)),
(8, 12, "surf", 5, (60, 30, 5, 4, 1)),
Expand Down Expand Up @@ -1883,9 +1897,9 @@ def get_encounters_list(address: int, length: int, encounter_rates: list[int]) -
data["old_rod_encounters"] = get_encounters_list(list_pointer, 2, rates[:2])
data["good_rod_encounters"] = get_encounters_list(list_pointer + 8, 3, rates[2:5])
data["super_rod_encounters"] = get_encounters_list(list_pointer + 20, 5, rates[5:10])
_wild_encounters_cache[(group, number)] = WildEncounterList(**data)
_wild_encounters_cache[context.rom.id][(group, number)] = WildEncounterList(**data)

return _wild_encounters_cache.get((map_group, map_number))
return _wild_encounters_cache[context.rom.id].get((map_group, map_number))


@dataclass
Expand Down
23 changes: 14 additions & 9 deletions modules/map_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def contains_global_coordinates(self, global_coordinates: tuple[int, int]) -> bo
)


_maps: dict[tuple[int, int], PathMap] = {}
_maps: dict[str, dict[tuple[int, int], PathMap]] = {}


def _get_connection_for_direction(map_data: MapLocation, direction: str) -> tuple[tuple[int, int], int] | None:
Expand All @@ -318,8 +318,11 @@ def _get_connection_for_direction(map_data: MapLocation, direction: str) -> tupl
def _get_all_maps_metadata() -> dict[tuple[int, int], PathMap]:
global _maps

if len(_maps) > 0:
return _maps
game_key = context.rom.id
if game_key in _maps:
return _maps[game_key]

_maps[game_key] = {}

if context.rom.is_rse:
maps_enum = MapRSE
Expand All @@ -344,12 +347,14 @@ def _get_all_maps_metadata() -> dict[tuple[int, int], PathMap]:
_get_connection_for_direction(map_data, "South"),
_get_connection_for_direction(map_data, "West"),
]
_maps[map_address.value] = PathMap(map_address.value, map_data.map_size, None, -1, map_connections, None)
_maps[game_key][map_address.value] = PathMap(
map_address.value, map_data.map_size, None, -1, map_connections, None
)

# For each map, find all connected maps and set an offset for each of them
current_map_level = 0
for map_address in reversed(_maps):
map = _maps[map_address]
for map_address in reversed(_maps[game_key]):
map = _maps[game_key][map_address]
if map.offset is None:
map.offset = (0, 0)
map.level = current_map_level
Expand All @@ -359,7 +364,7 @@ def _get_all_maps_metadata() -> dict[tuple[int, int], PathMap]:
map_queue.put_nowait(map_address)

while not map_queue.empty():
map_to_check = _maps[map_queue.get_nowait()]
map_to_check = _maps[game_key][map_queue.get_nowait()]
for direction in Direction:
connection = map_to_check.connections[direction]
if connection is not None:
Expand All @@ -368,7 +373,7 @@ def _get_all_maps_metadata() -> dict[tuple[int, int], PathMap]:
interconnected_maps.add(connection_address)
map_queue.put_nowait(connection_address)

connected_map = _maps[connection_address]
connected_map = _maps[game_key][connection_address]
connected_map.level = current_map_level
if direction is Direction.North:
connected_map.offset = (
Expand All @@ -395,7 +400,7 @@ def _get_all_maps_metadata() -> dict[tuple[int, int], PathMap]:

current_map_level += 1

return _maps
return _maps[game_key]


def _get_map_metadata(map_address: tuple[int, int]) -> PathMap:
Expand Down
4 changes: 4 additions & 0 deletions modules/roms.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ def is_gen3(self) -> bool:
def is_gen2(self) -> bool:
return self.is_crystal or self.is_gs

@property
def id(self) -> str:
return f"{self.game_code}{self.language.value}{self.revision}"


class InvalidROMError(Exception):
pass
Expand Down
32 changes: 18 additions & 14 deletions modules/save_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,26 @@ class MigrationError(Exception):


def migrate_save_state(file: IO, profile_name: str, selected_rom: ROM) -> Profile:
selected_rom, state_data, savegame_data = guess_rom_from_save_state(file, selected_rom)

profile = create_profile(profile_name, selected_rom)
if state_data is not None:
with open(profile.path / "current_state.ss1", "wb") as state_file:
state_file.write(state_data)

if savegame_data is not None:
with open(profile.path / "current_save.sav", "wb") as save_file:
save_file.write(savegame_data)

file.close()

return profile


def guess_rom_from_save_state(file, selected_rom) -> tuple[ROM, bytes, bytes | None]:
file.seek(0)
magic = file.read(4)
file.seek(0)

# mGBA state files can either contain the raw serialised state data, or it can
# contain a PNG file that contains a custom 'gbAs' chunk, which in turn contains
# the actual (zlib-compressed) state data. We'd like to support both.
Expand Down Expand Up @@ -58,19 +74,7 @@ def migrate_save_state(file: IO, profile_name: str, selected_rom: ROM) -> Profil
'Please place your .gba ROMs in the "roms/" folder.'
)
selected_rom = matching_rom

profile = create_profile(profile_name, selected_rom)
if state_data is not None:
with open(profile.path / "current_state.ss1", "wb") as state_file:
state_file.write(state_data)

if savegame_data is not None:
with open(profile.path / "current_save.sav", "wb") as save_file:
save_file.write(savegame_data)

file.close()

return profile
return selected_rom, state_data, savegame_data


def get_state_data_from_mgba_state_file(file: IO) -> tuple[bytes, bytes | None]:
Expand Down
Loading