Skip to content

Commit

Permalink
formatted code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardorei committed Jul 17, 2021
1 parent f7fffb8 commit 21a73e9
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*DS_Store
data/
lightning_logs/
wmt21/

.vscode
# Byte-compiled / optimized / DLL files
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Afrikaans, Albanian, Amharic, Arabic, Armenian, Assamese, Azerbaijani, Basque, B
COMET implements the [Pytorch-Lightning model interface](https://pytorch-lightning.readthedocs.io/en/1.3.8/common/lightning_module.html) which means that you'll need to initialize a trainer in order to run inference.

```python
import torch
from comet import download_model, load_from_checkpoint
from pytorch_lightning.trainer.trainer import Trainer
from torch.utils.data import DataLoader
Expand Down
14 changes: 6 additions & 8 deletions comet/cli/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def score_command() -> None:
parser.add_argument("--gpus", type=int, default=1)
parser.add_argument("--mc_dropout", type=Union[bool, int], default=False)
cfg = parser.parse_args()

if (cfg.references is None) and ("refless" not in cfg.model):
parser.error("{} requires -r/--references.".format(cfg.model))

Expand All @@ -82,7 +82,7 @@ def score_command() -> None:
with open(cfg.references()) as fp:
references = [l.strip() for l in fp.readlines()]
data = {"src": sources, "mt": translations, "ref": references}

data = [dict(zip(data, t)) for t in zip(*data.values())]
dataloader = DataLoader(
dataset=data,
Expand All @@ -91,7 +91,7 @@ def score_command() -> None:
num_workers=multiprocessing.cpu_count(),
)
trainer = Trainer(gpus=cfg.gpus, deterministic=True, logger=False)

if cfg.mc_dropout:
model.set_mc_dropout(cfg.mc_dropout)
predictions = trainer.predict(
Expand All @@ -106,19 +106,19 @@ def score_command() -> None:
print("Segment {}\tscore: {:.3f}\tvariance: {:.3f}".format(i, mean, std))
sample["COMET"] = mean
sample["variance"] = std

print("System score: {:.3f}".format(sum(mean_scores) / len(mean_scores)))
if isinstance(cfg.to_json, str):
with open(cfg.to_json, "w") as outfile:
json.dump(data, outfile, ensure_ascii=False, indent=4)
print("Predictions saved in: {}.".format(cfg.to_json))

else:
predictions = trainer.predict(
model, dataloaders=dataloader, return_predictions=True
)
predictions = torch.cat(predictions, dim=0).tolist()

for i, (score, sample) in enumerate(zip(predictions, data)):
print("Segment {}\tscore: {:.3f}".format(i, score))
sample["COMET"] = score
Expand All @@ -128,5 +128,3 @@ def score_command() -> None:
with open(cfg.to_json, "w") as outfile:
json.dump(data, outfile, ensure_ascii=False, indent=4)
print("Predictions saved in: {}.".format(cfg.to_json))


2 changes: 1 addition & 1 deletion comet/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


def load_from_checkpoint(checkpoint_path: str):
""" Loads models from a checkpoint path.
"""Loads models from a checkpoint path.
:param checkpoint_path: Path to a model checkpoint.
:return: Returns a COMET model.
Expand Down
4 changes: 2 additions & 2 deletions comet/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
else:
logger.warning(f"Path {load_weights_from_checkpoint} does not exist!")

self.mc_dropout = False # Flag used to control usage of MC Dropout
self.mc_dropout = False # Flag used to control usage of MC Dropout

def set_mc_dropout(self, value: bool):
self.mc_dropout = value
Expand Down Expand Up @@ -319,7 +319,7 @@ def predict_step(
mcd_mean = mcd_outputs.mean(dim=0)
mcd_std = mcd_outputs.std(dim=0)
return mcd_mean, mcd_std

return self(**batch)["score"].view(-1)

def validation_epoch_end(self, *args, **kwargs) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/models/test_ranking_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_training(self):
default_root_dir=DATA_PATH,
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0
progress_bar_refresh_rate=0,
)
model = RankingMetric(
encoder_model="BERT",
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/models/test_referenceless_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_training(self):
default_root_dir=DATA_PATH,
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0
progress_bar_refresh_rate=0,
)
model = ReferencelessRegression(
encoder_model="BERT",
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/models/test_regression_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_training(self):
default_root_dir=DATA_PATH,
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0
progress_bar_refresh_rate=0,
)
model = RegressionMetric(
encoder_model="BERT",
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/test_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ def tearDownClass(cls):
def test_download_from_s3(self):
data_path = download_model("wmt21-small-da-152012", saving_directory=DATA_PATH)
self.assertTrue(
os.path.exists(os.path.join(DATA_PATH, "wmt21-small-da-152012/hparams.yaml"))
os.path.exists(
os.path.join(DATA_PATH, "wmt21-small-da-152012/hparams.yaml")
)
)
self.assertTrue(
os.path.exists(os.path.join(DATA_PATH, "wmt21-small-da-152012/checkpoints/"))
os.path.exists(
os.path.join(DATA_PATH, "wmt21-small-da-152012/checkpoints/")
)
)
model = load_from_checkpoint(data_path)

0 comments on commit 21a73e9

Please sign in to comment.