Skip to content

Commit

Permalink
feat: pipeline generation
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed May 5, 2024
1 parent 0f32c97 commit 9c69185
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 0 deletions.
31 changes: 31 additions & 0 deletions repepo/core/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import AbstractContextManager, ExitStack
from dataclasses import dataclass, field
from typing import Any, Literal, Protocol
from transformers.generation import GenerationConfig

import torch

Expand Down Expand Up @@ -117,6 +118,36 @@ def build_full_prompt(self, completion: Completion) -> str:
self.formatter.format_conversation(completion, self.conversation_history)
)

def generate(
self,
completion: Completion,
generation_config: GenerationConfig | None = None,
remove_base_prompt: bool = True,
) -> str:
"""Generate a completion for a given example"""
base_prompt = self.build_generation_prompt(completion)
inputs: Any = self.tokenizer(base_prompt, return_tensors="pt")
inputs = inputs.to(self.model.device)
context = PipelineContext(
method="generate",
base_prompt=base_prompt,
full_prompt=base_prompt,
inputs=inputs,
pipeline=self,
)
with ExitStack() as stack:
for hook in self.hooks:
stack.enter_context(hook(context))
outputs = self.model.generate(
**inputs,
generation_config=generation_config,
)[0]
outputs_str = self.tokenizer.decode(outputs, skip_special_tokens=True)
if remove_base_prompt:
return outputs_str.replace(base_prompt, "")
return outputs_str
raise RuntimeError("Should never get here")

@torch.no_grad()
def calculate_output_logprobs(
self, completion: Completion, slim_results: bool = False
Expand Down
20 changes: 20 additions & 0 deletions repepo/notebooks/generative_metrics.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test whether steering affects generative metric. "
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
12 changes: 12 additions & 0 deletions tests/core/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ def test_compute_quantiles_edge_cases():
torch.testing.assert_allclose(output, expected_output)


def test_basic_Pipeline_generate(
model: GPTNeoXForCausalLM, tokenizer: Tokenizer
) -> None:
pipeline = Pipeline(model, tokenizer)
res = pipeline.generate(
Completion(prompt="Respond A B C D", response="E"),
generation_config=None,
)
# pythia-70m generates nonsense, so just verify we get something
assert isinstance(res, str)
assert len(res) > 0

def test_basic_Pipeline_build_generation_prompt(
model: GPTNeoXForCausalLM, tokenizer: Tokenizer
) -> None:
Expand Down

0 comments on commit 9c69185

Please sign in to comment.