Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple beams translate & evaluation with bleu #6

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
57 changes: 57 additions & 0 deletions calculate_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import json
import os
from typing import Any, Dict, List

import evaluate


def read_all_lines(file_name: str) -> List[Dict[str, Any]]:
all_lines = []
with open(file_name, "r", encoding="utf-8") as file_obj:
for line in file_obj:
json_obj = json.loads(line)
if json_obj["rank"] == 0:
all_lines.append(json_obj)

return all_lines


def main(dir_name: str):
all_files = os.listdir(dir_name)
sacrebleu = evaluate.load("sacrebleu")
chrf = evaluate.load("chrf")
for filename in sorted(all_files):
if filename.endswith("jsonl"):
filepath = f"{dir_name}/{filename}"
print(filepath)
all_lines_rank_0 = read_all_lines(filepath)
refs = [[row["ref"]] for row in all_lines_rank_0]
hyps = [row["hyp"] for row in all_lines_rank_0]
score_bleu = sacrebleu.compute(predictions=hyps, references=refs)
score_spbleu_101 = sacrebleu.compute(
predictions=hyps, references=refs, tokenize="flores101"
)
score_spbleu_200 = sacrebleu.compute(
predictions=hyps, references=refs, tokenize="flores200"
)
score_chrf = chrf.compute(predictions=hyps, references=refs)
score_chrf_word_order_2 = chrf.compute(
predictions=hyps, references=refs, word_order=2
) # word_order = 2
metrics = [
{"metric": "bleu"} | score_bleu,
{"metric": "spbleu-101"} | score_spbleu_101,
{"metric": "spbleu-200"} | score_spbleu_200,
{"metric": "chrf"} | score_chrf,
{"metric": "chrf++"} | score_chrf_word_order_2,
]
with open(
f"results_for_sales_representative_paniv/{filename.replace('jsonl', 'metrics')}",
"w",
) as file_obj:
for metric in metrics:
file_obj.write(json.dumps(metric) + "\n")


if __name__ == "__main__":
main("eval-beams-paniv")
45 changes: 45 additions & 0 deletions docs/translate_beams.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Usage Guide for translate_beams.py

## Links
1. Models list - https://github.com/Helsinki-NLP/UkrainianLT/blob/main/opus-mt-ukr-flores-devtest-big.md
2. Link to model(download) - https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zle/opusTCv20210807+bt_transformer-big_2022-03-13.zip

## Running the code
1. Download the model from link 2
2. Convert model to compatible for ctranslate2 format:
```sh
ct2-opus-mt-converter --model_dir opusTCv20210807+bt_transformer-big_2022-03-13 --output_dir opusTCv20210807+bt_transformer-big_2022-03-13_ct2_model
```
3. Load data from Flores and store src:
```python
from datasets import load_dataset
import csv


dataset = load_dataset("facebook/flores", "eng_Latn-ukr_Cyrl")
dev = dataset["dev"]
devtest = dataset["devtest"]
dev.to_csv("flores-dev.csv")
devtest.to_csv("flores-devtest.csv")

eng = devtest["sentence_eng_Latn"]
def write_to_csv(list_of_emails):
with open('flores-eng-devtest.csv', 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["eng_Latn-ukr_Cyrl"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fieldnames looks wrong.

writer.writeheader()
for domain in list_of_emails:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list_of_emails?

csvfile.write(domain + '\n')

write_to_csv(eng)

```
4. Preprocess eng src:
Exclude last 2 rows in else statement to exclude tokenization(as it is done in script)
```sh
./preprocess.sh eng ukr source.spm < flores-eng-devtest.csv > preprocessed_devtest.csv
```
1. Enjoy
```sh
python3 translate_beams.py --source-file-path=flores-devtest.csv --preprocessed-file-path=preprocessed_devtest.csv --target-file-path=target-opus.csv --translation-model-path=opus_ct2_model/ --tokenizer-model-path=./opus_ct2_model/source.spm --target-tokenizer-model-path=./opus_ct2_model/target.spm --validation-field-name=sentence_ukr_Cyrl --source-field-name=sentence_eng_Latn --src-prefix=">>ukr<<" --target-prefix=">>ukr<<" --beam-size=2
```
P.S. Postprocessing was implemented in script
128 changes: 128 additions & 0 deletions few_shot_gpt_translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import csv
import json
import time
from typing import Any, Dict, List

