Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
updating requirements (#4860)
Browse files Browse the repository at this point in the history
* updating requirements

* lint
  • Loading branch information
Golovneva authored Nov 4, 2022
1 parent 5f41ba7 commit 49ea68c
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 125 deletions.
14 changes: 6 additions & 8 deletions projects/roscoe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,16 @@ bash projects/roscoe/roscoe_data/download_annotated.sh

### Baseline scoring
One-time setup:
It is higly recommended to run baseline scoring from a separate conda environment.
```bash
conda create --name roscoe_baselines python=3.8
conda activate roscoe_baselines
```
Follow BLEURT [installation quidelines](https://github.com/google-research/bleurt#installation).
Clone [BartScore repo](https://github.com/neulab/BARTScore) and update path in projects/roscoe/baselines/score.py. Install Bart_score requirements.
Upload fine-tuned [BART model](https://dl.fbaipublicfiles.com/parlai/projects/roscoe/fine_tuned_bartscore.pth).
Download PRISM [installation quidelines](https://github.com/thompsonb/prism) and download the model. Do not install requirements.
Install requirements to run baselines:
```bash
python -c "import nltk; nltk.download('punkt')"
python -c "import nltk; nltk.download('stopwords')"
pip install -r projects/roscoe/baselines/requirements.txt
pip install -r projects/roscoe/baselines/bart_requirements.txt
```
Follow BLEURT [installation quidelines](https://github.com/google-research/bleurt#installation)
Upload fine-tuned [BART model](https://dl.fbaipublicfiles.com/parlai/projects/roscoe/fine_tuned_bartscore.pth)

Then you can run baselines on all datasets, scores, and use of reference (when possible) with the following:
```bash
Expand Down
98 changes: 0 additions & 98 deletions projects/roscoe/baselines/bart_requirements.txt

This file was deleted.

6 changes: 2 additions & 4 deletions projects/roscoe/baselines/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
rouge-score
bert-score
ctc_score
rouge-score>=0.1.2
bert-score>=0.3.9
sentencepiece>=0.1.86
fairseq==0.9.0
sacrebleu>=1.4.8
torch>=1.4.0
ctc_score
2 changes: 1 addition & 1 deletion projects/roscoe/baselines/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def save_scores_map(path, metric_to_score):
type=str,
choices=[x.value for x in UseRef],
nargs="+",
default=[x.value for x in UseRef],
default=UseRef.NO.value,
help='do we want to generate reference-based or reference-free scores',
)
parser.add_argument(
Expand Down
33 changes: 19 additions & 14 deletions projects/roscoe/baselines/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from importlib.machinery import SourceFileLoader

BART_SCORE_REPO = "/path_to/BARTScore/"
PRISM_SCORE_REPO = "/path_to/SUM"
BART_SCORE_REPO = "/path_to/BARTScore"
PRISM_SCORE_REPO = "/path_to/prism"
BLEURT_SCORE_REPO = "/path_to/bleurt"

######### Base functionality
Expand Down Expand Up @@ -114,7 +114,9 @@ def get_scores(self, score_me):
@register_scorer([BLEURT])
class BleurtBaselineScorer(BaselineScorer):
def __init__(self):
self.scorer = bleurt_score.BleurtScorer(BLEURT_SCORE_REPO + "/test_checkpoint")
self.scorer = bleurt_score.BleurtScorer(
BLEURT_SCORE_REPO + "/bleurt/test_checkpoint"
)

def get_scores(self, score_me):
scores = self.scorer.score(
Expand Down Expand Up @@ -143,8 +145,6 @@ def get_scores(self, score_me):


######### BartScore (and its variants)
# Note: You might have to load some of its dependencies yourself.
# HuggingFace is the major one
# Second argument here should be path to `bart_score.py` of the BARTScore repo
try:
bart_score = SourceFileLoader(
Expand Down Expand Up @@ -190,8 +190,12 @@ def __init__(self):
self.scorer = BARTScorer(
device=DEFAULT_DEVICE, checkpoint='facebook/bart-large-cnn'
)
# Path here should be to fine tuned BART model from https://github.com/neulab/BARTScore#direct-use
self.scorer.load(BART_SCORE_REPO + "/bart_score_para_finetuned.pth")
try:
self.scorer.load(BART_SCORE_REPO + "/bart_score_para_finetuned.pth")
except FileNotFoundError:
raise FileNotFoundError(
f"Path here should be to fine tuned BART model from https://github.com/neulab/BARTScore#direct-use"
)
self.score_type = BARTSCORE_CNN_PARA_F


Expand All @@ -218,24 +222,25 @@ def load(self, path=None):
device=DEFAULT_DEVICE, checkpoint='facebook/bart-large-cnn'
)
# Path here to fine-tuend BART Model
self.scorer.load(BART_SCORE_REPO + "/train/reproduce/trained/bart_6000.pth")
try:
self.scorer.load(BART_SCORE_REPO + "/train/reproduce/trained/bart_6000.pth")
except FileNotFoundError:
raise FileNotFoundError(
f"Path here should be to fine tuned BART model from"
+ "https://dl.fbaipublicfiles.com/parlai/projects/roscoe/fine_tuned_bartscore.pth"
)
self.score_type = BARTSCORE_FINETUNED_F


######### Prism
# Prism deps (minimal set)
# sentencepiece>=0.1.86
# fairseq==0.9.0
# sacrebleu>=1.4.8#
# torch>=1.4.0
prism = SourceFileLoader("prism", PRISM_SCORE_REPO + "/prism.py").load_module()


@register_scorer([PRISM_AVG])
class PrismBaselineScorer(BaselineScorer):
def __init__(self):
self.scorer = prism.Prism(
model_dir=PRISM_SCORE_REPO + '/models/m39v1/',
model_dir=PRISM_SCORE_REPO + '/m39v1/',
lang='en',
)

Expand Down

0 comments on commit 49ea68c

Please sign in to comment.