diff --git a/tests/test_build.py b/tests/test_build.py index 078eba9a..098b2d52 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -79,6 +79,7 @@ def test_build_example(example): doc.add(nl()) doc.add(comment("Products")) + doc.add(nl()) products = aot() doc["products"] = products @@ -128,3 +129,20 @@ def test_append_table_after_multiple_indices(): """ doc = parse(content) doc.append("foobar", {"name": "John"}) + + +def test_top_level_keys_are_put_at_the_root_of_the_document(): + doc = document() + doc.add(comment("Comment")) + doc["foo"] = {"name": "test"} + doc["bar"] = 1 + + expected = """\ +# Comment +bar = 1 + +[foo] +name = "test" +""" + + assert doc.as_string() diff --git a/tests/test_parser.py b/tests/test_parser.py index d8219e12..9f759a2e 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,5 +1,6 @@ import pytest +from tomlkit.exceptions import EmptyTableNameError from tomlkit.exceptions import InternalParserError from tomlkit.items import StringType from tomlkit.parser import Parser @@ -13,3 +14,19 @@ def test_parser_should_raise_an_internal_error_if_parsing_wrong_type_of_string() assert e.value.line == 1 assert e.value.col == 0 + + +def test_parser_should_raise_an_error_for_empty_tables(): + content = """ +[one] + +[] +""" + + parser = Parser(content) + + with pytest.raises(EmptyTableNameError) as e: + parser.parse() + + assert e.value.line == 4 + assert e.value.col == 1 diff --git a/tests/test_toml_document.py b/tests/test_toml_document.py index 5fed0848..f5695f6d 100644 --- a/tests/test_toml_document.py +++ b/tests/test_toml_document.py @@ -12,6 +12,7 @@ import tomlkit from tomlkit import parse +from tomlkit._compat import PY36 from tomlkit._utils import _utc from tomlkit.exceptions import NonExistentKey @@ -618,3 +619,53 @@ def test_string_output_order_is_preserved_for_out_of_order_tables(): """ assert expected == doc.as_string() + + +def test_updating_nested_value_keeps_correct_indent(): + content = """ +[Key1] + [key1.Key2] + Value1 = 10 + Value2 = 30 +""" + + doc = parse(content) + doc["key1"]["Key2"]["Value1"] = 20 + + expected = """ +[Key1] + [key1.Key2] + Value1 = 20 + Value2 = 30 +""" + + assert doc.as_string() == expected + + +@pytest.mark.skipif(not PY36, reason="Dict order is not deterministic on Python < 3.6") +def test_repr(): + content = """ +namespace.key1 = "value1" +namespace.key2 = "value2" + +[tool.poetry.foo] +option = "test" + +[tool.poetry.bar] +option = "test" +inline = {"foo" = "bar", "bar" = "baz"} +""" + + doc = parse(content) + + assert ( + repr(doc) + == "{'namespace': {'key1': 'value1', 'key2': 'value2'}, 'tool': {'poetry': {'foo': {'option': 'test'}, 'bar': {'option': 'test', 'inline': {'foo': 'bar', 'bar': 'baz'}}}}}" + ) + + assert ( + repr(doc["tool"]) + == "{'poetry': {'foo': {'option': 'test'}, 'bar': {'option': 'test', 'inline': {'foo': 'bar', 'bar': 'baz'}}}}" + ) + + assert repr(doc["namespace"]) == "{'key1': 'value1', 'key2': 'value2'}" diff --git a/tomlkit/_compat.py b/tomlkit/_compat.py index 8d3b0ae3..487ed990 100644 --- a/tomlkit/_compat.py +++ b/tomlkit/_compat.py @@ -155,6 +155,11 @@ def _name_from_offset(delta): else: from collections import OrderedDict +try: + from collections.abc import MutableMapping +except ImportError: + from collections import MutableMapping + def decode(string, encodings=None): if not PY2 and not isinstance(string, bytes): diff --git a/tomlkit/container.py b/tomlkit/container.py index 6386e738..4e3ea12a 100644 --- a/tomlkit/container.py +++ b/tomlkit/container.py @@ -10,6 +10,7 @@ from typing import Tuple from typing import Union +from ._compat import MutableMapping from ._compat import decode from ._utils import merge_dicts from .exceptions import KeyAlreadyPresent @@ -29,7 +30,7 @@ _NOT_SET = object() -class Container(dict): +class Container(MutableMapping, dict): """ A container for items within a TOMLDocument. """ @@ -111,8 +112,6 @@ def append(self, key, item): # type: (Union[Key, str, None], Item) -> Container if isinstance(item, AoT) and self._body and not self._parsed: if item and "\n" not in item[0].trivia.indent: item[0].trivia.indent = "\n" + item[0].trivia.indent - else: - self.append(None, Whitespace("\n")) if key is not None and key in self: current_idx = self._map[key] @@ -210,7 +209,7 @@ def append(self, key, item): # type: (Union[Key, str, None], Item) -> Container if key_after is not None: if isinstance(key_after, int): - if key_after + 1 < len(self._body) - 1: + if key_after + 1 < len(self._body): return self._insert_at(key_after + 1, key, item) else: previous_item = self._body[-1][1] @@ -247,7 +246,7 @@ def append(self, key, item): # type: (Union[Key, str, None], Item) -> Container self._table_keys.append(key) if key is not None: - super(Container, self).__setitem__(key.key, item.value) + dict.__setitem__(self, key.key, item.value) return self @@ -265,7 +264,7 @@ def remove(self, key): # type: (Union[Key, str]) -> Container else: self._body[idx] = (None, Null()) - super(Container, self).__delitem__(key.key) + dict.__delitem__(self, key.key) return self @@ -312,7 +311,7 @@ def _insert_after( self._body.insert(idx + 1, (other_key, item)) if key is not None: - super(Container, self).__setitem__(other_key.key, item.value) + dict.__setitem__(self, other_key.key, item.value) return self @@ -354,7 +353,7 @@ def _insert_at( self._body.insert(idx, (key, item)) if key is not None: - super(Container, self).__setitem__(key.key, item.value) + dict.__setitem__(self, key.key, item.value) return self @@ -513,33 +512,6 @@ def _render_simple_item(self, key, item, prefix=None): # Dictionary methods - def keys(self): # type: () -> Generator[str] - return super(Container, self).keys() - - def values(self): # type: () -> Generator[Item] - for k in self.keys(): - yield self[k] - - def items(self): # type: () -> Generator[Item] - for k, v in self.value.items(): - if k is None: - continue - - yield k, v - - def update(self, other): # type: (Dict) -> None - for k, v in other.items(): - self[k] = v - - def get(self, key, default=None): # type: (Any, Optional[Any]) -> Any - if not isinstance(key, Key): - key = Key(key) - - if key not in self: - return default - - return self[key] - def pop(self, key, default=_NOT_SET): try: value = self[key] @@ -556,8 +528,7 @@ def pop(self, key, default=_NOT_SET): def setdefault( self, key, default=None ): # type: (Union[Key, str], Any) -> Union[Item, Container] - if key not in self: - self[key] = default + super(Container, self).setdefault(key, default=default) return self[key] @@ -567,6 +538,12 @@ def __contains__(self, key): # type: (Union[Key, str]) -> bool return key in self._map + def __setitem__(self, key, value): # type: (Union[Key, str], Any) -> None + if key is not None and key in self: + self._replace(key, key, value) + else: + self.append(key, value) + def __getitem__(self, key): # type: (Union[Key, str]) -> Union[Item, Container] if not isinstance(key, Key): key = Key(key) @@ -596,6 +573,12 @@ def __setitem__(self, key, value): # type: (Union[Key, str], Any) -> None def __delitem__(self, key): # type: (Union[Key, str]) -> None self.remove(key) + def __len__(self): # type: () -> int + return dict.__len__(self) + + def __iter__(self): # type: () -> Iterator[str] + return iter(dict.keys(self)) + def _replace( self, key, new_key, value ): # type: (Union[Key, str], Union[Key, str], Item) -> None @@ -627,7 +610,7 @@ def _replace_at( self._map[new_key] = self._map.pop(k) if new_key != k: - super(Container, self).__delitem__(k) + dict.__delitem__(self, k) if isinstance(self._map[new_key], tuple): self._map[new_key] = self._map[new_key][0] @@ -647,13 +630,13 @@ def _replace_at( self._body[idx] = (new_key, value) - super(Container, self).__setitem__(new_key.key, value.value) + dict.__setitem__(self, new_key.key, value.value) def __str__(self): # type: () -> str return str(self.value) def __repr__(self): # type: () -> str - return super(Container, self).__repr__() + return repr(self.value) def __eq__(self, other): # type: (Dict) -> bool if not isinstance(other, dict): @@ -684,8 +667,8 @@ def copy(self): # type: () -> Container def __copy__(self): # type: () -> Container c = self.__class__(self._parsed) - for k, v in super(Container, self).copy().items(): - super(Container, c).__setitem__(k, v) + for k, v in dict.items(self): + dict.__setitem__(c, k, v) c._body += self.body c._map.update(self._map) @@ -693,7 +676,7 @@ def __copy__(self): # type: () -> Container return c -class OutOfOrderTableProxy(dict): +class OutOfOrderTableProxy(MutableMapping, dict): def __init__(self, container, indices): # type: (Container, Tuple) -> None self._container = container self._internal_container = Container(self._container.parsing) @@ -711,12 +694,12 @@ def __init__(self, container, indices): # type: (Container, Tuple) -> None self._internal_container.append(k, v) self._tables_map[k] = table_idx if k is not None: - super(OutOfOrderTableProxy, self).__setitem__(k.key, v) + dict.__setitem__(self, k.key, v) else: self._internal_container.append(key, item) self._map[key] = i if key is not None: - super(OutOfOrderTableProxy, self).__setitem__(key.key, item) + dict.__setitem__(self, key.key, item) @property def value(self): @@ -742,7 +725,7 @@ def __setitem__(self, key, item): # type: (Union[Key, str], Any) -> None self._container[key] = item if key is not None: - super(OutOfOrderTableProxy, self).__setitem__(key, item) + dict.__setitem__(self, key, item) def __delitem__(self, key): # type: (Union[Key, str]) -> None if key in self._map: @@ -784,6 +767,9 @@ def setdefault( def __contains__(self, key): return key in self._internal_container + def __iter__(self): # type: () -> Iterator[str] + return iter(self._internal_container) + def __str__(self): return str(self._internal_container) diff --git a/tomlkit/exceptions.py b/tomlkit/exceptions.py index 44836363..d0c7ab5a 100644 --- a/tomlkit/exceptions.py +++ b/tomlkit/exceptions.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + from typing import Optional diff --git a/tomlkit/items.py b/tomlkit/items.py index 184ffe7d..f738e0be 100644 --- a/tomlkit/items.py +++ b/tomlkit/items.py @@ -16,6 +16,7 @@ from ._compat import PY2 from ._compat import PY38 +from ._compat import MutableMapping from ._compat import decode from ._compat import long from ._compat import unicode @@ -866,7 +867,7 @@ def _getstate(self, protocol=3): return self._value, self._trivia -class Table(Item, dict): +class Table(Item, MutableMapping, dict): """ A table literal. """ @@ -890,7 +891,7 @@ def __init__( for k, v in self._value.body: if k is not None: - super(Table, self).__setitem__(k.key, v) + dict.__setitem__(self, k.key, v) @property def value(self): # type: () -> tomlkit.container.Container @@ -924,7 +925,7 @@ def append(self, key, _item): # type: (Union[Key, str], Any) -> Table key = key.key if key is not None: - super(Table, self).__setitem__(key, _item) + dict.__setitem__(self, key, _item) m = re.match("(?s)^[^ ]*([ ]+).*$", self._trivia.indent) if not m: @@ -951,7 +952,7 @@ def raw_append(self, key, _item): # type: (Union[Key, str], Any) -> Table key = key.key if key is not None: - super(Table, self).__setitem__(key, _item) + dict.__setitem__(self, key, _item) return self @@ -962,7 +963,7 @@ def remove(self, key): # type: (Union[Key, str]) -> Table key = key.key if key is not None: - super(Table, self).__delitem__(key) + dict.__delitem__(self, key) return self @@ -992,42 +993,32 @@ def indent(self, indent): # type: (int) -> Table return self - def keys(self): # type: () -> Generator[str] - for k in self._value.keys(): - yield k - - def values(self): # type: () -> Generator[Item] - for v in self._value.values(): - yield v - - def items(self): # type: () -> Generator[Item] - for k, v in self._value.items(): - yield k, v - - def update(self, other): # type: (Dict) -> None - for k, v in other.items(): - self[k] = v - def get(self, key, default=None): # type: (Any, Optional[Any]) -> Any return self._value.get(key, default) - def __contains__(self, key): # type: (Union[Key, str]) -> bool - return key in self._value + def setdefault( + self, key, default=None + ): # type: (Union[Key, str], Any) -> Union[Item, Container] + super(Table, self).setdefault(key, default=default) + + return self[key] def __getitem__(self, key): # type: (Union[Key, str]) -> Item return self._value[key] def __setitem__(self, key, value): # type: (Union[Key, str], Any) -> None + fix_indent = key not in self + if not isinstance(value, Item): value = item(value) self._value[key] = value if key is not None: - super(Table, self).__setitem__(key, value) + dict.__setitem__(self, key, value) m = re.match("(?s)^[^ ]*([ ]+).*$", self._trivia.indent) - if not m: + if not m or not fix_indent: return indent = m.group(1) @@ -1042,11 +1033,14 @@ def __setitem__(self, key, value): # type: (Union[Key, str], Any) -> None def __delitem__(self, key): # type: (Union[Key, str]) -> None self.remove(key) - def __repr__(self): - return super(Table, self).__repr__() + def __len__(self): # type: () -> int + return len(self._value) - def __str__(self): - return str(self.value) + def __iter__(self): # type: () -> Iterator[str] + return iter(self._value) + + def __repr__(self): # type: () -> str + return repr(self._value) def _getstate(self, protocol=3): return ( @@ -1059,7 +1053,7 @@ def _getstate(self, protocol=3): ) -class InlineTable(Item, dict): +class InlineTable(Item, MutableMapping, dict): """ An inline table literal. """ @@ -1074,7 +1068,7 @@ def __init__( for k, v in self._value.body: if k is not None: - super(InlineTable, self).__setitem__(k.key, v) + dict.__setitem__(self, k.key, v) @property def discriminant(self): # type: () -> int @@ -1103,7 +1097,7 @@ def append(self, key, _item): # type: (Union[Key, str], Any) -> InlineTable key = key.key if key is not None: - super(InlineTable, self).__setitem__(key, _item) + dict.__setitem__(self, key, _item) return self @@ -1114,7 +1108,7 @@ def remove(self, key): # type: (Union[Key, str]) -> InlineTable key = key.key if key is not None: - super(InlineTable, self).__delitem__(key) + dict.__delitem__(self, key) return self @@ -1150,25 +1144,16 @@ def as_string(self): # type: () -> str return buf - def keys(self): # type: () -> Generator[str] - for k in self._value.keys(): - yield k - - def values(self): # type: () -> Generator[Item] - for v in self._value.values(): - yield v - - def items(self): # type: () -> Generator[Item] - for k, v in self._value.items(): - yield k, v - - def update(self, other): # type: (Dict) -> None - for k, v in other.items(): - self[k] = v - def get(self, key, default=None): # type: (Any, Optional[Any]) -> Any return self._value.get(key, default) + def setdefault( + self, key, default=None + ): # type: (Union[Key, str], Any) -> Union[Item, Container] + super(InlineTable, self).setdefault(key, default=default) + + return self[key] + def __contains__(self, key): # type: (Union[Key, str]) -> bool return key in self._value @@ -1182,7 +1167,8 @@ def __setitem__(self, key, value): # type: (Union[Key, str], Any) -> None self._value[key] = value if key is not None: - super(InlineTable, self).__setitem__(key, value) + dict.__setitem__(self, key, value) + if value.trivia.comment: value.trivia.comment = "" @@ -1202,8 +1188,14 @@ def __setitem__(self, key, value): # type: (Union[Key, str], Any) -> None def __delitem__(self, key): # type: (Union[Key, str]) -> None self.remove(key) + def __len__(self): # type: () -> int + return len(self._value) + + def __iter__(self): # type: () -> Iterator[str] + return iter(self._value) + def __repr__(self): - return super(InlineTable, self).__repr__() + return repr(self._value) def _getstate(self, protocol=3): return (self._value, self._trivia) diff --git a/tomlkit/parser.py b/tomlkit/parser.py index 49929954..b702088d 100644 --- a/tomlkit/parser.py +++ b/tomlkit/parser.py @@ -1138,9 +1138,11 @@ def _parse_table( ) if is_aot and i == len(name_parts[1:]) - 1: - table.append(_name, AoT([child], name=table.name, parsed=True)) + table.raw_append( + _name, AoT([child], name=table.name, parsed=True) + ) else: - table.append(_name, child) + table.raw_append(_name, child) table = child values = table.value @@ -1201,6 +1203,7 @@ def _peek_table(self): # type: () -> Tuple[bool, str] as well as whether it is part of an AoT. """ # we always want to restore after exiting this scope + table_name = "" with self._state(save_marker=True, restore=True): if self._current != "[": raise self.parse_error(