From 3c2f0f1e88c62bb4a00703db09f2d13da3372c27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 18 Feb 2022 18:04:38 +0100 Subject: [PATCH 1/4] Config: do not sort positional arguments In a configuration such as the following: [model] [model.chain] @layers = "chain.v1" [model.chain.*.hashembed] @layers = "HashEmbed.v1" nO = 8 nV = 8 [model.chain.*.expand_window] @layers = "expand_window.v1" window_size = 1 The positional arguments in the chain were sorted, changing the chain order. This change masks the names of positional arguments. Since `sorted` is a stable sort, the order of positional arguments is retained. --- thinc/config.py | 19 ++++++++++++++++++- thinc/tests/test_config.py | 26 ++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/thinc/config.py b/thinc/config.py index c779f9e4d..54c82422d 100644 --- a/thinc/config.py +++ b/thinc/config.py @@ -340,7 +340,10 @@ def _sort( account for subsections, which should always follow their parent. """ sort_map = {section: i for i, section in enumerate(self.section_order)} - sort_key = lambda x: (sort_map.get(x[0].split(".")[0], len(sort_map)), x[0]) + sort_key = lambda x: ( + sort_map.get(x[0].split(".")[0], len(sort_map)), + mask_positional_args(x[0]), + ) return dict(sorted(data.items(), key=sort_key)) def _set_overrides(self, config: "ConfigParser", overrides: Dict[str, Any]) -> None: @@ -456,6 +459,20 @@ def from_disk( return self.from_str(text, interpolate=interpolate, overrides=overrides) +def mask_positional_args(name: str) -> List[str]: + """Create a section name representation that masks names + of positional arguments to retain their order in sorts.""" + + stable_name = name.split(".") + + # Remove names of sections that are a positional arugment. + for i in range(1, len(stable_name)): + if stable_name[i - 1] == "*": + stable_name[i] = None + + return stable_name + + def try_load_json(value: str) -> Any: """Load a JSON string if possible, otherwise default to original value.""" try: diff --git a/thinc/tests/test_config.py b/thinc/tests/test_config.py index 2fd87b4c7..8d7aeb97e 100644 --- a/thinc/tests/test_config.py +++ b/thinc/tests/test_config.py @@ -1437,3 +1437,29 @@ def test_config_overrides(greeting, value, expected): assert "${vars.a}" in str_cfg cfg = Config().from_str(str_cfg, overrides=overrides) assert expected in str(cfg) + + +def test_arg_order_is_preserved(): + str_cfg = """ + [model] + + [model.chain] + @layers = "chain.v1" + + [model.chain.*.hashembed] + @layers = "HashEmbed.v1" + nO = 8 + nV = 8 + + [model.chain.*.expand_window] + @layers = "expand_window.v1" + window_size = 1 + """ + + cfg = Config().from_str(str_cfg) + resolved = my_registry.resolve(cfg) + model = resolved["model"]["chain"] + + # Fails when arguments are sorted, because expand_window + # is sorted before hashembed. + assert model.name == "hashembed>>expand_window" From 2fa6327c0f44f4a28bbd0ee5026b7873ad93c258 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 18 Feb 2022 18:23:03 +0100 Subject: [PATCH 2/4] Config: type fixes --- thinc/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/thinc/config.py b/thinc/config.py index 54c82422d..12d876420 100644 --- a/thinc/config.py +++ b/thinc/config.py @@ -1,5 +1,5 @@ from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type -from typing import Iterable, Sequence +from typing import Iterable, Sequence, cast from types import GeneratorType from dataclasses import dataclass from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH @@ -459,11 +459,11 @@ def from_disk( return self.from_str(text, interpolate=interpolate, overrides=overrides) -def mask_positional_args(name: str) -> List[str]: +def mask_positional_args(name: str) -> List[Optional[str]]: """Create a section name representation that masks names of positional arguments to retain their order in sorts.""" - stable_name = name.split(".") + stable_name = cast(List[Optional[str]], name.split(".")) # Remove names of sections that are a positional arugment. for i in range(1, len(stable_name)): From eb2313a2cd95261a87eab0b550ed79837836734f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 21 Feb 2022 12:17:55 +0100 Subject: [PATCH 3/4] config: Fix a typo Co-authored-by: Sofie Van Landeghem --- thinc/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thinc/config.py b/thinc/config.py index 12d876420..c078ddeeb 100644 --- a/thinc/config.py +++ b/thinc/config.py @@ -465,7 +465,7 @@ def mask_positional_args(name: str) -> List[Optional[str]]: stable_name = cast(List[Optional[str]], name.split(".")) - # Remove names of sections that are a positional arugment. + # Remove names of sections that are a positional argument. for i in range(1, len(stable_name)): if stable_name[i - 1] == "*": stable_name[i] = None From 9257d9c5155312683c7cf794ccd7e285f25b0e26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 21 Feb 2022 12:17:25 +0100 Subject: [PATCH 4/4] Make mask_positional_args private --- thinc/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thinc/config.py b/thinc/config.py index c078ddeeb..167f8fd97 100644 --- a/thinc/config.py +++ b/thinc/config.py @@ -342,7 +342,7 @@ def _sort( sort_map = {section: i for i, section in enumerate(self.section_order)} sort_key = lambda x: ( sort_map.get(x[0].split(".")[0], len(sort_map)), - mask_positional_args(x[0]), + _mask_positional_args(x[0]), ) return dict(sorted(data.items(), key=sort_key)) @@ -459,7 +459,7 @@ def from_disk( return self.from_str(text, interpolate=interpolate, overrides=overrides) -def mask_positional_args(name: str) -> List[Optional[str]]: +def _mask_positional_args(name: str) -> List[Optional[str]]: """Create a section name representation that masks names of positional arguments to retain their order in sorts."""