Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed May 15, 2024
1 parent 51d0a0c commit 8ae4ba9
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 246 deletions.
9 changes: 4 additions & 5 deletions repepo/core/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def get_prediction(
pipeline: Pipeline,
example: Example,
evaluators: Sequence[Evaluator] = [LogitDifferenceEvaluator()],
n_generation: int = 0,
slim_results: bool = False,
) -> EvalPrediction:
positive_probs = pipeline.calculate_output_logprobs(
Expand All @@ -272,10 +271,10 @@ def get_prediction(
example_metrics = {}
for evaluator in evaluators:
example_metrics[evaluator.get_metric_name()] = evaluator.score_prediction(pred)
for _ in range(n_generation):
prompt = pipeline.build_generation_prompt(example.positive)
response = pipeline.generate(example.positive)
pred.generations.append(Completion(prompt=prompt, response=response))
# for _ in range(n_generation):
# prompt = pipeline.build_generation_prompt(example.positive)
# response = pipeline.generate(example.positive)
# pred.generations.append(Completion(prompt=prompt, response=response))
pred.metrics = example_metrics
return pred

Expand Down
27 changes: 26 additions & 1 deletion repepo/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __repr__(self) -> str:

@dataclass
class PipelineContext:
method: Literal["generate", "logprobs"]
method: Literal["generate", "logprobs", "forward"]
base_prompt: str
full_prompt: str
inputs: Any
Expand Down Expand Up @@ -118,6 +118,7 @@ def build_full_prompt(self, completion: Completion) -> str:
self.formatter.format_conversation(completion, self.conversation_history)
)

@torch.no_grad()
def generate(
self,
completion: Completion,
Expand Down Expand Up @@ -147,12 +148,36 @@ def generate(
return outputs_str.replace(base_prompt, "")
return outputs_str
raise RuntimeError("Should never get here")

@torch.no_grad()
def __call__(self, text: str):
"""A lightweight wrapper around model"""
# with ExitStack() as stack:
# for hook in self.hooks:
# stack.enter_context(hook(context))
# base_prompt = self.build_generation_prompt(completion)
# full_prompt = self.build_full_prompt(completion)
inputs: Any = self.tokenizer(text, return_tensors="pt")
inputs = inputs.to(self.model.device)
context = PipelineContext(
method="forward",
base_prompt="dummy",
full_prompt="dummy",
inputs=inputs,
pipeline=self,
)
with ExitStack() as stack:
for hook in self.hooks:
stack.enter_context(hook(context))
outputs = self.model(**inputs, output_hidden_states=False, return_dict=True)
return outputs

@torch.no_grad()
def calculate_output_logprobs(
self, completion: Completion, slim_results: bool = False
) -> TextProbs:
"""Calculate the logprobs for each token in the prompt + output"""

base_prompt = self.build_generation_prompt(completion)
full_prompt = self.build_full_prompt(completion)
inputs: Any = self.tokenizer(full_prompt, return_tensors="pt")
Expand Down
Loading

0 comments on commit 8ae4ba9

Please sign in to comment.