diff --git a/repepo/baselines/icl/eval.py b/repepo/baselines/icl/eval.py index 5155eebf..461d4b81 100644 --- a/repepo/baselines/icl/eval.py +++ b/repepo/baselines/icl/eval.py @@ -1,6 +1,7 @@ import torch import evaluate import numpy as np +from termcolor import colored from torch.utils.data import DataLoader, RandomSampler from transformers import AdamW, get_linear_schedule_with_warmup from tqdm import tqdm @@ -52,7 +53,7 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, dat # Get the dataset by name examples: List[format.Example] = get_dataset(data_args.dataset_name) # Format the dataset - completions: List[format.Completion] = format.InstructionFormatter().apply(examples) + completions: List[format.Completion] = format.QAFormatter().apply(examples) completions = format.FewShotPrompter().apply(completions) # Initialize dataset @@ -217,10 +218,11 @@ def eval_model(): predictions = [pred[len(prompt):] for prompt, pred in zip(prompts, predictions)] references = [tokenizer.decode(ref, skip_special_tokens=True) for ref in batch['reference_ids']] # Visualize predictions - print("Sample prompt: ", prompts[0]) + print("Sample prompt: ", colored(prompts[0], 'yellow')) print() - print("Predicted answer: ", predictions[0]) - print("Reference answer: ", references[0]) + print("Predicted answer: ", colored(predictions[0], 'light_blue')) + print() + print("Reference answer: ", colored(references[0], 'green')) print() metrics = metric_fns.compute_metrics(predictions, references) for name, metric in metrics.items(): diff --git a/repepo/data/dataset/format.py b/repepo/data/dataset/format.py index 5bc9291f..2ff87edd 100644 --- a/repepo/data/dataset/format.py +++ b/repepo/data/dataset/format.py @@ -94,7 +94,7 @@ def apply(self, completions: List[Completion]) -> List[Completion]: selected_idxes.append(idx) # Concatenate completions - prompt = '\n'.join(few_shot_examples) + completion['prompt'] + prompt = '\n'.join(few_shot_examples) + '\n' + completion['prompt'] response = completion['response'] output_completions.append(dict( prompt=prompt,