Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config: do not sort positional arguments #594

Merged
merged 4 commits into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions thinc/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type
from typing import Iterable, Sequence
from typing import Iterable, Sequence, cast
from types import GeneratorType
from dataclasses import dataclass
from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH
Expand Down Expand Up @@ -340,7 +340,10 @@ def _sort(
account for subsections, which should always follow their parent.
"""
sort_map = {section: i for i, section in enumerate(self.section_order)}
sort_key = lambda x: (sort_map.get(x[0].split(".")[0], len(sort_map)), x[0])
sort_key = lambda x: (
sort_map.get(x[0].split(".")[0], len(sort_map)),
_mask_positional_args(x[0]),
)
return dict(sorted(data.items(), key=sort_key))

def _set_overrides(self, config: "ConfigParser", overrides: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -456,6 +459,20 @@ def from_disk(
return self.from_str(text, interpolate=interpolate, overrides=overrides)


def _mask_positional_args(name: str) -> List[Optional[str]]:
"""Create a section name representation that masks names
of positional arguments to retain their order in sorts."""

stable_name = cast(List[Optional[str]], name.split("."))

# Remove names of sections that are a positional argument.
for i in range(1, len(stable_name)):
if stable_name[i - 1] == "*":
stable_name[i] = None
svlandeg marked this conversation as resolved.
Show resolved Hide resolved

return stable_name


def try_load_json(value: str) -> Any:
"""Load a JSON string if possible, otherwise default to original value."""
try:
Expand Down
26 changes: 26 additions & 0 deletions thinc/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,3 +1437,29 @@ def test_config_overrides(greeting, value, expected):
assert "${vars.a}" in str_cfg
cfg = Config().from_str(str_cfg, overrides=overrides)
assert expected in str(cfg)


def test_arg_order_is_preserved():
str_cfg = """
[model]

[model.chain]
@layers = "chain.v1"

[model.chain.*.hashembed]
@layers = "HashEmbed.v1"
nO = 8
nV = 8

[model.chain.*.expand_window]
@layers = "expand_window.v1"
window_size = 1
"""

cfg = Config().from_str(str_cfg)
resolved = my_registry.resolve(cfg)
model = resolved["model"]["chain"]

# Fails when arguments are sorted, because expand_window
# is sorted before hashembed.
assert model.name == "hashembed>>expand_window"