Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

updating requirements #4860

Merged
merged 2 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions projects/roscoe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,16 @@ bash projects/roscoe/roscoe_data/download_annotated.sh
**Human annotated data: Pending data release approval**

One-time setup:
It is higly recommended to run baseline scoring from a separate conda environment.
```bash
conda create --name roscoe_baselines python=3.8
conda activate roscoe_baselines
```
Follow BLEURT [installation quidelines](https://github.com/google-research/bleurt#installation).
Clone [BartScore repo](https://github.com/neulab/BARTScore) and update path in projects/roscoe/baselines/score.py. Install Bart_score requirements.
Upload fine-tuned [BART model](https://dl.fbaipublicfiles.com/parlai/projects/roscoe/fine_tuned_bartscore.pth).
Download PRISM [installation quidelines](https://github.com/thompsonb/prism) and download the model. Do not install requirements.
Install requirements to run baselines:
```bash
python -c "import nltk; nltk.download('punkt')"
python -c "import nltk; nltk.download('stopwords')"
pip install -r projects/roscoe/baselines/requirements.txt
pip install -r projects/roscoe/baselines/bart_requirements.txt
```
Follow BLEURT [installation quidelines](https://github.com/google-research/bleurt#installation)
Upload fine-tuned [BART model](https://dl.fbaipublicfiles.com/parlai/projects/roscoe/fine_tuned_bartscore.pth)

Then you can run baselines on all datasets, scores, and use of reference (when possible) with the following:
```bash
Expand Down
98 changes: 0 additions & 98 deletions projects/roscoe/baselines/bart_requirements.txt

This file was deleted.

6 changes: 2 additions & 4 deletions projects/roscoe/baselines/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
rouge-score
bert-score
ctc_score
rouge-score>=0.1.2
bert-score>=0.3.9
sentencepiece>=0.1.86
fairseq==0.9.0
sacrebleu>=1.4.8
torch>=1.4.0
ctc_score
2 changes: 1 addition & 1 deletion projects/roscoe/baselines/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def save_scores_map(path, metric_to_score):
type=str,
choices=[x.value for x in UseRef],
nargs="+",
default=[x.value for x in UseRef],
default=UseRef.NO.value,
help='do we want to generate reference-based or reference-free scores',
)
parser.add_argument(
Expand Down
33 changes: 19 additions & 14 deletions projects/roscoe/baselines/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from importlib.machinery import SourceFileLoader

BART_SCORE_REPO = "/path_to/BARTScore/"
PRISM_SCORE_REPO = "/path_to/SUM"
BART_SCORE_REPO = "/path_to/BARTScore"
PRISM_SCORE_REPO = "/path_to/prism"
BLEURT_SCORE_REPO = "/path_to/bleurt"

######### Base functionality
Expand Down Expand Up @@ -114,7 +114,9 @@ def get_scores(self, score_me):
@register_scorer([BLEURT])
class BleurtBaselineScorer(BaselineScorer):
def __init__(self):
self.scorer = bleurt_score.BleurtScorer(BLEURT_SCORE_REPO + "/test_checkpoint")
self.scorer = bleurt_score.BleurtScorer(
BLEURT_SCORE_REPO + "/bleurt/test_checkpoint"
)

def get_scores(self, score_me):
scores = self.scorer.score(
Expand Down Expand Up @@ -143,8 +145,6 @@ def get_scores(self, score_me):


######### BartScore (and its variants)
# Note: You might have to load some of its dependencies yourself.
# HuggingFace is the major one
# Second argument here should be path to `bart_score.py` of the BARTScore repo
try:
bart_score = SourceFileLoader(
Expand Down Expand Up @@ -190,8 +190,12 @@ def __init__(self):
self.scorer = BARTScorer(
device=DEFAULT_DEVICE, checkpoint='facebook/bart-large-cnn'
)
# Path here should be to fine tuned BART model from https://github.com/neulab/BARTScore#direct-use
self.scorer.load(BART_SCORE_REPO + "/bart_score_para_finetuned.pth")
try:
self.scorer.load(BART_SCORE_REPO + "/bart_score_para_finetuned.pth")
except FileNotFoundError:
raise FileNotFoundError(
f"Path here should be to fine tuned BART model from https://github.com/neulab/BARTScore#direct-use"
)
self.score_type = BARTSCORE_CNN_PARA_F


Expand All @@ -218,24 +222,25 @@ def load(self, path=None):
device=DEFAULT_DEVICE, checkpoint='facebook/bart-large-cnn'
)
# Path here to fine-tuend BART Model
self.scorer.load(BART_SCORE_REPO + "/train/reproduce/trained/bart_6000.pth")
try:
self.scorer.load(BART_SCORE_REPO + "/train/reproduce/trained/bart_6000.pth")
except FileNotFoundError:
raise FileNotFoundError(
f"Path here should be to fine tuned BART model from"
+ "https://dl.fbaipublicfiles.com/parlai/projects/roscoe/fine_tuned_bartscore.pth"
)
self.score_type = BARTSCORE_FINETUNED_F


######### Prism
# Prism deps (minimal set)
# sentencepiece>=0.1.86
# fairseq==0.9.0
# sacrebleu>=1.4.8#
# torch>=1.4.0
prism = SourceFileLoader("prism", PRISM_SCORE_REPO + "/prism.py").load_module()


@register_scorer([PRISM_AVG])
class PrismBaselineScorer(BaselineScorer):
def __init__(self):
self.scorer = prism.Prism(
model_dir=PRISM_SCORE_REPO + '/models/m39v1/',
model_dir=PRISM_SCORE_REPO + '/m39v1/',
lang='en',
)

Expand Down