diff --git a/changes/257.feature.md b/changes/257.feature.md index 32fefeab..453b22bf 100644 --- a/changes/257.feature.md +++ b/changes/257.feature.md @@ -1,9 +1,12 @@ - Added the `NBTag` to deal with NBT data: - The `NBTag` class is the base class for all NBT tags and provides the basic functionality to serialize and deserialize NBT data from and to a `Buffer` object. - The classes `EndNBT`, `ByteNBT`, `ShortNBT`, `IntNBT`, `LongNBT`, `FloatNBT`, `DoubleNBT`, `ByteArrayNBT`, `StringNBT`, `ListNBT`, `CompoundNBT`, `IntArrayNBT`and `LongArrayNBT` were added and correspond to the NBT types described in the [NBT specification](https://wiki.vg/NBT#Specification). - - NBT tags can be created using the `NBTag.from_object()` method, which automatically selects the correct tag type based on the object's type and works recursively for lists and dictionaries. - - The `NBTag.to_object()` method can be used to convert an NBT tag back to a Python object. + - NBT tags can be created using the `NBTag.from_object()` method and a schema that describes the NBT tag structure. + Compound tags are represented as dictionaries, list tags as lists, and primitive tags as their respective Python types. + The implementation allows to add custom classes to the schema to handle custom NBT tags if they inherit the `:class: NBTagConvertible` class. + - The `NBTag.to_object()` method can be used to convert an NBT tag back to a Python object. Use include_schema=True to include the schema in the output, and `include_name=True` to include the name of the tag in the output. In that case the output will be a dictionary with a single key that is the name of the tag and the value is the object representation of the tag. - The `NBTag.serialize()` can be used to serialize an NBT tag to a new `Buffer` object. - The `NBTag.deserialize(buffer)` can be used to deserialize an NBT tag from a `Buffer` object. - If the buffer already exists, the `NBTag.write_to(buffer, with_type=True, with_name=True)` method can be used to write the NBT tag to the buffer (and in that case with the type and name in the right format). - The `NBTag.read_from(buffer, with_type=True, with_name=True)` method can be used to read an NBT tag from the buffer (and in that case with the type and name in the right format). + - The `NBTag.value` property can be used to get the value of the NBT tag as a Python object. diff --git a/mcproto/types/nbt.py b/mcproto/types/nbt.py index e8a057bb..e799f2a3 100644 --- a/mcproto/types/nbt.py +++ b/mcproto/types/nbt.py @@ -1,17 +1,18 @@ from __future__ import annotations -import warnings -from abc import ABCMeta +from abc import abstractmethod from enum import IntEnum -from typing import ClassVar, List, Mapping, Union, cast +from typing import Iterator, List, Mapping, Sequence, Tuple, Type, Union, cast -from typing_extensions import TypeAlias +from typing_extensions import Protocol, TypeAlias, override +from typing import runtime_checkable from mcproto.buffer import Buffer from mcproto.protocol.base_io import StructFormat from mcproto.types.abc import MCType __all__ = [ + "NBTagConvertible", "NBTagType", "NBTag", "EndNBT", @@ -29,103 +30,105 @@ "LongArrayNBT", ] - -# region NBT Specification """ -Source : https://web.archive.org/web/20110723210920/http://www.minecraft.net/docs/NBT.txt +Implementation of the NBT (Named Binary Tag) format used in Minecraft as described in the NBT specification +(:seealso: :class:`NBTagType`). +""" +# region NBT Specification -Named Binary Tag specification -NBT (Named Binary Tag) is a tag based binary format designed to carry large amounts of binary data with smaller amounts -of additional data. -An NBT file consists of a single GZIPped Named Tag of type TAG_Compound. +class NBTagType(IntEnum): + """Enumeration of the different types of NBT tags. -A Named Tag has the following format: + Source : https://web.archive.org/web/20110723210920/http://www.minecraft.net/docs/NBT.txt - byte tagType - TAG_String name - [payload] + Named Binary Tag specification -The tagType is a single byte defining the contents of the payload of the tag. + NBT (Named Binary Tag) is a tag based binary format designed to carry large amounts of binary data with smaller + amounts of additional data. + An NBT file consists of a single GZIPped Named Tag of type TAG_Compound. -The name is a descriptive name, and can be anything (eg "cat", "banana", "Hello World!"). It has nothing to do with the -tagType. -The purpose for this name is to name tags so parsing is easier and can be made to only look for certain recognized tag -names. -Exception: If tagType is TAG_End, the name is skipped and assumed to be "". + A Named Tag has the following format: -The [payload] varies by tagType. + byte tagType + TAG_String name + [payload] -Note that ONLY Named Tags carry the name and tagType data. Explicitly identified Tags (such as TAG_String above) only -contains the payload. + The tagType is a single byte defining the contents of the payload of the tag. + The name is a descriptive name, and can be anything (eg "cat", "banana", "Hello World!"). It has nothing to do with + the tagType. + The purpose for this name is to name tags so parsing is easier and can be made to only look for certain recognized + tag names. + Exception: If tagType is TAG_End, the name is skipped and assumed to be "". -The tag types and respective payloads are: + The [payload] varies by tagType. - TYPE: 0 NAME: TAG_End - Payload: None. - Note: This tag is used to mark the end of a list. - Cannot be named! If type 0 appears where a Named Tag is expected, the name is assumed to be "". - (In other words, this Tag is always just a single 0 byte when named, and nothing in all other cases) + Note that ONLY Named Tags carry the name and tagType data. Explicitly identified Tags (such as TAG_String above) + only contains the payload. - TYPE: 1 NAME: TAG_Byte - Payload: A single signed byte (8 bits) - TYPE: 2 NAME: TAG_Short - Payload: A signed short (16 bits, big endian) + The tag types and respective payloads are: - TYPE: 3 NAME: TAG_Int - Payload: A signed short (32 bits, big endian) + TYPE: 0 NAME: TAG_End + Payload: None. + Note: This tag is used to mark the end of a list. + Cannot be named! If type 0 appears where a Named Tag is expected, the name is assumed to be "". + (In other words, this Tag is always just a single 0 byte when named, and nothing in all other cases) - TYPE: 4 NAME: TAG_Long - Payload: A signed long (64 bits, big endian) + TYPE: 1 NAME: TAG_Byte + Payload: A single signed byte (8 bits) - TYPE: 5 NAME: TAG_Float - Payload: A floating point value (32 bits, big endian, IEEE 754-2008, binary32) + TYPE: 2 NAME: TAG_Short + Payload: A signed short (16 bits, big endian) - TYPE: 6 NAME: TAG_Double - Payload: A floating point value (64 bits, big endian, IEEE 754-2008, binary64) + TYPE: 3 NAME: TAG_Int + Payload: A signed short (32 bits, big endian) - TYPE: 7 NAME: TAG_Byte_Array - Payload: TAG_Int length - An array of bytes of unspecified format. The length of this array is bytes + TYPE: 4 NAME: TAG_Long + Payload: A signed long (64 bits, big endian) - TYPE: 8 NAME: TAG_String - Payload: TAG_Short length - An array of bytes defining a string in UTF-8 format. The length of this array is bytes + TYPE: 5 NAME: TAG_Float + Payload: A floating point value (32 bits, big endian, IEEE 754-2008, binary32) - TYPE: 9 NAME: TAG_List - Payload: TAG_Byte tagId - TAG_Int length - A sequential list of Tags (not Named Tags), of type . The length of this array is Tags - Notes: All tags share the same type. + TYPE: 6 NAME: TAG_Double + Payload: A floating point value (64 bits, big endian, IEEE 754-2008, binary64) - TYPE: 10 NAME: TAG_Compound - Payload: A sequential list of Named Tags. This array keeps going until a TAG_End is found. - TAG_End end - Notes: If there's a nested TAG_Compound within this tag, that one will also have a TAG_End, so simply reading - until the next TAG_End will not work. - The names of the named tags have to be unique within each TAG_Compound - The order of the tags is not guaranteed. + TYPE: 7 NAME: TAG_Byte_Array + Payload: TAG_Int length + An array of bytes of unspecified format. The length of this array is bytes + TYPE: 8 NAME: TAG_String + Payload: TAG_Short length + An array of bytes defining a string in UTF-8 format. The length of this array is bytes - // NEW TAGS - TYPE: 11 NAME: TAG_Int_Array - Payload: TAG_Int length - An array of integers. The length of this array is integers + TYPE: 9 NAME: TAG_List + Payload: TAG_Byte tagId + TAG_Int length + A sequential list of Tags (not Named Tags), of type . The length of this array is + Tags + Notes: All tags share the same type. - TYPE: 12 NAME: TAG_Long_Array - Payload: TAG_Int length - An array of longs. The length of this array is longs + TYPE: 10 NAME: TAG_Compound + Payload: A sequential list of Named Tags. This array keeps going until a TAG_End is found. + TAG_End end + Notes: If there's a nested TAG_Compound within this tag, that one will also have a TAG_End, so simply reading + until the next TAG_End will not work. + The names of the named tags have to be unique within each TAG_Compound + The order of the tags is not guaranteed. -""" -# endregion -# region NBT base classes/types + // NEW TAGS + TYPE: 11 NAME: TAG_Int_Array + Payload: TAG_Int length + An array of integers. The length of this array is integers + TYPE: 12 NAME: TAG_Long_Array + Payload: TAG_Int length + An array of longs. The length of this array is longs -class NBTagType(IntEnum): - """Types of NBT tags.""" + + """ END = 0 BYTE = 1 @@ -145,44 +148,62 @@ class NBTagType(IntEnum): PayloadType: TypeAlias = Union[ int, float, - bytearray, bytes, str, - List["PayloadType"], - Mapping[str, "PayloadType"], - List[int], "NBTag", - List["NBTag"], + Sequence["PayloadType"], + Mapping[str, "PayloadType"], ] -class _MetaNBTag(ABCMeta): - """Metaclass for NBT tags.""" +@runtime_checkable +class NBTagConvertible(Protocol): + """Protocol for objects that can be converted to an NBT tag.""" + + __slots__ = () + + def to_nbt(self, name: str = "") -> NBTag: + """Convert the object to an NBT tag. + + :param name: The name of the tag. + + :return: The NBT tag created from the object. + """ + ... - TYPE: NBTagType = NBTagType.COMPOUND - def __new__(cls, name: str, bases: tuple[type], namespace: dict, **kwargs): - new_cls: NBTag = super().__new__(cls, name, bases, namespace) # type: ignore - if name != "NBTag": - NBTag.ASSOCIATED_TYPES[new_cls.TYPE] = new_cls # type: ignore - return new_cls +FromObjectType: TypeAlias = Union[ + int, + float, + bytes, + str, + NBTagConvertible, + Sequence["FromObjectType"], + Mapping[str, "FromObjectType"], +] +FromObjectSchema: TypeAlias = Union[ + Type["NBTag"], + Type[NBTagConvertible], + Sequence["FromObjectSchema"], + Mapping[str, "FromObjectSchema"], +] -class NBTag(MCType, metaclass=_MetaNBTag): - """Base class for NBT tags.""" - __slots__ = ("name", "payload") +class NBTag(MCType, NBTagConvertible): + """Base class for NBT tags. - TYPE: ClassVar[NBTagType] = NBTagType.COMPOUND + In MC v1.20.2+ the type and name of the root tag are not written to the buffer, and unless specified, the type of + the tag is assumed to be TAG_Compound. + """ - ASSOCIATED_TYPES: ClassVar[dict[NBTagType, type[NBTag]]] = {} + __slots__ = ("name", "payload") def __init__(self, payload: PayloadType, name: str = ""): - if self.__class__ == NBTag: - raise TypeError("Cannot instantiate an NBTag object directly, use a subclass instead.") self.name = name self.payload = payload + @override def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: """Serialize the NBT tag to a buffer. @@ -196,21 +217,19 @@ def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: self.write_to(buf, with_name=with_name, with_type=with_type) return buf - def _write_header(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> bool: + def _write_header(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: if with_type: - buf.write_value(StructFormat.BYTE, self.TYPE.value) - if self.TYPE == NBTagType.END: - return False - if with_name: - if not self.name: - raise ValueError("Named tags must have a name.") + tag_type = _get_tag_type(self) + buf.write_value(StructFormat.BYTE, tag_type.value) + if with_name and self.name: StringNBT(self.name).write_to(buf, with_type=False, with_name=False) - return True + @abstractmethod def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the NBT tag to the buffer.""" ... + @override @classmethod def deserialize(cls, buf: Buffer, with_name: bool = True, with_type: bool = True) -> NBTag: """Deserialize the NBT tag. @@ -229,7 +248,7 @@ def deserialize(cls, buf: Buffer, with_name: bool = True, with_type: bool = True """ name, tag_type = cls._read_header(buf, with_name=with_name, read_type=with_type) - tag_class = NBTag.ASSOCIATED_TYPES[tag_type] + tag_class = ASSOCIATED_TYPES[tag_type] if cls not in (NBTag, tag_class): raise TypeError(f"Expected a {cls.__name__} tag, but found a different tag ({tag_class.__name__}).") @@ -251,16 +270,17 @@ def _read_header(cls, buf: Buffer, read_type: bool = True, with_name: bool = Tru :note: It is possible that this function reads nothing from the buffer if both with_name and read_type are set to False. """ - tag_type: NBTagType = cls.TYPE # default value if read_type: try: tag_type = NBTagType(buf.read_value(StructFormat.BYTE)) - except OSError: - raise IOError("Buffer is empty.") from None - except ValueError: - raise TypeError("Invalid tag type.") from None - - if tag_type == NBTagType.END: + except OSError as exc: + raise IOError("Buffer is empty.") from exc + except ValueError as exc: + raise TypeError("Invalid tag type.") from exc + else: + tag_type = _get_tag_type(cls) + + if tag_type is NBTagType.END: return "", tag_type name = StringNBT.read_from(buf, with_type=False, with_name=False).value if with_name else "" @@ -268,6 +288,7 @@ def _read_header(cls, buf: Buffer, read_type: bool = True, with_name: bool = Tru return name, tag_type @classmethod + @abstractmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> NBTag: """Read the NBT tag from the buffer. @@ -281,142 +302,133 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The NBT tag. """ - return cls.deserialize(buf, with_name=with_name, with_type=with_type) + ... @staticmethod - def from_object(data: object, /, name: str = "", *, use_int_array: bool = True) -> NBTag: # noqa: PLR0911,PLR0912 - """Create an NBT tag from an arbitrary (compatible) Python object. - - :param data: The object to convert to an NBT tag. - :param use_int_array: Whether to use IntArrayNBT and LongArrayNBT for lists of integers. - If set to False, all lists of integers will be considered as ListNBT. - :param name: The name of the resulting tag. Used for recursive calls. + def from_object(data: FromObjectType, schema: FromObjectSchema, name: str = "") -> NBTag: + """Create an NBT tag from a dictionary. - :return: The NBT tag representing the object. + :param data: The dictionary to create the NBT tag from. + :param schema: The schema used to create the NBT tags. - :note: The function will attempt to convert the object to an NBT tag in the following way: - - If the object is a dictionary with a single key, the key will be used as the name of the tag. - - If the object is an integer, it will be converted to a ByteNBT, ShortNBT, IntNBT, or LongNBT tag - depending on the value. - - If the object is a list, it will be converted to a ListNBT tag. - - If the object is a dictionary, it will be converted to a CompoundNBT tag. - - If the object is a string, it will be converted to a StringNBT tag. - - If the object is a float, it will be converted to a FloatNBT tag. - - If the object can be serialized to bytes, it will be converted to a ByteArrayNBT tag. + If the schema is a list, the data must be a list and the schema must either contain a single element + representing the type of the elements in the list or multiple dictionaries or lists representing the types + of the elements in the list since they are the only types that have a variable type. + Example: + ```python + schema = [IntNBT] + data = [1, 2, 3] + schema = [[IntNBT], [StringNBT]] + data = [[1, 2, 3], ["a", "b", "c"]] + ``` - - If you want an object to be serialized in a specific way, you can implement: + If the schema is a dictionary, the data must be a dictionary and the schema must contain the keys and the + types of the values in the dictionary. + Example: ```python - def to_nbt(self, name: str = "") -> NBTag: - ... + schema = {"key": IntNBT} + data = {"key": 1} ``` + + If the schema is a subclass of NBTag, the data will be passed to the constructor of the schema. + If the schema is not a list, dictionary or subclass of NBTag, the data will be converted to an NBT tag + using the `to_nbt` method of the data. + + :param name: The name of the NBT tag. + + :return: The NBT tag created from the dictionary. """ - if hasattr(data, "to_nbt"): # For objects that can be converted to NBT - return data.to_nbt(name=name) # type: ignore - - if isinstance(data, int): - if -(1 << 7) <= data < 1 << 7: - return ByteNBT(data, name=name) - if -(1 << 15) <= data < 1 << 15: - return ShortNBT(data, name=name) - if -(1 << 31) <= data < 1 << 31: - return IntNBT(data, name=name) - if -(1 << 63) <= data < 1 << 63: - return LongNBT(data, name=name) - raise ValueError(f"Integer {data} is out of range.") - if isinstance(data, float): - return FloatNBT(data, name=name) - if isinstance(data, str): - return StringNBT(data, name=name) - if isinstance(data, (bytearray, bytes)): - if isinstance(data, bytearray): - data = bytes(data) - return ByteArrayNBT(data, name=name) - if isinstance(data, list): - if not data: - # Type END is used to mark an empty list - return ListNBT([], name=name) - first_type = type(data[0]) - if any(type(item) != first_type for item in data): - raise TypeError("All items in a list must be of the same type.") - - if issubclass(first_type, int) and use_int_array: - # Check the range of the integers in the list - use_int = all(-(1 << 31) <= item < 1 << 31 for item in data) - use_long = all(-(1 << 63) <= item < 1 << 63 for item in data) - if use_int: - return IntArrayNBT(data, name=name) - if not use_long: # Too big to fit in a long, won't fit in a List of Longs either - raise ValueError("Integer list contains values out of range.") - return LongArrayNBT(data, name=name) - return ListNBT([NBTag.from_object(item, use_int_array=use_int_array) for item in data], name=name) - if isinstance(data, dict): - if len(data) == 0: - return CompoundNBT([], name=name) - if len(data) == 1 and name == "": - key, value = next(iter(data.items())) - return NBTag.from_object(value, name=key, use_int_array=use_int_array) - payload = [] + if isinstance(schema, (list, tuple)): + if not isinstance(data, list): + raise TypeError("Expected a list, but found a different type.") + payload: list[NBTag] = [] + if len(schema) > 1: + if not all(isinstance(item, (list, dict)) for item in schema): + raise TypeError("Expected a list of lists or dictionaries, but found a different type.") + if len(schema) != len(data): + raise ValueError("The schema and the data must have the same length.") + for item, sub_schema in zip(data, schema): + payload.append(NBTag.from_object(item, sub_schema)) + else: + if len(schema) == 0 and len(data) > 0: + raise ValueError("The schema is empty, but the data is not.") + if len(schema) == 0: + return ListNBT([], name=name) + + schema = schema[0] + for item in data: + payload.append(NBTag.from_object(item, schema)) + return ListNBT(payload, name=name) + if isinstance(schema, dict): + if not isinstance(data, dict): + raise TypeError("Expected a dictionary, but found a different type.") + payload: list[NBTag] = [] for key, value in data.items(): - tag = NBTag.from_object(value, name=key, use_int_array=use_int_array) - payload.append(tag) - return CompoundNBT(payload, name) - if data is None: - warnings.warn("Converting None to an END tag.", stacklevel=2) - return EndNBT() # Should not be used - - try: - # Check if the object can be converted to bytes - return ByteArrayNBT(bytes(data), name=name) # type: ignore - except (TypeError, ValueError): - pass - raise TypeError(f"Cannot convert object of type {type(data)} to an NBT tag.") - - def to_object(self) -> Mapping[str, PayloadType] | PayloadType: - """Convert the NBT payload to a dictionary.""" - return CompoundNBT(self.payload).to_object() # allow NBTag.to_object to act as a dict - - def __getitem__(self, key: str | int) -> PayloadType: - """Get a tag from the list or compound tag.""" - if self.TYPE not in (NBTagType.LIST, NBTagType.COMPOUND, NBTagType.INT_ARRAY, NBTagType.LONG_ARRAY): - raise TypeError(f"Cannot get a tag by index from a non-LIST or non-COMPOUND tag ({self.TYPE}).") - - if not isinstance(self.payload, list): - raise AttributeError( - f"The payload of the tag is not a list ({self.TYPE}).\n" - "Check that the initialization of the tag is correct." + payload.append(NBTag.from_object(value, schema[key], name=key)) + return CompoundNBT(payload, name=name) + if not isinstance(schema, type) or not issubclass(schema, (NBTag, NBTagConvertible)): # type: ignore + raise TypeError("The schema must be a list, dict or a subclass of either NBTag or NBTagConvertible.") + if isinstance(data, schema): + return data.to_nbt(name=name) + schema = cast(Type[NBTag], schema) # Last option + if issubclass(schema, (CompoundNBT, ListNBT)): + raise ValueError("The schema must specify the type of the elements in CompoundNBT and ListNBT tags.") + if isinstance(data, dict): + if len(data) != 1: + raise ValueError("Expected a dictionary with a single key-value pair.") + key, value = next(iter(data.items())) + return schema.from_object(value, schema, name=key) + if not isinstance(data, (bytes, str, int, float, list)): + raise TypeError(f"Expected a bytes, str, int, float, but found {type(data).__name__}.") + if isinstance(data, list) and not all(isinstance(item, int) for item in data): + raise TypeError("Expected a list of integers.") # LongArrayNBT, IntArrayNBT + + data = cast(Union[bytes, str, int, float, List[int]], data) + return schema(data, name=name) + + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> PayloadType | Mapping[str, PayloadType] | tuple[PayloadType | Mapping[str, PayloadType], FromObjectSchema]: + """Convert the NBT tag to a python object. + + :param include_schema: Whether to return a schema describing the types of the original tag. + :param include_name: Whether to include the name of the tag in the output. + If the tag has no name, the name will be set to "". + + :return: Either : + - A python object representing the payload of the tag. (default) + - A dictionary containing the name associated with a python object representing the payload of the tag. + - A tuple which includes one of the above and a schema describing the types of the original tag. + """ + if type(self) is EndNBT: + raise NotImplementedError("Cannot convert an EndNBT tag to a python object.") + if type(self) in (CompoundNBT, ListNBT): + raise TypeError( + f"Use the `{type(self).__name__}.to_object()` method to convert the tag to a python object." ) - if not isinstance(key, (str, int)): # type: ignore - raise TypeError("Key must be a string or an integer.") - - if isinstance(key, str): - if self.TYPE != NBTagType.COMPOUND: - raise TypeError(f"Cannot get a tag by name from a non-COMPOUND tag ({self.TYPE}).") - if not all(isinstance(tag, NBTag) for tag in self.payload): - raise AttributeError("The payload of the tag is not a list of NBTag objects.") - for tag in self.payload: - tag = cast(NBTag, tag) - if tag.name == key: - return tag - raise KeyError(f"No tag with the name {key!r} found.") - - # Key is an integer - if key < -len(self.payload) or key >= len(self.payload): - raise IndexError(f"Index {key} out of range.") - return self.payload[key] + result = self.payload if not include_name else {self.name: self.payload} + if include_schema: + return result, type(self) + return result + @override def __repr__(self) -> str: if self.name: - return f"{self.__class__.__name__}[{self.name!r}]({self.payload!r})" - return f"{self.__class__.__name__}({self.payload!r})" + return f"{type(self).__name__}[{self.name!r}]({self.payload!r})" + return f"{type(self).__name__}({self.payload!r})" + @override def __eq__(self, other: object) -> bool: """Check equality between two NBT tags.""" if not isinstance(other, NBTag): raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") - return self.name == other.name and self.TYPE == other.TYPE and self.payload == other.payload + if type(self) is not type(other): + return False + return self.name == other.name and self.payload == other.payload + @override def to_nbt(self, name: str = "") -> NBTag: """Convert the object to an NBT tag. @@ -426,12 +438,10 @@ def to_nbt(self, name: str = "") -> NBTag: return self @property + @abstractmethod def value(self) -> PayloadType: """Get the payload of the NBT tag in a python-friendly format.""" - obj = self.to_object() - if isinstance(obj, dict) and self.name: - return obj[self.name] - return obj + ... # endregion @@ -441,22 +451,23 @@ def value(self) -> PayloadType: class EndNBT(NBTag): """Sentinel tag used to mark the end of a TAG_Compound.""" - TYPE = NBTagType.END __slots__ = () def __init__(self): """Create a new EndNBT tag.""" super().__init__(0, name="") - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = False) -> None: """Write the EndNBT tag to the buffer. :param buf: The buffer to write to. :param with_type: Whether to include the type of the tag in the serialization. :param with_name: Whether to include the name of the tag in the serialization. """ - self._write_header(buf, with_type=with_type, with_name=with_name) + self._write_header(buf, with_type=with_type, with_name=False) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> EndNBT: """Read the EndNBT tag from the buffer. @@ -469,26 +480,37 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The EndNBT tag. """ _, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") return EndNBT() - def to_object(self) -> Mapping[str, PayloadType]: + @override + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> PayloadType | Mapping[str, PayloadType]: """Convert the EndNBT tag to a python object. - :return: An empty dictionary. + :param include_schema: Whether to return a schema describing the types of the original tag. + :param include_name: Whether to include the name of the tag in the output. + + :return: None """ - return {} + return NotImplemented + + @property + @override + def value(self) -> PayloadType: + """Get the payload of the EndNBT tag in a python-friendly format.""" + return NotImplemented class ByteNBT(NBTag): """NBT tag representing a single byte value, represented as a signed 8-bit integer.""" - TYPE = NBTagType.BYTE - __slots__ = () payload: int + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the ByteNBT tag to the buffer. @@ -502,6 +524,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) buf.write_value(StructFormat.BYTE, self.payload) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteNBT: """Read the ByteNBT tag from the buffer. @@ -514,8 +537,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The ByteNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 1: raise IOError("Buffer does not contain enough data to read a byte. (Empty buffer)") @@ -526,17 +549,8 @@ def __int__(self) -> int: """Get the integer value of the ByteNBT tag.""" return self.payload - def to_object(self) -> Mapping[str, int] | int: - """Convert the ByteNBT tag to a python object. - - :return: A dictionary containing the name and the integer value of the tag. If the tag has no name, the value - will be returned directly. - """ - if self.name: - return {self.name: self.payload} - return self.payload - @property + @override def value(self) -> int: """Get the integer value of the IntNBT tag.""" return self.payload @@ -545,10 +559,9 @@ def value(self) -> int: class ShortNBT(ByteNBT): """NBT tag representing a short value, represented as a signed 16-bit integer.""" - TYPE = NBTagType.SHORT - __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the ShortNBT tag to the buffer. @@ -565,6 +578,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) buf.write(self.payload.to_bytes(2, "big", signed=True)) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ShortNBT: """Read the ShortNBT tag from the buffer. @@ -577,8 +591,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The ShortNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 2: raise IOError("Buffer does not contain enough data to read a short.") @@ -589,10 +603,9 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) class IntNBT(ByteNBT): """NBT tag representing an integer value, represented as a signed 32-bit integer.""" - TYPE = NBTagType.INT - __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the IntNBT tag to the buffer. @@ -610,6 +623,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) # No more messing around with the struct, we want 32 bits of data no matter what buf.write(self.payload.to_bytes(4, "big", signed=True)) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntNBT: """Read the IntNBT tag from the buffer. @@ -622,8 +636,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The IntNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 4: raise IOError("Buffer does not contain enough data to read an int.") @@ -634,10 +648,9 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) class LongNBT(ByteNBT): """NBT tag representing a long value, represented as a signed 64-bit integer.""" - TYPE = NBTagType.LONG - __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the LongNBT tag to the buffer. @@ -655,6 +668,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) # No more messing around with the struct, we want 64 bits of data no matter what buf.write(self.payload.to_bytes(8, "big", signed=True)) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongNBT: """Read the LongNBT tag from the buffer. @@ -667,8 +681,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The LongNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 8: raise IOError("Buffer does not contain enough data to read a long.") @@ -680,12 +694,11 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) class FloatNBT(NBTag): """NBT tag representing a floating-point value, represented as a 32-bit IEEE 754-2008 binary32 value.""" - TYPE = NBTagType.FLOAT - payload: float __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the FloatNBT tag to the buffer. @@ -698,6 +711,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) self._write_header(buf, with_type=with_type, with_name=with_name) buf.write_value(StructFormat.FLOAT, self.payload) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> FloatNBT: """Read the FloatNBT tag from the buffer. @@ -710,8 +724,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The FloatNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 4: raise IOError("Buffer does not contain enough data to read a float.") @@ -722,16 +736,7 @@ def __float__(self) -> float: """Get the float value of the FloatNBT tag.""" return self.payload - def to_object(self) -> Mapping[str, float] | float: - """Convert the FloatNBT tag to a python object. - - :return: A dictionary containing the name and the float value of the tag. If the tag has no name, the value - will be returned directly. - """ - if self.name: - return {self.name: self.payload} - return self.payload - + @override def __eq__(self, other: object) -> bool: """Check equality between two FloatNBT tags. @@ -744,14 +749,15 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, NBTag): raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") # Compare the float values with a small epsilon - if not (self.name == other.name and self.TYPE == other.TYPE): + if type(self) is not type(other): + return False + other.payload = cast(float, other.payload) + if self.name != other.name: return False - if not isinstance(other, self.__class__): # pragma: no cover - return False # Should not happen if nobody messes with the TYPE attribute - return abs(self.payload - other.payload) < 1e-6 @property + @override def value(self) -> float: """Get the float value of the FloatNBT tag.""" return self.payload @@ -760,10 +766,9 @@ def value(self) -> float: class DoubleNBT(FloatNBT): """NBT tag representing a double-precision floating-point value, represented as a 64-bit IEEE 754-2008 binary64.""" - TYPE = NBTagType.DOUBLE - __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the DoubleNBT tag to the buffer. @@ -776,6 +781,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) self._write_header(buf, with_type=with_type, with_name=with_name) buf.write_value(StructFormat.DOUBLE, self.payload) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> DoubleNBT: """Read the DoubleNBT tag from the buffer. @@ -788,8 +794,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The DoubleNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 8: raise IOError("Buffer does not contain enough data to read a double.") @@ -800,12 +806,11 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) class ByteArrayNBT(NBTag): """NBT tag representing an array of bytes. The length of the array is stored as a signed 32-bit integer.""" - TYPE = NBTagType.BYTE_ARRAY - __slots__ = () - payload: bytearray + payload: bytes + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the ByteArrayNBT tag to the buffer. @@ -819,6 +824,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) IntNBT(len(self.payload)).write_to(buf, with_type=False, with_name=False) buf.write(self.payload) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteArrayNBT: """Read the ByteArrayNBT tag from the buffer. @@ -831,12 +837,12 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The ByteArrayNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") try: length = IntNBT.read_from(buf, with_type=False, with_name=False).value - except IOError: - raise IOError("Buffer does not contain enough data to read a byte array.") from None + except IOError as exc: + raise IOError("Buffer does not contain enough data to read a byte array.") from exc if length < 0: raise ValueError("Invalid byte array length.") @@ -846,32 +852,24 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) f"Buffer does not contain enough data to read the byte array ({buf.remaining} < {length} bytes)." ) - return ByteArrayNBT(buf.read(length), name=name) + return ByteArrayNBT(bytes(buf.read(length)), name=name) def __bytes__(self) -> bytes: """Get the bytes value of the ByteArrayNBT tag.""" return self.payload - def to_object(self) -> Mapping[str, bytearray] | bytearray: - """Convert the ByteArrayNBT tag to a python object. - - :return: A dictionary containing the name and the byte array value of the tag. If the tag has no name, the - value will be returned directly. - """ - if self.name: - return {self.name: self.payload} - return self.payload - + @override def __repr__(self) -> str: """Get a string representation of the ByteArrayNBT tag.""" if self.name: - return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)})" + return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)})" if len(self.payload) < 8: - return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" - return f"{self.__class__.__name__}(length={len(self.payload)}, {bytes(self.payload[:7])!r}...)" + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}(length={len(self.payload)}, {bytes(self.payload[:7])!r}...)" @property - def value(self) -> bytearray: + @override + def value(self) -> bytes: """Get the bytes value of the ByteArrayNBT tag.""" return self.payload @@ -879,12 +877,11 @@ def value(self) -> bytearray: class StringNBT(NBTag): """NBT tag representing an UTF-8 string value. The length of the string is stored as a signed 16-bit integer.""" - TYPE = NBTagType.STRING - __slots__ = () payload: str + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the StringNBT tag to the buffer. @@ -899,10 +896,11 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) # Check the length of the string (can't generate strings that long in tests) raise ValueError("Maximum character limit for writing strings is 32767 characters.") # pragma: no cover - data = bytearray(self.payload, "utf-8") + data = bytes(self.payload, "utf-8") ShortNBT(len(data)).write_to(buf, with_type=False, with_name=False) buf.write(data) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> StringNBT: """Read the StringNBT tag from the buffer. @@ -915,12 +913,12 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The StringNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") try: length = ShortNBT.read_from(buf, with_type=False, with_name=False).value - except IOError: - raise IOError("Buffer does not contain enough data to read a string.") from None + except IOError as exc: + raise IOError("Buffer does not contain enough data to read a string.") from exc if length < 0: raise ValueError("Invalid string length.") @@ -933,21 +931,13 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) except UnicodeDecodeError: raise # We want to know it + @override def __str__(self) -> str: """Get the string value of the StringNBT tag.""" return self.payload - def to_object(self) -> Mapping[str, str] | str: - """Convert the StringNBT tag to a python object. - - :return: A dictionary containing the name and the string value of the tag. If the tag has no name, the value - will be returned directly. - """ - if self.name: - return {self.name: self.payload} - return self.payload - @property + @override def value(self) -> str: """Get the string value of the StringNBT tag.""" return self.payload @@ -956,12 +946,11 @@ def value(self) -> str: class ListNBT(NBTag): """NBT tag representing a list of tags. All tags in the list must be of the same type.""" - TYPE = NBTagType.LIST - __slots__ = () payload: list[NBTag] + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the ListNBT tag to the buffer. @@ -986,17 +975,18 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) "objects to tags first." ) - tag_type = self.payload[0].TYPE + tag_type = _get_tag_type(self.payload[0]) ByteNBT(tag_type).write_to(buf, with_name=False, with_type=False) IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) for tag in self.payload: - if tag_type != tag.TYPE: + if tag_type != _get_tag_type(tag): raise ValueError(f"All tags in a list must be of the same type, got tag {tag!r}") if tag.name != "": raise ValueError(f"All tags in a list must be unnamed, got tag {tag!r}") tag.write_to(buf, with_type=False, with_name=False) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ListNBT: """Read the ListNBT tag from the buffer. @@ -1009,69 +999,104 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The ListNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") list_tag_type = ByteNBT.read_from(buf, with_type=False, with_name=False).payload try: length = IntNBT.read_from(buf, with_type=False, with_name=False).value - except IOError: - raise IOError("Buffer does not contain enough data to read a list.") from None + except IOError as exc: + raise IOError("Buffer does not contain enough data to read a list.") from exc - if length < 0 or list_tag_type == NBTagType.END: + if length < 0 or list_tag_type is NBTagType.END: return ListNBT([], name=name) try: list_tag_type = NBTagType(list_tag_type) - except ValueError: - raise TypeError(f"Unknown tag type {list_tag_type}.") from None + except ValueError as exc: + raise TypeError(f"Unknown tag type {list_tag_type}.") from exc - list_type_class = NBTag.ASSOCIATED_TYPES.get(list_tag_type, NBTag) - if list_type_class == NBTag: + list_type_class = ASSOCIATED_TYPES.get(list_tag_type, NBTag) + if list_type_class is NBTag: raise TypeError(f"Unknown tag type {list_tag_type}.") # pragma: no cover try: - payload = [ - # The type is already known, so we don't need to read it again - # List items are unnamed, so we don't need to read the name - list_type_class.read_from(buf, with_type=False, with_name=False) - for _ in range(length) - ] - except IOError: - raise IOError("Buffer does not contain enough data to read the list.") from None + payload = [list_type_class.read_from(buf, with_type=False, with_name=False) for _ in range(length)] + except IOError as exc: + raise IOError("Buffer does not contain enough data to read the list.") from exc return ListNBT(payload, name=name) - def __iter__(self): + def __iter__(self) -> Iterator[NBTag]: """Iterate over the tags in the list.""" yield from self.payload + @override def __repr__(self) -> str: """Get a string representation of the ListNBT tag.""" if self.name: - return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" if len(self.payload) < 8: - return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" - return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" - - def to_object(self) -> Mapping[str, list[PayloadType]] | list[PayloadType]: + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + + @override + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> ( + list[PayloadType] + | Mapping[str, list[PayloadType]] + | tuple[list[PayloadType] | Mapping[str, list[PayloadType]], list[FromObjectSchema]] + ): """Convert the ListNBT tag to a python object. - :return: A dictionary containing the name and the list of tags. If the tag has no name, the list will be - returned directly. + :param include_schema: Whether to return a schema describing the types of the original tag. + :param include_name: Whether to include the name of the tag in the output. + If the tag has no name, the name will be set to "". + + :return: Either : + - A list containing the payload of the tag. (default) + - A dictionary containing the name associated with a list containing the payload of the tag. + - A tuple which includes one of the above and a list of schemas describing the types of the original tag. """ - self.payload: list[NBTag] - if self.name: - return {self.name: [tag.to_object() for tag in self.payload]} # Extract the (unnamed) object from each tag - return [tag.to_object() for tag in self.payload] # Extract the (unnamed) object from each tag + result = [tag.to_object() for tag in self.payload] + result = cast(List[PayloadType], result) + result = result if not include_name else {self.name: result} + if include_schema: + subschemas = [ + cast( + Tuple[PayloadType, FromObjectSchema], + tag.to_object(include_schema=True), + )[1] + for tag in self.payload + ] + if len(result) == 0: + return result, [] + + first = subschemas[0] + if all(schema == first for schema in subschemas): + return result, [first] + + if not isinstance(first, (dict, list)): + raise TypeError(f"The schema must contain either a dict or a list. Found {first!r}") + # This will take care of ensuring either everything is a dict or a list + if not all(isinstance(schema, type(first)) for schema in subschemas): + raise TypeError(f"All items in the list must have the same type. Found {subschemas!r}") + return result, subschemas + return result + + @property + @override + def value(self) -> list[PayloadType]: + """Get the payload of the ListNBT tag in a python-friendly format.""" + return [tag.value for tag in self.payload] class CompoundNBT(NBTag): """NBT tag representing a compound of named tags.""" - TYPE = NBTagType.COMPOUND - __slots__ = () payload: list[NBTag] + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the CompoundNBT tag to the buffer. @@ -1102,6 +1127,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) tag.write_to(buf) EndNBT().write_to(buf, with_name=False, with_type=True) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> CompoundNBT: """Read the CompoundNBT tag from the buffer. @@ -1114,16 +1140,16 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The CompoundNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") - payload = [] + payload: list[NBTag] = [] while True: child_name, child_type = cls._read_header(buf, with_name=True, read_type=True) - if child_type == NBTagType.END: + if child_type is NBTagType.END: break # The name and type of the tag have already been read - tag = NBTag.ASSOCIATED_TYPES[child_type].read_from(buf, with_type=False, with_name=False) + tag = ASSOCIATED_TYPES[child_type].read_from(buf, with_type=False, with_name=False) tag.name = child_name payload.append(tag) return CompoundNBT(payload, name=name) @@ -1133,29 +1159,50 @@ def __iter__(self): for tag in self.payload: yield tag.name, tag + @override def __repr__(self) -> str: """Get a string representation of the CompoundNBT tag.""" if self.name: - return f"{self.__class__.__name__}[{self.name!r}]({dict(self)})" - return f"{self.__class__.__name__}({dict(self)})" - - def to_object(self) -> Mapping[str, Mapping[str, PayloadType]]: + return f"{type(self).__name__}[{self.name!r}]({dict(self)})" + return f"{type(self).__name__}({dict(self)})" + + @override + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> ( + Mapping[str, PayloadType] + | Mapping[str, Mapping[str, PayloadType]] + | tuple[ + Mapping[str, PayloadType] | Mapping[str, Mapping[str, PayloadType]], + Mapping[str, FromObjectSchema], + ] + ): """Convert the CompoundNBT tag to a python object. - :return: A dictionary containing the name and the dictionary of tags. If the tag has no name, the dictionary - will be returned directly. + :param include_schema: Whether to return a schema describing the types of the original tag and its children. + :param include_name: Whether to include the name of the tag in the output. + If the tag has no name, the name will be set to "". + + :return: Either : + - A dictionary containing the payload of the tag. (default) + - A dictionary containing the name associated with a dictionary containing the payload of the tag. + - A tuple which includes one of the above and a dictionary of schemas describing the types of the original tag. """ - result = {} - for tag in self.payload: - if tag.name in result: - raise ValueError(f"Duplicate tag name {tag.name!r} in the compound.") - if tag.name == "": - raise ValueError("All tags in a compound must have a name.") - result.update(cast("dict[str, PayloadType]", tag.to_object())) - if self.name: - return {self.name: result} + result = {tag.name: tag.to_object() for tag in self.payload} + result = cast(Mapping[str, PayloadType], result) + result = result if not include_name else {self.name: result} + if include_schema: + subschemas = { + tag.name: cast( + Tuple[PayloadType, FromObjectSchema], + tag.to_object(include_schema=True), + )[1] + for tag in self.payload + } + return result, subschemas return result + @override def __eq__(self, other: object) -> bool: """Check equality between two CompoundNBT tags. @@ -1169,24 +1216,30 @@ def __eq__(self, other: object) -> bool: # The order of the tags is not guaranteed if not isinstance(other, NBTag): raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") - if self.name != other.name or self.TYPE != other.TYPE: + if type(self) is not type(other): + return False + if self.name != other.name: return False - if not isinstance(other, self.__class__): # pragma: no cover - return False # Should not happen if nobody messes with the TYPE attribute + other = cast(CompoundNBT, other) if len(self.payload) != len(other.payload): return False return all(tag in other.payload for tag in self.payload) + @property + @override + def value(self) -> dict[str, PayloadType]: + """Get the dictionary of tags in the CompoundNBT tag.""" + return {tag.name: tag.value for tag in self.payload} + class IntArrayNBT(NBTag): """NBT tag representing an array of integers. The length of the array is stored as a signed 32-bit integer.""" - TYPE = NBTagType.INT_ARRAY - __slots__ = () payload: list[int] + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the IntArrayNBT tag to the buffer. @@ -1208,6 +1261,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) for i in self.payload: IntNBT(i).write_to(buf, with_name=False, with_type=False) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntArrayNBT: """Read the IntArrayNBT tag from the buffer. @@ -1224,36 +1278,28 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) raise TypeError(f"Expected an INT_ARRAY tag, but found a different tag ({tag_type}).") length = IntNBT.read_from(buf, with_type=False, with_name=False).value try: - payload = [IntNBT.read_from(buf, with_type == NBTagType.INT, with_name=False).value for _ in range(length)] - except IOError: + payload = [IntNBT.read_from(buf, with_type is NBTagType.INT, with_name=False).value for _ in range(length)] + except IOError as exc: raise IOError( "Buffer does not contain enough data to read the entire integer array. (Incomplete data)" - ) from None + ) from exc return IntArrayNBT(payload, name=name) + @override def __repr__(self) -> str: """Get a string representation of the IntArrayNBT tag.""" if self.name: - return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" if len(self.payload) < 8: - return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" - return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" - def __iter__(self): + def __iter__(self) -> Iterator[int]: """Iterate over the integers in the array.""" yield from self.payload - def to_object(self) -> Mapping[str, list[int]] | list[int]: - """Convert the IntArrayNBT tag to a python object. - - :return: A dictionary containing the name and the list of integers. If the tag has no name, the list will be - returned directly. - """ - if self.name: - return {self.name: self.payload} - return self.payload - @property + @override def value(self) -> list[int]: """Get the list of integers in the IntArrayNBT tag.""" return self.payload @@ -1262,10 +1308,9 @@ def value(self) -> list[int]: class LongArrayNBT(IntArrayNBT): """NBT tag representing an array of longs. The length of the array is stored as a signed 32-bit integer.""" - TYPE = NBTagType.LONG_ARRAY - __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the LongArrayNBT tag to the buffer. @@ -1287,6 +1332,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) for i in self.payload: LongNBT(i).write_to(buf, with_name=False, with_type=False) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongArrayNBT: """Read the LongArrayNBT tag from the buffer. @@ -1305,11 +1351,49 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) try: payload = [LongNBT.read_from(buf, with_type=False, with_name=False).payload for _ in range(length)] - except IOError: + except IOError as exc: raise IOError( "Buffer does not contain enough data to read the entire long array. (Incomplete data)" - ) from None + ) from exc return LongArrayNBT(payload, name=name) # endregion + +# region: NBT Associated Types +ASSOCIATED_TYPES: dict[NBTagType, type[NBTag]] = { + NBTagType.END: EndNBT, + NBTagType.BYTE: ByteNBT, + NBTagType.SHORT: ShortNBT, + NBTagType.INT: IntNBT, + NBTagType.LONG: LongNBT, + NBTagType.FLOAT: FloatNBT, + NBTagType.DOUBLE: DoubleNBT, + NBTagType.BYTE_ARRAY: ByteArrayNBT, + NBTagType.STRING: StringNBT, + NBTagType.LIST: ListNBT, + NBTagType.COMPOUND: CompoundNBT, + NBTagType.INT_ARRAY: IntArrayNBT, + NBTagType.LONG_ARRAY: LongArrayNBT, +} + + +def _get_tag_type(tag: NBTag | type[NBTag]) -> NBTagType: + """Get the tag type of an NBTag object or class. + + :param tag: The tag to get the type of. + + :return: The tag type of the tag. + """ + cls = tag if isinstance(tag, type) else type(tag) + + if cls is NBTag: + return NBTagType.COMPOUND + for tag_type, tag_cls in ASSOCIATED_TYPES.items(): + if cls is tag_cls: + return tag_type + + raise ValueError(f"Unknown tag type {cls}.") # pragma: no cover + + +# endregion diff --git a/tests/mcproto/types/test_nbt.py b/tests/mcproto/types/test_nbt.py index 18e5b37c..f68f6e16 100644 --- a/tests/mcproto/types/test_nbt.py +++ b/tests/mcproto/types/test_nbt.py @@ -1,6 +1,8 @@ from __future__ import annotations import struct +from typing import Any, Dict, List, cast +from typing_extensions import override import pytest @@ -22,6 +24,7 @@ PayloadType, ShortNBT, StringNBT, + NBTagConvertible, ) # region EndNBT @@ -41,7 +44,7 @@ def test_serialize_deserialize_end(): assert buffer == bytearray.fromhex("00") buffer = Buffer(bytearray.fromhex("00")) - assert NBTag.deserialize(buffer).TYPE == NBTagType.END + assert EndNBT.deserialize(buffer) == EndNBT() # endregion @@ -86,9 +89,21 @@ def test_serialize_deserialize_end(): (ByteArrayNBT, b"", bytearray.fromhex("07 00 00 00 00")), (ByteArrayNBT, b"\x00", bytearray.fromhex("07 00 00 00 01") + b"\x00"), (ByteArrayNBT, b"\x00\x01", bytearray.fromhex("07 00 00 00 02") + b"\x00\x01"), - (ByteArrayNBT, b"\x00\x01\x02", bytearray.fromhex("07 00 00 00 03") + b"\x00\x01\x02"), - (ByteArrayNBT, b"\x00\x01\x02\x03", bytearray.fromhex("07 00 00 00 04") + b"\x00\x01\x02\x03"), - (ByteArrayNBT, b"\xFF" * 1024, bytearray.fromhex("07 00 00 04 00") + b"\xFF" * 1024), + ( + ByteArrayNBT, + b"\x00\x01\x02", + bytearray.fromhex("07 00 00 00 03") + b"\x00\x01\x02", + ), + ( + ByteArrayNBT, + b"\x00\x01\x02\x03", + bytearray.fromhex("07 00 00 00 04") + b"\x00\x01\x02\x03", + ), + ( + ByteArrayNBT, + b"\xff" * 1024, + bytearray.fromhex("07 00 00 04 00") + b"\xff" * 1024, + ), ( ByteArrayNBT, bytes((n - 1) * n * 2 % 256 for n in range(256)), @@ -100,7 +115,11 @@ def test_serialize_deserialize_end(): (StringNBT, "&à@é", bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8")), (ListNBT, [], bytearray.fromhex("09 00 00 00 00 00")), (ListNBT, [ByteNBT(0)], bytearray.fromhex("09 01 00 00 00 01 00")), - (ListNBT, [ShortNBT(127), ShortNBT(256)], bytearray.fromhex("09 02 00 00 00 02 00 7F 01 00")), + ( + ListNBT, + [ShortNBT(127), ShortNBT(256)], + bytearray.fromhex("09 02 00 00 00 02 00 7F 01 00"), + ), ( ListNBT, [ListNBT([ByteNBT(0)]), ListNBT([IntNBT(256)])], @@ -124,7 +143,10 @@ def test_serialize_deserialize_end(): ), ( CompoundNBT, - [CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test"), IntNBT(-1, "Int 2")], + [ + CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test"), + IntNBT(-1, "Int 2"), + ], bytearray.fromhex("0A") + CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test").serialize() + IntNBT(-1, "Int 2").serialize() @@ -132,15 +154,43 @@ def test_serialize_deserialize_end(): ), (IntArrayNBT, [], bytearray.fromhex("0B 00 00 00 00")), (IntArrayNBT, [0], bytearray.fromhex("0B 00 00 00 01 00 00 00 00")), - (IntArrayNBT, [0, 1], bytearray.fromhex("0B 00 00 00 02 00 00 00 00 00 00 00 01")), - (IntArrayNBT, [1, 2, 3], bytearray.fromhex("0B 00 00 00 03 00 00 00 01 00 00 00 02 00 00 00 03")), + ( + IntArrayNBT, + [0, 1], + bytearray.fromhex("0B 00 00 00 02 00 00 00 00 00 00 00 01"), + ), + ( + IntArrayNBT, + [1, 2, 3], + bytearray.fromhex("0B 00 00 00 03 00 00 00 01 00 00 00 02 00 00 00 03"), + ), (IntArrayNBT, [(1 << 31) - 1], bytearray.fromhex("0B 00 00 00 01 7F FF FF FF")), - (IntArrayNBT, [(1 << 31) - 1, (1 << 31) - 2], bytearray.fromhex("0B 00 00 00 02 7F FF FF FF 7F FF FF FE")), - (IntArrayNBT, [-1, -2, -3], bytearray.fromhex("0B 00 00 00 03 FF FF FF FF FF FF FF FE FF FF FF FD")), - (IntArrayNBT, [12] * 1024, bytearray.fromhex("0B 00 00 04 00") + b"\x00\x00\x00\x0C" * 1024), + ( + IntArrayNBT, + [(1 << 31) - 1, (1 << 31) - 2], + bytearray.fromhex("0B 00 00 00 02 7F FF FF FF 7F FF FF FE"), + ), + ( + IntArrayNBT, + [-1, -2, -3], + bytearray.fromhex("0B 00 00 00 03 FF FF FF FF FF FF FF FE FF FF FF FD"), + ), + ( + IntArrayNBT, + [12] * 1024, + bytearray.fromhex("0B 00 00 04 00") + b"\x00\x00\x00\x0c" * 1024, + ), (LongArrayNBT, [], bytearray.fromhex("0C 00 00 00 00")), - (LongArrayNBT, [0], bytearray.fromhex("0C 00 00 00 01 00 00 00 00 00 00 00 00")), - (LongArrayNBT, [0, 1], bytearray.fromhex("0C 00 00 00 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 01")), + ( + LongArrayNBT, + [0], + bytearray.fromhex("0C 00 00 00 01 00 00 00 00 00 00 00 00"), + ), + ( + LongArrayNBT, + [0, 1], + bytearray.fromhex("0C 00 00 00 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 01"), + ), ( LongArrayNBT, [1, 2, 3], @@ -148,7 +198,11 @@ def test_serialize_deserialize_end(): "0C 00 00 00 03 00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03" ), ), - (LongArrayNBT, [(1 << 63) - 1], bytearray.fromhex("0C 00 00 00 01 7F FF FF FF FF FF FF FF")), + ( + LongArrayNBT, + [(1 << 63) - 1], + bytearray.fromhex("0C 00 00 00 01 7F FF FF FF FF FF FF FF"), + ), ( LongArrayNBT, [(1 << 63) - 1, (1 << 63) - 2], @@ -161,7 +215,11 @@ def test_serialize_deserialize_end(): "0C 00 00 00 03 FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FE FF FF FF FF FF FF FF FD" ), ), - (LongArrayNBT, [12] * 1024, bytearray.fromhex("0C 00 00 04 00") + b"\x00\x00\x00\x00\x00\x00\x00\x0C" * 1024), + ( + LongArrayNBT, + [12] * 1024, + bytearray.fromhex("0C 00 00 04 00") + b"\x00\x00\x00\x00\x00\x00\x00\x0c" * 1024, + ), ], ) def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType, expected_bytes: bytes): @@ -193,33 +251,108 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType @pytest.mark.parametrize( ("nbt_class", "value", "name", "expected_bytes"), [ - (ByteNBT, 0, "test", bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("00")), - (ByteNBT, 1, "a", bytearray.fromhex("01") + b"\x00\x01a" + bytearray.fromhex("01")), - (ByteNBT, 127, "&à@é", bytearray.fromhex("01 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F")), - (ByteNBT, -128, "test", bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("80")), - (ByteNBT, 12, "a" * 100, bytearray.fromhex("01") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("0C")), - (ShortNBT, 0, "test", bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("00 00")), - (ShortNBT, 1, "a", bytearray.fromhex("02") + b"\x00\x01a" + bytearray.fromhex("00 01")), - (ShortNBT, 32767, "&à@é", bytearray.fromhex("02 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF")), - (ShortNBT, -32768, "test", bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("80 00")), - (ShortNBT, 12, "a" * 100, bytearray.fromhex("02") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 0C")), - (IntNBT, 0, "test", bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), - (IntNBT, 1, "a", bytearray.fromhex("03") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01")), + ( + ByteNBT, + 0, + "test", + bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("00"), + ), + ( + ByteNBT, + 1, + "a", + bytearray.fromhex("01") + b"\x00\x01a" + bytearray.fromhex("01"), + ), + ( + ByteNBT, + 127, + "&à@é", + bytearray.fromhex("01 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F"), + ), + ( + ByteNBT, + -128, + "test", + bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("80"), + ), + ( + ByteNBT, + 12, + "a" * 100, + bytearray.fromhex("01") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("0C"), + ), + ( + ShortNBT, + 0, + "test", + bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("00 00"), + ), + ( + ShortNBT, + 1, + "a", + bytearray.fromhex("02") + b"\x00\x01a" + bytearray.fromhex("00 01"), + ), + ( + ShortNBT, + 32767, + "&à@é", + bytearray.fromhex("02 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF"), + ), + ( + ShortNBT, + -32768, + "test", + bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("80 00"), + ), + ( + ShortNBT, + 12, + "a" * 100, + bytearray.fromhex("02") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 0C"), + ), + ( + IntNBT, + 0, + "test", + bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), + ), + ( + IntNBT, + 1, + "a", + bytearray.fromhex("03") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01"), + ), ( IntNBT, 2147483647, "&à@é", bytearray.fromhex("03 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF"), ), - (IntNBT, -2147483648, "test", bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00")), + ( + IntNBT, + -2147483648, + "test", + bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00"), + ), ( IntNBT, 12, "a" * 100, bytearray.fromhex("03") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 0C"), ), - (LongNBT, 0, "test", bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00 00 00 00")), - (LongNBT, 1, "a", bytearray.fromhex("04") + b"\x00\x01a" + bytearray.fromhex("00 00 00 00 00 00 00 01")), + ( + LongNBT, + 0, + "test", + bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00 00 00 00"), + ), + ( + LongNBT, + 1, + "a", + bytearray.fromhex("04") + b"\x00\x01a" + bytearray.fromhex("00 00 00 00 00 00 00 01"), + ), ( LongNBT, (1 << 63) - 1, @@ -238,25 +371,60 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType "a" * 100, bytearray.fromhex("04") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 00 00 00 00 0C"), ), - (FloatNBT, 1.0, "test", bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 1.0))), - (FloatNBT, 3.14, "a", bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 3.14))), + ( + FloatNBT, + 1.0, + "test", + bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 1.0)), + ), + ( + FloatNBT, + 3.14, + "a", + bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 3.14)), + ), ( FloatNBT, -1.0, "&à@é", bytearray.fromhex("05 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">f", -1.0)), ), - (FloatNBT, 12.0, "test", bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 12.0))), - (DoubleNBT, 1.0, "test", bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 1.0))), - (DoubleNBT, 3.14, "a", bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 3.14))), + ( + FloatNBT, + 12.0, + "test", + bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 12.0)), + ), + ( + DoubleNBT, + 1.0, + "test", + bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 1.0)), + ), + ( + DoubleNBT, + 3.14, + "a", + bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 3.14)), + ), ( DoubleNBT, -1.0, "&à@é", bytearray.fromhex("06 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">d", -1.0)), ), - (DoubleNBT, 12.0, "test", bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 12.0))), - (ByteArrayNBT, b"", "test", bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + ( + DoubleNBT, + 12.0, + "test", + bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 12.0)), + ), + ( + ByteArrayNBT, + b"", + "test", + bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), + ), ( ByteArrayNBT, b"\x00", @@ -277,12 +445,22 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType ), ( ByteArrayNBT, - b"\xFF" * 1024, + b"\xff" * 1024, "a" * 100, - bytearray.fromhex("07") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 04 00") + b"\xFF" * 1024, + bytearray.fromhex("07") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 04 00") + b"\xff" * 1024, + ), + ( + StringNBT, + "", + "test", + bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 00"), + ), + ( + StringNBT, + "test", + "a", + bytearray.fromhex("08") + b"\x00\x01a" + bytearray.fromhex("00 04") + b"test", ), - (StringNBT, "", "test", bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 00")), - (StringNBT, "test", "a", bytearray.fromhex("08") + b"\x00\x01a" + bytearray.fromhex("00 04") + b"test"), ( StringNBT, "a" * 100, @@ -295,7 +473,12 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType "test", bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 06") + bytes("&à@é", "utf-8"), ), - (ListNBT, [], "test", bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00")), + ( + ListNBT, + [], + "test", + bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00"), + ), ( ListNBT, [ByteNBT(-1)], @@ -316,7 +499,12 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType + b"\x00\x01a" + bytearray.fromhex("09 00 00 00 02 01 00 00 00 01 FF 03 00 00 00 01 00 00 01 00"), ), - (CompoundNBT, [], "test", bytearray.fromhex("0A") + b"\x00\x04test" + bytearray.fromhex("00")), + ( + CompoundNBT, + [], + "test", + bytearray.fromhex("0A") + b"\x00\x04test" + bytearray.fromhex("00"), + ), ( CompoundNBT, [ByteNBT(0, name="Byte")], @@ -348,7 +536,12 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType "test", bytearray.fromhex("0A") + b"\x00\x04test" + ListNBT([ByteNBT(0)], name="List").serialize() + b"\x00", ), - (IntArrayNBT, [], "test", bytearray.fromhex("0B") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + ( + IntArrayNBT, + [], + "test", + bytearray.fromhex("0B") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), + ), ( IntArrayNBT, [0], @@ -381,9 +574,14 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 01") - + b"\x7F\xFF\xFF\xFF", + + b"\x7f\xff\xff\xff", + ), + ( + LongArrayNBT, + [], + "test", + bytearray.fromhex("0C") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), ), - (LongArrayNBT, [], "test", bytearray.fromhex("0C") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), ( LongArrayNBT, [0], @@ -419,7 +617,7 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 64") - + b"\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF" * 100, + + b"\x7f\xff\xff\xff\xff\xff\xff\xff" * 100, ), ], ) @@ -471,9 +669,6 @@ def test_serialize_deserialize_numerical_fail(nbt_class: type[NBTag], size: int, with pytest.raises(OverflowError): nbt_class(-(1 << (size - 1)) - 1).serialize(with_name=False) - with pytest.raises(ValueError): # No name - nbt_class(0, "").serialize() # without with_name=False - # Deserialization buffer = Buffer(bytearray([tag.value + 1] + [0] * (size // 8))) with pytest.raises(TypeError): # Tries to read a nbt_class, but it's one higher @@ -495,9 +690,6 @@ def test_serialize_deserialize_numerical_fail(nbt_class: type[NBTag], size: int, def test_serialize_deserialize_float_fail(): """Test serialization/deserialization of NBT FLOAT tag with invalid value.""" - with pytest.raises(ValueError): - FloatNBT(0, 0).serialize() # type:ignore - with pytest.raises(struct.error): FloatNBT("test").serialize(with_name=False) @@ -524,9 +716,6 @@ def test_serialize_deserialize_float_fail(): def test_serialize_deserialize_double_fail(): """Test serialization/deserialization of NBT DOUBLE tag with invalid value.""" - with pytest.raises(ValueError): - DoubleNBT(0, 0).serialize() # type: ignore - with pytest.raises(struct.error): DoubleNBT("test").serialize(with_name=False) @@ -547,12 +736,6 @@ def test_serialize_deserialize_double_fail(): def test_serialize_deserialize_bytearray_fail(): """Test serialization/deserialization of NBT BYTEARRAY tag with invalid value.""" - with pytest.raises(ValueError): - ByteArrayNBT([], 0).serialize() # type:ignore - - with pytest.raises(ValueError): - ByteArrayNBT(b"test", "").serialize() - # Deserialization buffer = Buffer(bytearray([0x01] + [0] * 4)) with pytest.raises(TypeError): # Tries to read a ByteArrayNBT, but it's a ByteNBT @@ -585,12 +768,6 @@ def test_serialize_deserialize_bytearray_fail(): def test_serialize_deserialize_string_fail(): """Test serialization/deserialization of NBT STRING tag with invalid value.""" - with pytest.raises(ValueError): - StringNBT("", 0).serialize() # type:ignore - - with pytest.raises(ValueError): - StringNBT("test", "").serialize() - # Deserialization buffer = Buffer(bytearray([0x01, 0, 0])) with pytest.raises(TypeError): # Tries to read a StringNBT, but it's a ByteNBT @@ -636,7 +813,7 @@ def test_serialize_deserialize_string_fail(): ([ByteNBT(128), ByteNBT(-1)], OverflowError), # Check for error propagation ], ) -def test_serialize_list_fail(payload, error): +def test_serialize_list_fail(payload: PayloadType, error: type[Exception]): """Test serialization of NBT LIST tag with invalid value.""" with pytest.raises(error): ListNBT(payload, "test").serialize() @@ -671,10 +848,13 @@ def test_deserialize_list_fail(): ([ByteNBT(0, name="hi"), "test"], ValueError), ([ByteNBT(0, name="hi"), None], ValueError), ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags - ([ByteNBT(128, name="Jello"), ByteNBT(-1, name="Bonjour")], OverflowError), # Check for error propagation + ( + [ByteNBT(128, name="Jello"), ByteNBT(-1, name="Bonjour")], + OverflowError, + ), # Check for error propagation ], ) -def test_serialize_compound_fail(payload, error): +def test_serialize_compound_fail(payload: PayloadType, error: type[Exception]): """Test serialization of NBT COMPOUND tag with invalid value.""" with pytest.raises(error): CompoundNBT(payload, "test").serialize() @@ -702,30 +882,40 @@ def test_deseialize_compound_fail(): NBTag.deserialize(buffer) -def test_to_object_compound(): - """Try a few incorrect CompoundNBT.to_object() calls.""" - comp = CompoundNBT([ByteNBT(0, "test"), ByteNBT(1, "test")]) - with pytest.raises(ValueError): - comp.to_object() # Duplicate name +def test_nbtag_deserialize_compound(): + """Test deserialization of NBT COMPOUND tag from the NBTag class.""" + buf = Buffer(bytearray([0x00])) + assert NBTag.deserialize(buf, with_type=False, with_name=False) == CompoundNBT([]) - comp = CompoundNBT([ByteNBT(0), ByteNBT(1)]) - with pytest.raises(ValueError): - comp.to_object() + buf = Buffer(bytearray.fromhex("0A 00 01 61 01 00 01 62 00 00")) + assert NBTag.deserialize(buf) == CompoundNBT([ByteNBT(0, name="b")], name="a") def test_equality_compound(): """Test equality of CompoundNBT.""" - comp1 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") - comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") + comp1 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], + "comp", + ) + comp2 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], + "comp", + ) assert comp1 == comp2 comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2")], "comp") assert comp1 != comp2 - comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], "comp") + comp2 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], + "comp", + ) assert comp1 != comp2 - comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp2") + comp2 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], + "comp2", + ) assert comp1 != comp2 assert comp1 != ByteNBT(0, name="comp") @@ -744,7 +934,7 @@ def test_equality_compound(): ([0, -(1 << 31) - 1], OverflowError), ], ) -def test_serialize_intarray_fail(payload, error): +def test_serialize_intarray_fail(payload: PayloadType, error: type[Exception]): """Test serialization of NBT INTARRAY tag with invalid value.""" with pytest.raises(error): IntArrayNBT(payload, "test").serialize() @@ -781,7 +971,7 @@ def test_deserialize_intarray_fail(): ([0, -(1 << 63) - 1], OverflowError), ], ) -def test_serialize_deserialize_longarray_fail(payload, error): +def test_serialize_deserialize_longarray_fail(payload: PayloadType, error: type[Exception]): """Test serialization/deserialization of NBT LONGARRAY tag with invalid value.""" with pytest.raises(error): LongArrayNBT(payload, "test").serialize() @@ -819,60 +1009,86 @@ def test_nbt_helloworld(): buffer = Buffer(data) expected_object = { - "hello world": { - "name": "Bananrama", - } + "name": "Bananrama", } + expected_schema = {"name": StringNBT} data = CompoundNBT.deserialize(buffer) - assert data == NBTag.from_object(expected_object) + assert data == NBTag.from_object(expected_object, schema=expected_schema, name="hello world") assert data.to_object() == expected_object def test_nbt_bigfile(): """Test serialization/deserialization of a big NBT tag. - Slighly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. + Slightly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. Source data: https://wiki.vg/NBT#Example. """ data = "0a00054c6576656c0400086c6f6e67546573747fffffffffffffff02000973686f7274546573747fff08000a737472696e6754657374002948454c4c4f20574f524c4420544849532049532041205445535420535452494e4720c385c384c39621050009666c6f6174546573743eff1832030007696e74546573747fffffff0a00146e657374656420636f6d706f756e6420746573740a000368616d0800046e616d65000648616d70757305000576616c75653f400000000a00036567670800046e616d6500074567676265727405000576616c75653f00000000000c000f6c6973745465737420286c6f6e672900000005000000000000000b000000000000000c000000000000000d000000000000000e7fffffffffffffff0b000e6c697374546573742028696e7429000000047fffffff7ffffffe7ffffffd7ffffffc0900136c697374546573742028636f6d706f756e64290a000000020800046e616d65000f436f6d706f756e642074616720233004000a637265617465642d6f6e000001265237d58d000800046e616d65000f436f6d706f756e642074616720233104000a637265617465642d6f6e000001265237d58d0001000862797465546573747f07006562797465417272617954657374202874686520666972737420313030302076616c756573206f6620286e2a6e2a3235352b6e2a3729253130302c207374617274696e672077697468206e3d302028302c2036322c2033342c2031362c20382c202e2e2e2929000003e8003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a063005000a646f75626c65546573743efc7b5e00" # noqa: E501 - data = bytearray.fromhex(data) + data = bytes.fromhex(data) buffer = Buffer(data) - expected_object = { - "Level": { - "longTest": 9223372036854775807, - "shortTest": 32767, - "stringTest": "HELLO WORLD THIS IS A TEST STRING ÅÄÖ!", - "floatTest": 0.4982314705848694, - "intTest": 2147483647, - "nested compound test": { - "ham": {"name": "Hampus", "value": 0.75}, - "egg": {"name": "Eggbert", "value": 0.5}, + expected_object = { # Name ! Level + "longTest": 9223372036854775807, + "shortTest": 32767, + "stringTest": "HELLO WORLD THIS IS A TEST STRING ÅÄÖ!", + "floatTest": 0.4982314705848694, + "intTest": 2147483647, + "nested compound test": { + "ham": {"name": "Hampus", "value": 0.75}, + "egg": {"name": "Eggbert", "value": 0.5}, + }, + "listTest (long)": [11, 12, 13, 14, 9223372036854775807], + "listTest (int)": [2147483647, 2147483646, 2147483645, 2147483644], + "listTest (compound)": [ + {"name": "Compound tag #0", "created-on": 1264099775885}, + {"name": "Compound tag #1", "created-on": 1264099775885}, + ], + "byteTest": 127, + "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100, " + "starting with n=0 (0, 62, 34, 16, 8, ...))": bytes((n * n * 255 + n * 7) % 100 for n in range(1000)), + "doubleTest": 0.4931287132182315, + } + expected_schema = { + "longTest": LongNBT, + "shortTest": ShortNBT, + "stringTest": StringNBT, + "floatTest": FloatNBT, + "intTest": IntNBT, + "nested compound test": { + "ham": { + "name": StringNBT, + "value": FloatNBT, }, - "listTest (long)": [11, 12, 13, 14, 9223372036854775807], - "listTest (int)": [2147483647, 2147483646, 2147483645, 2147483644], - "listTest (compound)": [ - {"name": "Compound tag #0", "created-on": 1264099775885}, - {"name": "Compound tag #1", "created-on": 1264099775885}, - ], - "byteTest": 127, - "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100" - ", starting with n=0 (0, 62, 34, 16, 8, ...))": bytearray( - (n * n * 255 + n * 7) % 100 for n in range(1000) - ), - "doubleTest": 0.4931287132182315, - } + "egg": { + "name": StringNBT, + "value": FloatNBT, + }, + }, + "listTest (long)": LongArrayNBT, + "listTest (int)": IntArrayNBT, + "listTest (compound)": [ + { + "name": StringNBT, + "created-on": LongNBT, + } + ], + "byteTest": ByteNBT, + "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100, " + "starting with n=0 (0, 62, 34, 16, 8, ...))": ByteArrayNBT, + "doubleTest": FloatNBT, } data = CompoundNBT.deserialize(buffer) # print(f"{data=}\n{expected_object=}\n{data.to_object()=}\n{NBTag.from_object(expected_object)=}") - def check_equality(self, other): + def check_equality(self: object, other: object) -> bool: """Check if two objects are equal, with deep epsilon check for floats.""" if type(self) != type(other): return False if isinstance(self, dict): + self = cast(Dict[Any, Any], self) + other = cast(Dict[Any, Any], other) if len(self) != len(other): return False for key in self: @@ -882,16 +1098,18 @@ def check_equality(self, other): return False return True if isinstance(self, list): + self = cast(List[Any], self) + other = cast(List[Any], other) if len(self) != len(other): return False return all(check_equality(self[i], other[i]) for i in range(len(self))) - if isinstance(self, float): + if isinstance(self, float) and isinstance(other, float): return abs(self - other) < 1e-6 if self != other: return False return self == other - assert data == NBTag.from_object(expected_object) + assert data == NBTag.from_object(expected_object, schema=expected_schema, name="Level") assert check_equality(data.to_object(), expected_object) @@ -902,30 +1120,16 @@ def check_equality(self, other): def test_from_object_lst_not_same_type(): """Test from_object with a list that does not have the same type.""" with pytest.raises(TypeError): - NBTag.from_object([ByteNBT(0), IntNBT(0)]) - - -def test_from_object_out_of_bounds(): - """Test from_object with a value that is out of bounds.""" - with pytest.raises(ValueError): - NBTag.from_object({"test": 1 << 63}) - - with pytest.raises(ValueError): - NBTag.from_object({"test": -(1 << 63) - 1}) - - with pytest.raises(ValueError): - NBTag.from_object({"test": [1 << 63]}) - - with pytest.raises(ValueError): - NBTag.from_object({"test": [-(1 << 63) - 1]}) + NBTag.from_object([0, "test"], [IntNBT, StringNBT]) def test_from_object_morecases(): """Test from_object with more edge cases.""" - class CustomType: - def __bytes__(self): - return b"test" + class CustomType(NBTagConvertible): + @override + def to_nbt(self, name: str = "") -> NBTag: + return ByteArrayNBT(b"CustomType", name) assert NBTag.from_object( { @@ -933,60 +1137,48 @@ def __bytes__(self): "bytearray": b"test", # Conversion from bytes "empty_list": [], # Empty list with type EndNBT "empty_compound": {}, # Empty compound - "end_NBTag": None, # Should not be done in practice, would create a broken buffer if serialized - "custom": CustomType(), # Custom type with __bytes__ method - } + "custom": CustomType(), # Custom type with to_nbt method + "recursive_list": [ + [0, 1, 2], + [3, 4, 5], + ], + }, + { + "nbtag": ByteNBT, + "bytearray": ByteArrayNBT, + "empty_list": [], + "empty_compound": {}, + "custom": CustomType, + "recursive_list": [[IntNBT], [ShortNBT]], + }, ) == CompoundNBT( [ # Order is shuffled because the spec does not require a specific order CompoundNBT([], "empty_compound"), ByteArrayNBT(b"test", "bytearray"), - ByteArrayNBT(b"test", "custom"), + ByteArrayNBT(b"CustomType", "custom"), ListNBT([], "empty_list"), ByteNBT(0, "nbtag"), - EndNBT(), + ListNBT( + [ListNBT([IntNBT(0), IntNBT(1), IntNBT(2)]), ListNBT([ShortNBT(3), ShortNBT(4), ShortNBT(5)])], + "recursive_list", + ), ] ) - # Not a valid object - with pytest.raises(TypeError): - NBTag.from_object({"test": object()}) - compound = CompoundNBT.from_object( { "test": ByteNBT(0), "test2": IntNBT(0), }, + { + "test": ByteNBT, + "test2": IntNBT, + }, name="compound", ) - assert compound["test"] == ByteNBT(0, "test") - assert compound["test2"] == IntNBT(0, "test2") - with pytest.raises(KeyError): - compound["test3"] - # Cannot index into a ByteNBT - with pytest.raises(TypeError): - compound["test"][0] # type:ignore - - listnbt = ListNBT.from_object([0, 1, 2], use_int_array=False) - assert listnbt[0] == ByteNBT(0) - assert listnbt[1] == ByteNBT(1) - assert listnbt[2] == ByteNBT(2) - with pytest.raises(IndexError): - listnbt[3] - with pytest.raises(TypeError): - listnbt["hello"] - - assert listnbt[-1] == ByteNBT(2) - assert listnbt[-2] == ByteNBT(1) - assert listnbt[-3] == ByteNBT(0) - - with pytest.raises(TypeError): - listnbt[object()] # type:ignore - - assert listnbt.value == [0, 1, 2] - assert listnbt.to_object() == [0, 1, 2] assert ListNBT([]).value == [] - assert compound.to_object() == {"compound": {"test": 0, "test2": 0}} + assert compound.to_object(include_name=True) == {"compound": {"test": 0, "test2": 0}} assert compound.value == {"test": 0, "test2": 0} assert ListNBT([IntNBT(0)]).value == [0] @@ -1001,21 +1193,86 @@ def __bytes__(self): assert IntArrayNBT([0, 1, 2]).value == [0, 1, 2] assert LongArrayNBT([0, 1, 2, 3]).value == [0, 1, 2, 3] - invalid = ListNBT("Hello", "name") - with pytest.raises(AttributeError): - invalid[0] - invalid = CompoundNBT([ByteNBT(0, "Byte"), "Hi"], "name") - with pytest.raises(AttributeError): - invalid["Byte"] # Attribute error is raised when the structure is incorrectly constructed +@pytest.mark.parametrize( + ("data", "schema", "error", "error_msg"), + [ + # Data is not a list + ({"test": 0}, {"test": [ByteNBT]}, TypeError, "Expected a list, but found a different type."), + # Expected a list of dict, got a list of NBTags for schema + ( + {"test": [1, 0]}, + {"test": [ByteNBT, IntNBT]}, + TypeError, + "Expected a list of lists or dictionaries, but found a different type.", + ), + # Schema and data have different lengths + ( + [[1], [2], [3]], + [[ByteNBT], [IntNBT]], + ValueError, + "The schema and the data must have the same length.", + ), + # schema empty, data is not + ([1], [], ValueError, "The schema is empty, but the data is not."), + # Schema is a dict, data is not + (["test"], {"test": ByteNBT}, TypeError, "Expected a dictionary, but found a different type."), + # Schema is not a dict, list or subclass of NBTagConvertible + ( + ["test"], + "test", + TypeError, + "The schema must be a list, dict or a subclass of either NBTag or NBTagConvertible.", + ), + # Schema contains CompoundNBT or ListNBT instead of a dict or list + ( + {"test": 0}, + CompoundNBT, + ValueError, + "The schema must specify the type of the elements in CompoundNBT and ListNBT tags.", + ), + ( + ["test"], + ListNBT, + ValueError, + "The schema must specify the type of the elements in CompoundNBT and ListNBT tags.", + ), + # The schema specifies a type, but the data is a dict with more than one key + ( + {"test": 0, "test2": 1}, + ByteNBT, + ValueError, + "Expected a dictionary with a single key-value pair.", + ), + # The data is not of the right type to be a payload + ( + {"test": object()}, + ByteNBT, + TypeError, + "Expected a bytes, str, int, float, but found object.", + ), + # The data is a list but not all elements are ints + ( + [0, "test"], + IntArrayNBT, + TypeError, + "Expected a list of integers.", + ), + ], +) +def test_from_object_error(data: Any, schema: Any, error: type[Exception], error_msg: str): + """Test from_object with erroneous data.""" + with pytest.raises(error, match=error_msg): + NBTag.from_object(data, schema) def test_to_object_morecases(): """Test to_object with more edge cases.""" - class CustomType: - def __bytes__(self): - return b"test" + class CustomType(NBTagConvertible): + @override + def to_nbt(self, name: str = "") -> NBTag: + return ByteArrayNBT(b"CustomType", name) assert NBTag.from_object( { @@ -1023,27 +1280,58 @@ def __bytes__(self): "empty_list": [], "empty_compound": {}, "custom": CustomType(), - } - ).to_object() == { - "bytearray": b"test", - "empty_list": [], - "empty_compound": {}, - "custom": b"test", - } - - assert NBTag.to_object(CompoundNBT([])) == {} + "recursive_list": [ + [0, 1, 2], + [3, 4, 5], + ], + "compound_list": [{"test": 0, "test2": 1}, {"test2": 1}], + }, + { + "bytearray": ByteArrayNBT, + "empty_list": [], + "empty_compound": {}, + "custom": CustomType, + "recursive_list": [[IntNBT], [ShortNBT]], + "compound_list": [{"test": ByteNBT, "test2": IntNBT}, {"test2": IntNBT}], + }, + ).to_object(include_schema=True) == ( + { + "bytearray": b"test", + "empty_list": [], + "empty_compound": {}, + "custom": b"CustomType", + "recursive_list": [[0, 1, 2], [3, 4, 5]], + "compound_list": [{"test": 0, "test2": 1}, {"test2": 1}], + }, + { + "bytearray": ByteArrayNBT, + "empty_list": [], + "empty_compound": {}, + "custom": ByteArrayNBT, # After the conversion, the NBT tag is a ByteArrayNBT + "recursive_list": [[IntNBT], [ShortNBT]], + "compound_list": [{"test": ByteNBT, "test2": IntNBT}, {"test2": IntNBT}], + }, + ) - assert EndNBT().to_object() == {} # Does not add anything when doing dict.update assert FloatNBT(0.5).to_object() == 0.5 - assert FloatNBT(0.5, "Hello World").to_object() == {"Hello World": 0.5} + assert FloatNBT(0.5, "Hello World").to_object(include_name=True) == {"Hello World": 0.5} assert ByteArrayNBT(b"test").to_object() == b"test" # Do not add name when there is no name assert StringNBT("test").to_object() == "test" - assert StringNBT("test", "name").to_object() == {"name": "test"} + assert StringNBT("test", "name").to_object(include_name=True) == {"name": "test"} assert ListNBT([ByteNBT(0), ByteNBT(1)]).to_object() == [0, 1] - assert ListNBT([ByteNBT(0), ByteNBT(1)], "name").to_object() == {"name": [0, 1]} + assert ListNBT([ByteNBT(0), ByteNBT(1)], "name").to_object(include_name=True) == {"name": [0, 1]} assert IntArrayNBT([0, 1, 2]).to_object() == [0, 1, 2] assert LongArrayNBT([0, 1, 2]).to_object() == [0, 1, 2] + with pytest.raises(TypeError): + NBTag.to_object(CompoundNBT([])) + + with pytest.raises(TypeError): + ListNBT([CompoundNBT([]), ListNBT([])]).to_object(include_schema=True) + + with pytest.raises(TypeError): + ListNBT([IntNBT(0), ShortNBT(0)]).to_object(include_schema=True) + def test_data_conversions(): """Test data conversions using the built-in functions.""" @@ -1063,11 +1351,32 @@ def test_data_conversions(): def test_init_nbtag_directly(): """Test initializing NBTag directly.""" with pytest.raises(TypeError): - NBTag(0) - with pytest.raises(TypeError): - NBTag(0, "test") + NBTag(0) # type: ignore # I know, that's what I'm testing + + +@pytest.mark.parametrize( + ("buffer_content", "tag_type"), + [ + ("01", EndNBT), + ("00 00", ByteNBT), + ("01 0000", ShortNBT), + ("02 00000000", IntNBT), + ("03 0000000000000000", LongNBT), + ("04 3F800000", FloatNBT), + ("05 3FF999999999999A", DoubleNBT), + ("06 00", ByteArrayNBT), + ("07 00", StringNBT), + ("08 00", ListNBT), + ("09 00", CompoundNBT), + ("0A 00", IntArrayNBT), + ("0B 00", LongArrayNBT), + ], +) +def test_wrong_type(buffer_content: str, tag_type: type[NBTag]): + """Test read_from with wrong tag type in the buffer.""" + buffer = Buffer(bytearray.fromhex(buffer_content)) with pytest.raises(TypeError): - NBTag(0, name="test") + tag_type.read_from(buffer, with_name=False) # endregion