diff --git a/pytmx/pytmx.py b/pytmx/pytmx.py index 7d40621..9bbe454 100644 --- a/pytmx/pytmx.py +++ b/pytmx/pytmx.py @@ -17,6 +17,7 @@ License along with pytmx. If not, see . """ + from __future__ import annotations import gzip @@ -54,6 +55,8 @@ "convert_to_bool", "resolve_to_class", "parse_properties", + "decode_gid", + "unpack_gids", ) logger = logging.getLogger(__name__) @@ -85,6 +88,7 @@ TiledLayer = Union[ "TiledTileLayer", "TiledImageLayer", "TiledGroupLayer", "TiledObjectGroup" ] +flag_cache: dict[int, TileFlags] = {} # need a more graceful way to handle annotations for optional dependencies if pygame: @@ -125,15 +129,19 @@ def decode_gid(raw_gid: int) -> tuple[int, TileFlags]: """ if raw_gid < GID_TRANS_ROT: return raw_gid, empty_flags - return ( - raw_gid & ~GID_MASK, - # TODO: cache all combinations of flags - TileFlags( - raw_gid & GID_TRANS_FLIPX == GID_TRANS_FLIPX, - raw_gid & GID_TRANS_FLIPY == GID_TRANS_FLIPY, - raw_gid & GID_TRANS_ROT == GID_TRANS_ROT, - ), + + # Check if the GID is already in the cache + if raw_gid in flag_cache: + return raw_gid & ~GID_MASK, flag_cache[raw_gid] + + # Calculate and cache the flags + flags = TileFlags( + raw_gid & GID_TRANS_FLIPX == GID_TRANS_FLIPX, + raw_gid & GID_TRANS_FLIPY == GID_TRANS_FLIPY, + raw_gid & GID_TRANS_ROT == GID_TRANS_ROT, ) + flag_cache[raw_gid] = flags + return raw_gid & ~GID_MASK, flags def reshape_data( diff --git a/tests/pytmx/test_pytmx.py b/tests/pytmx/test_pytmx.py index 9ab99ef..c8f1211 100644 --- a/tests/pytmx/test_pytmx.py +++ b/tests/pytmx/test_pytmx.py @@ -1,7 +1,24 @@ import unittest import pytmx -from pytmx import TiledElement, convert_to_bool +import base64 +import gzip +import zlib +import struct +from pytmx import ( + TiledElement, + TiledMap, + convert_to_bool, + TileFlags, + decode_gid, + unpack_gids, +) + +# Tiled gid flags +GID_TRANS_FLIPX = 1 << 31 +GID_TRANS_FLIPY = 1 << 30 +GID_TRANS_ROT = 1 << 29 +GID_MASK = GID_TRANS_FLIPX | GID_TRANS_FLIPY | GID_TRANS_ROT class TestConvertToBool(unittest.TestCase): @@ -64,6 +81,15 @@ def test_non_boolean_number_raises_error(self) -> None: with self.assertRaises(ValueError): convert_to_bool("200") + def test_edge_cases(self): + # Whitespace + self.assertTrue(convert_to_bool(" t ")) + self.assertFalse(convert_to_bool(" f ")) + + # Numeric edge cases + self.assertTrue(convert_to_bool(1e-10)) # Very small positive number + self.assertFalse(convert_to_bool(-1e-10)) # Very small negative number + class TiledMapTest(unittest.TestCase): filename = "tests/resources/test01.tmx" @@ -170,3 +196,145 @@ def test_reserved_names_check_disabled_with_option(self) -> None: def test_repr(self) -> None: self.element.name = "foo" self.assertEqual('', self.element.__repr__()) + + +class TestDecodeGid(unittest.TestCase): + def test_no_flags(self): + raw_gid = 100 + expected_gid, expected_flags = 100, TileFlags(False, False, False) + self.assertEqual(decode_gid(raw_gid), (expected_gid, expected_flags)) + + def test_individual_flags(self): + # Test for each flag individually + test_cases = [ + (GID_TRANS_FLIPX + 1, 1, TileFlags(True, False, False)), + (GID_TRANS_FLIPY + 1, 1, TileFlags(False, True, False)), + (GID_TRANS_ROT + 1, 1, TileFlags(False, False, True)), + ] + for raw_gid, expected_gid, expected_flags in test_cases: + self.assertEqual(decode_gid(raw_gid), (expected_gid, expected_flags)) + + def test_combinations_of_flags(self): + # Test combinations of flags + test_cases = [ + (GID_TRANS_FLIPX + GID_TRANS_FLIPY + 1, 1, TileFlags(True, True, False)), + (GID_TRANS_FLIPX + GID_TRANS_ROT + 1, 1, TileFlags(True, False, True)), + (GID_TRANS_FLIPY + GID_TRANS_ROT + 1, 1, TileFlags(False, True, True)), + ( + GID_TRANS_FLIPX + GID_TRANS_FLIPY + GID_TRANS_ROT + 1, + 1, + TileFlags(True, True, True), + ), + ] + for raw_gid, expected_gid, expected_flags in test_cases: + self.assertEqual(decode_gid(raw_gid), (expected_gid, expected_flags)) + + def test_edge_cases(self): + # Maximum GID + # max_gid = 2**32 - 1 + # self.assertEqual(decode_gid(max_gid), (max_gid & ~GID_MASK, TileFlags(False, False, False))) + + # Minimum GID + min_gid = 0 + self.assertEqual(decode_gid(min_gid), (min_gid, TileFlags(False, False, False))) + + # GID with all flags set + gid_all_flags = GID_TRANS_FLIPX + GID_TRANS_FLIPY + GID_TRANS_ROT + 1 + self.assertEqual(decode_gid(gid_all_flags), (1, TileFlags(True, True, True))) + + # GID with flags in different orders + test_cases = [ + (GID_TRANS_FLIPX + GID_TRANS_FLIPY + 1, 1, TileFlags(True, True, False)), + (GID_TRANS_FLIPY + GID_TRANS_FLIPX + 1, 1, TileFlags(True, True, False)), + (GID_TRANS_FLIPX + GID_TRANS_ROT + 1, 1, TileFlags(True, False, True)), + (GID_TRANS_ROT + GID_TRANS_FLIPX + 1, 1, TileFlags(True, False, True)), + ] + for raw_gid, expected_gid, expected_flags in test_cases: + self.assertEqual(decode_gid(raw_gid), (expected_gid, expected_flags)) + + +class TestRegisterGid(unittest.TestCase): + def setUp(self): + self.tmx_map = TiledMap() + + def test_register_gid_with_valid_tiled_gid(self): + gid = self.tmx_map.register_gid(123) + self.assertIsNotNone(gid) + + def test_register_gid_with_flags(self): + flags = TileFlags(1, 0, 1) + gid = self.tmx_map.register_gid(456, flags) + self.assertIsNotNone(gid) + + def test_register_gid_zero(self): + gid = self.tmx_map.register_gid(0) + self.assertEqual(gid, 0) + + def test_register_gid_max_gid(self): + max_gid = self.tmx_map.maxgid + self.tmx_map.register_gid(max_gid) + self.assertEqual(self.tmx_map.maxgid, max_gid + 1) + + def test_register_gid_duplicate_gid(self): + gid1 = self.tmx_map.register_gid(123) + gid2 = self.tmx_map.register_gid(123) + self.assertEqual(gid1, gid2) + + def test_register_gid_duplicate_gid_different_flags(self): + gid1 = self.tmx_map.register_gid(123, TileFlags(1, 0, 0)) + gid2 = self.tmx_map.register_gid(123, TileFlags(0, 1, 0)) + self.assertNotEqual(gid1, gid2) + + def test_register_gid_empty_flags(self): + gid = self.tmx_map.register_gid(123, TileFlags(0, 0, 0)) + self.assertIsNotNone(gid) + + def test_register_gid_all_flags_set(self): + gid = self.tmx_map.register_gid(123, TileFlags(1, 1, 1)) + self.assertIsNotNone(gid) + + def test_register_gid_repeated_registration(self): + gid1 = self.tmx_map.register_gid(123) + gid2 = self.tmx_map.register_gid(123) + self.assertEqual(gid1, gid2) + + +class TestUnpackGids(unittest.TestCase): + def test_base64_no_compression(self): + gids = [123, 456, 789] + data = struct.pack("