-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4a9b605
commit e50609e
Showing
10 changed files
with
290 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
FROM pure/python:3.7-cuda10.0-base | ||
|
||
WORKDIR /app | ||
|
||
# Install the library | ||
RUN pip install --upgrade pip && \ | ||
pip install --no-cache-dir unbabel-comet==1.0.1 --use-feature=2020-resolver | ||
|
||
# Run a warmup query | ||
COPY scripts/warmup.sh warmup.sh | ||
RUN sh warmup.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Rei et al. (2020) | ||
|
||
## Publication | ||
[COMET: A Neural Framework for MT Evaluation](https://aclanthology.org/2020.emnlp-main.213/) | ||
|
||
## Repositories | ||
https://github.com/Unbabel/COMET | ||
|
||
## Available Models | ||
The available models are COMET using the reference-based `wmt20-comet-da` model or the reference-free `wmt20-comet-qe-da` model. | ||
|
||
- COMET: | ||
- Description: A machine translation evaluation metric. | ||
- Name: `rei2020-comet` | ||
- Usage: | ||
```python | ||
from repro.models.rei2020 import COMET | ||
model = COMET() | ||
# reference-based | ||
inputs = [ | ||
{"candidate": "The candidate to score", "sources": ["The source text"], "reference": ["The reference"]} | ||
] | ||
macro, micro = model.predict_batch(inputs) | ||
|
||
# reference-free | ||
inputs = [ | ||
{"candidate": "The candidate to score", "sources": ["The source text"]} | ||
] | ||
macro, micro = model.predict_batch(inputs) | ||
``` | ||
The `macro` and `micro` are the averaged and input-level COMET scores. | ||
The reference-based key is `"comet"` and the reference-free key is `"comet-src"`. | ||
|
||
## Implementation Notes | ||
Only 1 source document and 1 reference translation are supported. | ||
|
||
## Docker Information | ||
- Image name: `danieldeutsch/rei2020:1.0` | ||
- Build command: | ||
```shell script | ||
repro setup rei2020 | ||
``` | ||
- Requires network: Yes, the code still makes a network request even if the models are pre-cached. | ||
|
||
## Testing | ||
```shell script | ||
repro setup rei2020 | ||
pytest models/rei2020/tests | ||
``` | ||
|
||
## Status | ||
- [x] Regression unit tests pass | ||
See [here](https://github.com/danieldeutsch/repro/actions/runs/1567865901) | ||
- [ ] Correctness unit tests pass | ||
- [ ] Model runs on full test dataset | ||
- [ ] Predictions approximately replicate results reported in the paper | ||
- [ ] Predictions exactly replicate results reported in the paper | ||
|
||
## Changelog |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Examples taken from https://github.com/Unbabel/COMET | ||
echo -e "Dem Feuer konnte Einhalt geboten werden\nSchulen und Kindergärten wurden eröffnet." >> src.de | ||
echo -e "The fire could be stopped\nSchools and kindergartens were open" >> hyp1.en | ||
echo -e "They were able to control the fire.\nSchools and kindergartens opened" >> ref.en | ||
|
||
# Run a reference-based version | ||
comet-score -s src.de -t hyp1.en -r ref.en --gpus 0 | ||
|
||
# Run a reference-free version | ||
comet-score -s src.de -t hyp1.en --gpus 0 --model wmt20-comet-qe-da | ||
|
||
rm src.de hyp1.en ref.en |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import unittest | ||
from parameterized import parameterized | ||
|
||
from repro.models.rei2020 import COMET | ||
from repro.testing import assert_dicts_approx_equal, get_testing_device_parameters | ||
|
||
|
||
class TestRei2020Models(unittest.TestCase): | ||
@parameterized.expand(get_testing_device_parameters()) | ||
def test_comet_regression(self, device: int): | ||
# Tests the examples from the Github repo | ||
model = COMET(device=device) | ||
|
||
inputs = [ | ||
{ | ||
"sources": ["Dem Feuer konnte Einhalt geboten werden"], | ||
"candidate": "The fire could be stopped", | ||
"references": ["They were able to control the fire."], | ||
}, | ||
{ | ||
"sources": ["Schulen und Kindergärten wurden eröffnet."], | ||
"candidate": "Schools and kindergartens were open", | ||
"references": ["Schools and kindergartens opened"], | ||
}, | ||
] | ||
|
||
expected_macro = {"comet": 0.5529156997799873} | ||
expected_micro = [{"comet": 0.19016893208026886}, {"comet": 0.9156624674797058}] | ||
actual_macro, actual_micro = model.predict_batch(inputs) | ||
|
||
assert_dicts_approx_equal(expected_macro, actual_macro) | ||
assert len(expected_micro) == len(actual_micro) | ||
for expected, actual in zip(expected_micro, actual_micro): | ||
assert_dicts_approx_equal(expected, actual, abs=1e-4) | ||
|
||
@parameterized.expand(get_testing_device_parameters()) | ||
def test_comet_src_regression(self, device: int): | ||
# Tests the examples from the Github repo | ||
model = COMET(device=device) | ||
|
||
inputs = [ | ||
{ | ||
"sources": ["Dem Feuer konnte Einhalt geboten werden"], | ||
"candidate": "The fire could be stopped", | ||
}, | ||
{ | ||
"sources": ["Schulen und Kindergärten wurden eröffnet."], | ||
"candidate": "Schools and kindergartens were open", | ||
}, | ||
] | ||
|
||
expected_macro = {"comet-src": 0.35479202680289745} | ||
expected_micro = [ | ||
{"comet-src": 0.00831037387251854}, | ||
{"comet-src": 0.7012736797332764}, | ||
] | ||
actual_macro, actual_micro = model.predict_batch(inputs) | ||
|
||
assert_dicts_approx_equal(expected_macro, actual_macro) | ||
assert len(expected_micro) == len(actual_micro) | ||
for expected, actual in zip(expected_micro, actual_micro): | ||
assert_dicts_approx_equal(expected, actual, abs=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import os | ||
|
||
VERSION = "1.0" | ||
MODEL_NAME = os.path.basename(os.path.dirname(__file__)) | ||
DOCKERHUB_REPO = f"danieldeutsch/{MODEL_NAME}" | ||
DEFAULT_IMAGE = f"{DOCKERHUB_REPO}:{VERSION}" | ||
AUTOMATICALLY_PUBLISH = True | ||
|
||
from repro.models.rei2020.model import COMET | ||
from repro.models.rei2020.setup import Rei2020SetupSubcommand |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import json | ||
import logging | ||
from typing import Dict, List, Tuple, Union | ||
|
||
from repro.common import util | ||
from repro.common.docker import DockerContainer | ||
from repro.common.io import write_to_text_file | ||
from repro.data.types import MetricsType, TextType | ||
from repro.models import Model | ||
from repro.models.rei2020 import DEFAULT_IMAGE, MODEL_NAME | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@Model.register(f"{MODEL_NAME}-comet") | ||
class COMET(Model): | ||
def __init__(self, image: str = DEFAULT_IMAGE, device: int = 0): | ||
self.image = image | ||
self.device = device | ||
|
||
def predict( | ||
self, | ||
candidate: TextType, | ||
sources: List[TextType] = None, | ||
references: List[TextType] = None, | ||
batch_size: int = None, | ||
) -> MetricsType: | ||
return self.predict_batch( | ||
[{"candidate": candidate, "sources": sources, "references": references}], | ||
batch_size=batch_size, | ||
)[0] | ||
|
||
def predict_batch( | ||
self, | ||
inputs: List[Dict[str, Union[TextType, List[TextType]]]], | ||
batch_size: int = None, | ||
) -> Tuple[MetricsType, List[MetricsType]]: | ||
logger.info(f"Calculating COMET for {len(inputs)} inputs") | ||
|
||
batch_size = batch_size or 8 | ||
|
||
candidates = [inp["candidate"] for inp in inputs] | ||
sources_list = [inp["sources"] if "sources" in inp else None for inp in inputs] | ||
references_list = [ | ||
inp["references"] if "references" in inp else None for inp in inputs | ||
] | ||
|
||
# If any input has a reference, they all must | ||
def _has_references(references: List[TextType]) -> bool: | ||
return references is not None and len(references) > 0 | ||
|
||
has_references = any( | ||
_has_references(references) for references in references_list | ||
) | ||
if has_references: | ||
if not all(_has_references(references) for references in references_list): | ||
raise Exception( | ||
f"COMET requires all or none of the inputs have references" | ||
) | ||
|
||
# COMET only supports single sources and references | ||
sources = util.check_for_single_texts(sources_list) | ||
if has_references: | ||
references = util.check_for_single_texts(references_list) | ||
|
||
# Ensure all are strings or None | ||
candidates = [util.flatten(candidate) for candidate in candidates] | ||
sources = [util.flatten(source) for source in sources] | ||
if has_references: | ||
references = [util.flatten(reference) for reference in references] | ||
|
||
with DockerContainer(self.image) as backend: | ||
host_src_file = f"{backend.host_dir}/src.txt" | ||
container_src_file = f"{backend.container_dir}/src.txt" | ||
write_to_text_file(sources, host_src_file) | ||
|
||
hyp_filename = f"hyp1.txt" | ||
host_hyp_file = f"{backend.host_dir}/{hyp_filename}" | ||
container_hyp_file = f"{backend.container_dir}/{hyp_filename}" | ||
write_to_text_file(candidates, host_hyp_file) | ||
|
||
host_ref_file = f"{backend.host_dir}/ref.txt" | ||
container_ref_file = f"{backend.container_dir}/ref.txt" | ||
if has_references: | ||
write_to_text_file(references, host_ref_file) | ||
|
||
host_output_file = f"{backend.host_dir}/output.json" | ||
container_output_file = f"{backend.container_dir}/output.json" | ||
|
||
cuda = self.device != -1 | ||
commands = [] | ||
if cuda: | ||
commands.append(f"export CUDA_VISIBLE_DEVICES={self.device}") | ||
num_gpus = 1 | ||
else: | ||
num_gpus = 0 | ||
|
||
score_command = ( | ||
f"comet-score " | ||
f"-s {container_src_file} " | ||
f"-t {container_hyp_file} " | ||
f"--gpus {num_gpus} " | ||
f"--batch_size {batch_size} " | ||
f"--to_json {container_output_file}" | ||
) | ||
if has_references: | ||
score_command += f" -r {container_ref_file} --model wmt20-comet-da" | ||
else: | ||
score_command += " --model wmt20-comet-qe-da" | ||
commands.append(score_command) | ||
|
||
command = " && ".join(commands) | ||
backend.run_command(command=command, cuda=cuda, network_disabled=False) | ||
|
||
output_dict = json.load(open(host_output_file, "r")) | ||
outputs = output_dict[container_hyp_file] | ||
|
||
metric = "comet" if has_references else "comet-src" | ||
|
||
micro = [] | ||
for output in outputs: | ||
micro.append({metric: output["COMET"]}) | ||
macro = util.average_dicts(micro) | ||
return macro, micro |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from repro import MODELS_ROOT | ||
from repro.commands.subcommand import SetupSubcommand | ||
from repro.common.docker import BuildDockerImageSubcommand | ||
from repro.models.rei2020 import DEFAULT_IMAGE, MODEL_NAME | ||
|
||
|
||
@SetupSubcommand.register(MODEL_NAME) | ||
class Rei2020SetupSubcommand(BuildDockerImageSubcommand): | ||
def __init__(self) -> None: | ||
super().__init__(f"{MODELS_ROOT}/{MODEL_NAME}", DEFAULT_IMAGE) |