import evaluate
from openai import OpenAI
from tqdm import tqdm

SYSTEM_PROMPT_TEMPLATE = """
You are professional English to Ukrainian translator, complete the translation according domain examples
###

{translation_few_shot}

###

English: {original}
Translation:
"""


def write_to_file(
target_file_path: str,
source_sentences: List[str],
translation_sentences: List[str],
validation_sentences: List[str],
bleu_scores: List[float],
) -> None:
"""
Write the evaluation results to a file
Args:
target_file_path: path to the target file
source_sentences: list of source sentences
translation_sentences: list of translated sentences
validation_sentences: list of validation sentences
metrics_evaluated: dictionary of metrics evaluated
Returns:
None
"""
evaluation_entity: Dict[str, List[Any]] = {
"source": source_sentences,
"original_translation": validation_sentences,
"mt_translation": translation_sentences,
"bleu": bleu_scores,
}

evaluation_entity_list: List[Dict[str, Any]] = [
dict(zip(evaluation_entity, t)) for t in zip(*evaluation_entity.values())
]

with open(target_file_path, "w", encoding="utf-8") as fp_out:
writer = csv.DictWriter(fp_out, fieldnames=evaluation_entity_list[0].keys())
writer.writeheader()
writer.writerows(evaluation_entity_list)


def read_all_lines(file_name: str) -> List[Dict[str, Any]]:
all_lines = []
with open(file_name, "r", encoding="utf-8") as file_obj:
for line in file_obj:
json_obj = json.loads(line)
all_lines.append(json_obj)

return all_lines


def prepare_prompt(query_result: Dict[str, Any]):
original = query_result["orig"]
context_pairs: List[Dict[str, str]] = query_result["context"]
translation_few_shot = ""
for translation_pair in context_pairs:
translation_few_shot += f"English: {translation_pair['orig']}\n"
translation_few_shot += f"Translation: {translation_pair['trans']}\n\n"

return SYSTEM_PROMPT_TEMPLATE.format(
translation_few_shot=translation_few_shot, original=original
)


def main():
all_scores = []
start_time = time.time()
sacrebleu = evaluate.load("sacrebleu")
client = OpenAI(
api_key="*****",
)
all_query_results = read_all_lines(
"data/flores_context/context_floresdev_sbert_loose.jsonl"
)
for query_result in tqdm(all_query_results):
translation_prompt = prepare_prompt(query_result)
completion = client.chat.completions.create(
model="gpt-4",
# model="gpt-4-turbo-preview",
# model="gpt-3.5-turbo",
messages=[{"role": "system", "content": translation_prompt}],
)
score = {}
score["src"] = query_result["orig"]
score["ref"] = query_result["trans"]
score["hyp"] = completion.choices[0].message.content
score["sacrebleu"] = sacrebleu.compute(
predictions=[score["hyp"]], references=[score["ref"]]
)["score"]
all_scores.append(score)

references = [score["ref"] for score in all_scores]
translations = [score["hyp"] for score in all_scores]
source_sentences = [score["src"] for score in all_scores]
translation_sentences = [score["hyp"] for score in all_scores]
validation_sentences = [score["ref"] for score in all_scores]
sacrebleu_scores = [score["sacrebleu"] for score in all_scores]
write_to_file(
target_file_path="results/context_floresdev_sbert_loose_scores.csv",
source_sentences=source_sentences,
translation_sentences=translation_sentences,
validation_sentences=validation_sentences,
bleu_scores=sacrebleu_scores,
)
evaluation_result_sacrebleu = sacrebleu.compute(
predictions=translations, references=references
)
print(evaluation_result_sacrebleu)
print(f"Execution lasted for {time.time() - start_time}")


main()
Loading