Skip to content

Commit

Permalink
Improve formatting, printing
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed Nov 15, 2023
1 parent d01bbb5 commit 20d84f9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions repepo/baselines/icl/eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import evaluate
import numpy as np
from termcolor import colored
from torch.utils.data import DataLoader, RandomSampler
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
Expand Down Expand Up @@ -52,7 +53,7 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, dat
# Get the dataset by name
examples: List[format.Example] = get_dataset(data_args.dataset_name)
# Format the dataset
completions: List[format.Completion] = format.InstructionFormatter().apply(examples)
completions: List[format.Completion] = format.QAFormatter().apply(examples)
completions = format.FewShotPrompter().apply(completions)

# Initialize dataset
Expand Down Expand Up @@ -217,10 +218,11 @@ def eval_model():
predictions = [pred[len(prompt):] for prompt, pred in zip(prompts, predictions)]
references = [tokenizer.decode(ref, skip_special_tokens=True) for ref in batch['reference_ids']]
# Visualize predictions
print("Sample prompt: ", prompts[0])
print("Sample prompt: ", colored(prompts[0], 'yellow'))
print()
print("Predicted answer: ", predictions[0])
print("Reference answer: ", references[0])
print("Predicted answer: ", colored(predictions[0], 'light_blue'))
print()
print("Reference answer: ", colored(references[0], 'green'))
print()
metrics = metric_fns.compute_metrics(predictions, references)
for name, metric in metrics.items():
Expand Down
2 changes: 1 addition & 1 deletion repepo/data/dataset/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def apply(self, completions: List[Completion]) -> List[Completion]:
selected_idxes.append(idx)

# Concatenate completions
prompt = '\n'.join(few_shot_examples) + completion['prompt']
prompt = '\n'.join(few_shot_examples) + '\n' + completion['prompt']
response = completion['response']
output_completions.append(dict(
prompt=prompt,
Expand Down

0 comments on commit 20d84f9

Please sign in to comment.