Skip to content

Commit

Permalink
Fix attribute errors: #185 #183
Browse files Browse the repository at this point in the history
  • Loading branch information
RicardoRei committed Jan 8, 2024
1 parent 45cb572 commit 97c5b9a
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion comet/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def predict(

scores = torch.cat([pred["scores"] for pred in predictions], dim=0).tolist()
if "metadata" in predictions[0]:
metadata = flatten_metadata([pred.metadata for pred in predictions])
metadata = flatten_metadata([pred["metadata"] for pred in predictions])
else:
metadata = []

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/models/test_ranking_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_training(self):
predictions = trainer.predict(
ckpt_path="best", dataloaders=dataloader, return_predictions=True
)
y_pos = torch.cat([p.scores for p in predictions], dim=0)
y_pos = torch.cat([p["scores"] for p in predictions], dim=0)

# Scores for "worse" translations
neg_translations = [
Expand All @@ -96,7 +96,7 @@ def test_training(self):
predictions = trainer.predict(
ckpt_path="best", dataloaders=dataloader, return_predictions=True
)
y_neg = torch.cat([p.scores for p in predictions], dim=0)
y_neg = torch.cat([p["scores"] for p in predictions], dim=0)
## This shouldn't break!
pearsonr(y_pos, y_neg)[0]

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/models/test_referenceless_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,5 @@ def test_training(self):
predictions = trainer.predict(
ckpt_path="best", dataloaders=dataloader, return_predictions=True
)
y_hat = torch.cat([p.scores for p in predictions], dim=0).tolist()
y_hat = torch.cat([p["scores"] for p in predictions], dim=0).tolist()
assert pearsonr(y_hat, y)[0] > 0.85
3 changes: 2 additions & 1 deletion tests/integration/models/test_regression_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,6 @@ def test_training(self):
predictions = trainer.predict(
ckpt_path="best", dataloaders=dataloader, return_predictions=True
)
y_hat = torch.cat([p.scores for p in predictions], dim=0).tolist()
breakpoint()
y_hat = torch.cat([p["scores"] for p in predictions], dim=0).tolist()
assert pearsonr(y_hat, y)[0] > 0.85
4 changes: 2 additions & 2 deletions tests/integration/models/test_unified_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_regression_with_references(self):
predictions = trainer.predict(
ckpt_path="best", dataloaders=dataloader, return_predictions=True
)
y_hat = torch.cat([p.scores for p in predictions], dim=0).tolist()
y_hat = torch.cat([p["scores"] for p in predictions], dim=0).tolist()
assert pearsonr(y_hat, y)[0] > 0.9

def test_regression_without_references(self):
Expand Down Expand Up @@ -153,6 +153,6 @@ def test_regression_without_references(self):
predictions = trainer.predict(
ckpt_path="best", dataloaders=dataloader, return_predictions=True
)
y_hat = torch.cat([p.scores for p in predictions], dim=0).tolist()
y_hat = torch.cat([p["scores"] for p in predictions], dim=0).tolist()
assert pearsonr(y_hat, y)[0] > 0.9

0 comments on commit 97c5b9a

Please sign in to comment.