-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add onmt_config converter to facilitate switch (#69)
- Loading branch information
1 parent
d3f05fe
commit 316bbc2
Showing
3 changed files
with
149 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# How to switch from OpenNMT-py to EOLE? | ||
|
||
## Configuration conversion | ||
|
||
One of the main pillars of EOLE is the full revamping of the configuration structure and validation logic. That means OpenNMT-py configuration files are not supported by default. | ||
That being said, a conversion tool has been created to facilitate the transition: [`eole convert onmt_config`](https://github.com/eole-nlp/eole/blob/master/eole/bin/convert/convert_onmt_config.py) | ||
|
||
There are a few key things to know: | ||
- what was previous fully "flat" in OpenNMT-py configurations is now mostly nested in nested sections with specific scope such as `training`, `model`, `transforms_configs`; | ||
- some parameters were renamed, removed, or replaced by other logics, which makes the conversion script not 100% exhaustive; | ||
- the conversion script will log the remaining "unmapped settings", to facilitate fixing the last issues manually. | ||
|
||
## Model conversion | ||
|
||
Models trained with OpenNMT-py can technically be converted to be used with EOLE, but there is no automated tool for now. Feel free to get in touch via [Issues](https://github.com/eole-nlp/eole/issues) or [Discussions](https://github.com/eole-nlp/eole/discussions) if that is a blocker. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from eole.bin import BaseBin, register_bin | ||
import yaml | ||
from eole.config.run import TrainConfig | ||
from eole.config.training import TrainingConfig | ||
from eole.config.run import NestedAllTransformsConfig | ||
from collections import defaultdict, OrderedDict | ||
from eole.config.models import BaseModelConfig | ||
from rich import print | ||
from eole.config import reorder_fields | ||
|
||
|
||
# to be extended | ||
KEY_MAP = { | ||
"save_model": ["training", "model_path"], | ||
"encoder_type": [ | ||
"model", | ||
"encoder", | ||
"encoder_type", | ||
], # will set config["model"]["encoder"]["encoder_type"] | ||
"decoder_type": ["model", "decoder", "decoder_type"], | ||
"enc_layers": ["model", "encoder", "layers"], | ||
"dec_layers": ["model", "decoder", "layers"], | ||
"heads": ["model", "heads"], | ||
"model_dtype": ["training", "compute_dtype"], | ||
"max_relative_positions": ["model", "embeddins", "n_positions"], | ||
"num_kv": ["model", "heads_kv"], | ||
"pos_ffn_activation_fn": ["model", "mlp_activation_fn"], | ||
} | ||
|
||
|
||
def custom_mapping(k, v, config): | ||
# to be extended | ||
match k, v: | ||
case "max_relative_positions", -1: | ||
config["model"]["embeddings"]["position_encoding_type"] = "Rotary" | ||
case "max_relative_positions", -2: | ||
config["model"]["embeddings"]["position_encoding_type"] = "Alibi" | ||
case "max_relative_positions", _ if v > 0: | ||
config["model"]["embeddings"]["positions_encoding_type"] = "Relative" | ||
case "max_relative_positions", 0: | ||
config["model"]["embeddings"]["position_encoding_type"] = "Absolute" | ||
|
||
|
||
def set_nested_value(d, keys, value): | ||
current = d | ||
for key in keys[:-1]: | ||
if key not in current: | ||
current[key] = {} | ||
current = current[key] | ||
current[keys[-1]] = value | ||
|
||
|
||
def _to_dict(d): | ||
if isinstance(d, defaultdict) or isinstance(d, OrderedDict): | ||
d = {k: _to_dict(v) for k, v in d.items()} | ||
return d | ||
|
||
|
||
@register_bin(name="onmt_config") | ||
class OnmtConfigConverter(BaseBin): | ||
@classmethod | ||
def add_args(cls, parser): | ||
parser.add_argument("input", help="Input OpenNMT-py yaml config to convert.") | ||
parser.add_argument("output", help="Output converted yaml EOLE config.") | ||
parser.add_argument( | ||
"-a", | ||
"--architecture", | ||
choices=["rnn", "cnn", "transformer", "transformer_lm"], | ||
help="Setting this field will ensure conversion will be relevant to the architecture.", | ||
) | ||
|
||
@classmethod | ||
def run(cls, args): | ||
with open(args.input) as f: | ||
data_in = yaml.safe_load(f) | ||
# 3 potential levels of nesting | ||
new_config = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) | ||
if "data" in data_in.keys(): | ||
new_config["data"] = data_in["data"] | ||
if args.architecture is not None: | ||
new_config["model"]["architecture"] = args.architecture | ||
mapped = {"data"} | ||
# retrieve all transforms to populate only necessary transforms_configs | ||
all_transforms = set(data_in.get("transforms", [])) | ||
for _, corpus in data_in.get("data", {}).items(): | ||
all_transforms |= set(corpus.get("transforms", [])) | ||
for k, v in data_in.items(): | ||
if k in KEY_MAP.keys(): | ||
set_nested_value(new_config, KEY_MAP[k], v) | ||
mapped.add(k) | ||
if k not in mapped: | ||
if k in TrainConfig.model_fields: | ||
new_config[k] = v | ||
mapped.add(k) | ||
elif k in TrainingConfig.model_fields: | ||
new_config["training"][k] = v | ||
mapped.add(k) | ||
elif k in BaseModelConfig.model_fields: | ||
new_config["model"][k] = v | ||
mapped.add(k) | ||
else: | ||
for ( | ||
t_name, | ||
t_field, | ||
) in NestedAllTransformsConfig.model_fields.items(): | ||
if k in t_field.default.model_fields: | ||
if t_name in all_transforms: | ||
try: | ||
# handle edge cases like {src,tgt}_onmttok_kwargs | ||
# str to dict conversion | ||
v = eval(v) | ||
except Exception: | ||
pass | ||
new_config["transforms_configs"][t_name][k] = v | ||
mapped.add(k) | ||
unmapped = set(data_in.keys()) - mapped | ||
# reorder fields for readability | ||
new_config = reorder_fields(new_config) | ||
# convert to standard dict for proper yaml dump | ||
new_config = _to_dict(new_config) | ||
print(new_config) | ||
print("Saving converted config to:", args.output) | ||
with open(args.output, "w") as f: | ||
yaml.dump(new_config, f, default_flow_style=False, sort_keys=False) | ||
print("Remaining unmapped items:", unmapped) | ||
# test config validity | ||
new_config["data"] = {} | ||
try: | ||
TrainConfig(**new_config) | ||
except Exception as e: | ||
print(e) | ||
else: | ||
print("Converted config seems ok!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
"pyonmttok>=1.37,<2", | ||
"pyyaml", | ||
"rapidfuzz", | ||
"rich", | ||
"sacrebleu", | ||
"safetensors", | ||
"sentencepiece", | ||
|