diff --git a/retro_data_structures/asset_provider.py b/retro_data_structures/asset_provider.py index 049df32..3a65604 100644 --- a/retro_data_structures/asset_provider.py +++ b/retro_data_structures/asset_provider.py @@ -1,4 +1,5 @@ import logging +import typing from pathlib import Path from typing import List, BinaryIO, Optional @@ -26,16 +27,18 @@ def __init__(self, asset_id, reason: str): class AssetProvider: _pak_files: Optional[List[BinaryIO]] = None - def __init__(self, pak_paths: List[Path], target_game: Game): + def __init__(self, target_game: Game, pak_paths: List[Path], pak_files: Optional[List[typing.BinaryIO]] = None): self.pak_paths = pak_paths + self._pak_files = pak_files self.target_game = target_game self.loaded_assets = {} def __enter__(self): - self._pak_files = [ - path.open("rb") - for path in self.pak_paths - ] + if self._pak_files is None: + self._pak_files = [ + path.open("rb") + for path in self.pak_paths + ] self._paks = [] for i, pak_file in enumerate(self._pak_files): logger.info("Parsing PAK at %s", str(self.pak_paths[i])) diff --git a/retro_data_structures/cli.py b/retro_data_structures/cli.py index e3ebc40..781ed29 100644 --- a/retro_data_structures/cli.py +++ b/retro_data_structures/cli.py @@ -142,7 +142,7 @@ def do_decode_from_pak(args): paks_path: Path = args.paks_path asset_id: int = args.asset_id - with AssetProvider(list(paks_path.glob("*.pak")), game) as asset_provider: + with AssetProvider(game, list(paks_path.glob("*.pak"))) as asset_provider: print(asset_provider.get_asset(asset_id)) @@ -151,7 +151,7 @@ def list_dependencies(args): paks_path: Path = args.paks_path asset_ids: List[int] - with AssetProvider(list(paks_path.glob("*.pak")), game) as asset_provider: + with AssetProvider(game, list(paks_path.glob("*.pak"))) as asset_provider: if args.asset_ids is not None: asset_ids = args.asset_ids else: @@ -171,7 +171,7 @@ def do_convert(args): paks_path: Path = args.paks_path asset_ids: List[int] = args.asset_ids - with AssetProvider(list(paks_path.glob("*.pak")), source_game) as asset_provider: + with AssetProvider(source_game, list(paks_path.glob("*.pak"))) as asset_provider: next_id = 0xFFFF0000 def id_generator(asset_type): diff --git a/test/formats/test_ancs.py b/test/formats/test_ancs.py index 00bb80c..496db5d 100644 --- a/test/formats/test_ancs.py +++ b/test/formats/test_ancs.py @@ -30,7 +30,7 @@ def test_compare_p2(prime2_pwe_project): def test_dependencies_all_p1(prime1_pwe_project): pak_path = prime1_pwe_project.joinpath("Disc", "files") - with AssetProvider(list(pak_path.glob("*.pak")), Game.PRIME) as asset_provider: + with AssetProvider(Game.PRIME, list(pak_path.glob("*.pak"))) as asset_provider: asset_ids = [ asset_id for asset_id, (resource, _) in asset_provider._resource_by_asset_id.items()