forked from salvacarrion/autonmt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path1_custom_model.py
80 lines (64 loc) · 3.61 KB
/
1_custom_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import datetime
from autonmt.bundle.report import generate_report
from autonmt.modules.models import Transformer
from autonmt.preprocessing import DatasetBuilder
from autonmt.toolkits import AutonmtTranslator
from autonmt.vocabularies import Vocabulary
from autonmt.preprocessing.processors import preprocess_pairs, preprocess_lines, normalize_lines
from tokenizers.normalizers import NFKC, Strip
# Preprocess functions
normalize_fn = lambda x: normalize_lines(x, seq=[NFKC(), Strip()])
preprocess_raw_fn = lambda x, y: preprocess_pairs(x, y, normalize_fn=normalize_fn, min_len=1, max_len=None, remove_duplicates=False, shuffle_lines=True)
preprocess_splits_fn = lambda x, y: preprocess_pairs(x, y, normalize_fn=normalize_fn)
preprocess_predict_fn = lambda x: preprocess_lines(x, normalize_fn=normalize_fn)
def main():
# Create preprocessing for training
builder = DatasetBuilder(
# Root folder for datasets
base_path="datasets/translate",
# Set of datasets, languages, training sizes to try
datasets=[
{"name": "europarl", "languages": ["es-en", "fr-en", "de-en"], "sizes": [("original", None), ("100k", 100000)]},
{"name": "scielo/health", "languages": ["es-en"], "sizes": [("100k", 100000)], "split_sizes": (None, 1000, 1000)},
],
# Set of subword models and vocab sizes to try
encoding=[
{"subword_models": ["bpe", "unigram+bytes"], "vocab_sizes": [8000, 16000, 32000]},
{"subword_models": ["bytes", "char", "char+bytes"], "vocab_sizes": [1000]},
],
# Preprocessing functions
preprocess_raw_fn=preprocess_raw_fn,
preprocess_splits_fn=preprocess_splits_fn,
# Additional args
merge_vocabs=False,
).build(make_plots=False, force_overwrite=False)
# Create preprocessing for training and testing
tr_datasets = builder.get_train_ds()
ts_datasets = builder.get_test_ds()
# Train & Score a model for each dataset
scores = []
for train_ds in tr_datasets:
# Instantiate vocabs and model
src_vocab = Vocabulary(max_tokens=150).build_from_ds(ds=train_ds, lang=train_ds.src_lang)
trg_vocab = Vocabulary(max_tokens=150).build_from_ds(ds=train_ds, lang=train_ds.trg_lang)
model = Transformer(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)
# Define trainer
runs_dir = train_ds.get_runs_path(toolkit="autonmt")
run_name = train_ds.get_run_name(run_prefix="mymodel")
trainer = AutonmtTranslator(model=model, src_vocab=src_vocab, trg_vocab=trg_vocab,
runs_dir=runs_dir, run_name=run_name)
# Train model
wandb_params = None #dict(project="autonmt", entity="salvacarrion")
trainer.fit(train_ds, max_epochs=5, learning_rate=0.001, optimizer="adam", batch_size=128, seed=1234,
patience=10, num_workers=10, strategy="ddp", save_best=True, save_last=True, wandb_params=wandb_params)
# Test model
m_scores = trainer.predict(ts_datasets, metrics={"bleu", "chrf", "bertscore"}, beams=[1, 5], load_checkpoint="best",
preprocess_fn=preprocess_predict_fn, eval_mode="compatible", force_overwrite=False)
scores.append(m_scores)
# Make report and print it
output_path = f".outputs/autonmt/{str(datetime.datetime.now())}"
df_report, df_summary = generate_report(scores=scores, output_path=output_path, plot_metric="translations.beam1.sacrebleu_bleu_score")
print("Summary:")
print(df_summary.to_string(index=False))
if __name__ == "__main__":
main()