Skip to content

Commit

Permalink
Merge pull request #24 from probcomp/async-llm
Browse files Browse the repository at this point in the history
Integrate AsyncLM models from genlm-backend
  • Loading branch information
benlebrun authored Feb 19, 2025
2 parents 7c68968 + 969c22d commit b4a9824
Show file tree
Hide file tree
Showing 13 changed files with 480 additions and 457 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
Expand Down
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ poetry run python examples/hard_constraints.py

If everything is working, you should see the model generate political news using words that are at most five letters long (e.g., "Dr. Jill Biden may still be a year away from the White House but she is set to make her first trip to the U.N. today.").

### vLLM backend

As of version 0.2.0, hfppl now supports vllm backend, which provides significant speedups over the HuggingFace backend. To install this backend, simply add the following:

```
poetry install --with vllm
```

## Modeling with LLaMPPL

A LLaMPPL program is a subclass of the `hfppl.Model` class.
Expand All @@ -47,18 +55,18 @@ class MyModel(Model):
# A stateful context object for the LLM, initialized with the prompt
self.context = LMContext(lm, prompt)
self.eos_token = lm.tokenizer.eos_token_id

# The forbidden letter
self.forbidden_tokens = set(i for (i, v) in enumerate(lm.vocab)
if forbidden_letter in v)

# The step method is used to perform a single 'step' of generation.
# This might be a single token, a single phrase, or any other division.
# Here, we generate one token at a time.
async def step(self):
# Condition on the next token *not* being a forbidden token.
await self.observe(self.context.mask_dist(self.forbidden_tokens), False)

# Sample the next token from the LLM -- automatically extends `self.context`.
token = await self.sample(self.context.next_token())

Expand Down Expand Up @@ -86,7 +94,7 @@ import asyncio
from hfppl import smc_steer

# Initialize the HuggingFace model
lm = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", auth_token=<YOUR_HUGGINGFACE_API_TOKEN_HERE>)
lm = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", backend='hf', auth_token=<YOUR_HUGGINGFACE_API_TOKEN_HERE>)

# Create a model instance
model = MyModel(lm, "The weather today is expected to be", "e")
Expand All @@ -105,4 +113,4 @@ sunny and cool.
hot and humid with a possibility of rain, which is not uncommon for this part of Mississippi.
```

Further documentation can be found at https://probcomp.github.io/hfppl.
Further documentation can be found at https://probcomp.github.io/hfppl.
63 changes: 63 additions & 0 deletions benchmark/benchmark_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Requires pytest and pytest-benchmark (pip install pytest pytest-benchmark)
Example usage: pytest benchmark/benchmark_backend.py --benchmark-only --benchmark-group-by=func -v
"""

import torch
import pytest
import asyncio
from hfppl.llms import CachedCausalLM
from examples.haiku import run_example as run_haiku
from examples.hard_constraints import run_example as run_hard_constraints

backends = [
'hf',
pytest.param(
'vllm',
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason="vLLM backend requires CUDA"
)
)
]

@pytest.fixture
def LLM(backend):
# Set lower gpu_memory_utilization in vllm so that we can fit both models on the GPU
kwargs = {'engine_opts' : {'gpu_memory_utilization' : 0.45}, 'cache_size' : 100} if backend == 'vllm' else {}
return CachedCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", backend=backend, **kwargs)

@pytest.mark.parametrize('backend', backends)
def test_hard_constraints_benchmark(LLM, benchmark, n_particles=20, max_tokens=50):
def run_with_clear_cache():
LLM.clear_cache()
return asyncio.run(
run_hard_constraints(LLM, max_tokens=max_tokens, n_particles=n_particles)
)

# warmup
run_with_clear_cache()

benchmark.pedantic(
run_with_clear_cache,
iterations=1,
rounds=3,
)

@pytest.mark.parametrize('backend', backends)
def test_haiku_benchmark(LLM, benchmark, n_particles=20):
def run_with_clear_cache():
LLM.clear_cache()
return asyncio.run(
run_haiku(LLM, poem_title='The beauty of testing', n_particles=n_particles)
)

# warmup
run_with_clear_cache()

benchmark.pedantic(
run_with_clear_cache,
iterations=1,
rounds=3,
)
Empty file added examples/__init__.py
Empty file.
15 changes: 4 additions & 11 deletions examples/grammar_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self.lm = lm
self.grammar = grammar
self.context = LMContext(lm, prompt)
self.vocab = self.lm.vocab
self.vocab = self.lm.str_vocab
self.eos_token_id = self.lm.tokenizer.eos_token_id

self.comp_engine = LarkCompletionEngine(
Expand Down Expand Up @@ -125,10 +125,9 @@ async def run_generation(
max_tokens: int = 32,
verbose: bool = False,
):
LLM = CachedCausalLM.from_pretrained(
args.model, auth_token=os.getenv("HF_AUTH_TOKEN")
)
LLM.batch_size = args.batch_size
LLM = CachedCausalLM.from_pretrained(args.model)
if LLM.backend == 'hf':
LLM.batch_size = args.batch_size
model = GrammarConstrainedSMC(
lm=LLM,
grammar=grammar,
Expand Down Expand Up @@ -177,12 +176,6 @@ async def run_generation(
default=5,
help="Number of particles to use in SMC",
)
parser.add_argument(
"--batch-size",
type=int,
default=16,
help="LLM batch size",
)
parser.add_argument(
"--max-tokens",
type=int,
Expand Down
93 changes: 51 additions & 42 deletions examples/haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,6 @@ def count_syllables(word, unknown_word_syllables=100):

return syllable_count


# Load the language model (llama2 if authorized, else mistral-7b).
if "HF_AUTH_TOKEN" in os.environ:
HF_AUTH_TOKEN = os.environ["HF_AUTH_TOKEN"]
LLM = CachedCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B", auth_token=HF_AUTH_TOKEN
)
else:
LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")

# Set batch size
LLM.batch_size = 40

# Example poems for the prompt.
# Authors:
# - Amy Lowell
Expand All @@ -56,7 +43,7 @@ def count_syllables(word, unknown_word_syllables=100):
# Note that not all of these follow the syllabic constraints of a Haiku; the goal is
# to encode a certain 'poetic style' but to leave the syllabic constraints to be enforced
# by the probabilistic program (enabling generalization to other syllabic constraints).
example_poems = """Example poems. Note how they tend to end on a somewhat surprising or otherwise satisfying note, and are not repetitive at the end.
EXAMPLE_POEMS = """Example poems. Note how they tend to end on a somewhat surprising or otherwise satisfying note, and are not repetitive at the end.
1. "Portrait"
Sweet smell of wet flowers
Expand All @@ -78,28 +65,16 @@ def count_syllables(word, unknown_word_syllables=100):
this deep in fall,
still not a butterfly."""

# Ask user for poem title (without newline)
poem_title = input("Enter a title for your Haiku: ")
poem_prompt = f"""{example_poems}
5. "{poem_title}"
"""

# Cache prompt for faster generation
LLM.cache_kv(LLM.tokenizer.encode(poem_prompt))

# Useful constants
NEWLINE_TOKEN, EOS_TOKEN = LLM.vocab.index("\n"), LLM.tokenizer.eos_token_id


# LLaMPPL model
class Haiku(Model):

def __init__(self, prompt, syllable_pattern=[5, 7, 5]):
def __init__(self, LLM, prompt, syllable_pattern=[5, 7, 5]):
super().__init__()
self.context = LMContext(LLM, prompt, 0.7)
self.context = LMContext(LLM, prompt)
self.syllable_pattern = syllable_pattern
self.previous_string = str(self.context)
self.newline_token = LLM.str_vocab.index("\n")
self.eos_token = LLM.tokenizer.eos_token_id

async def step(self):
self.previous_string = str(self.context)
Expand All @@ -121,12 +96,12 @@ async def step(self):

# If there are no more lines, finish
if not self.syllable_pattern:
await self.observe(self.context.next_token(), EOS_TOKEN)
await self.observe(self.context.next_token(), self.eos_token)
self.finish()
return

# Otherwise, observe a line break
await self.observe(self.context.next_token(), NEWLINE_TOKEN)
await self.observe(self.context.next_token(), self.newline_token)

# Print current result
print(str(self.context))
Expand All @@ -141,16 +116,50 @@ def string_for_serialization(self):
)
return s.replace("\n", "/")

async def run_example(LLM, poem_title, syllable_pattern=[5, 7, 5], n_particles=20, ess_threshold=0.5):
# Construct prompt
prompt = f"""{EXAMPLE_POEMS}
# Run inference
SYLLABLES_PER_LINE = [5, 7, 5] # [5, 3, 5] for a Lune
particles = asyncio.run(
smc_standard(
Haiku(poem_prompt, SYLLABLES_PER_LINE), 20, 0.5, "html", "results/haiku.json"
5. "{poem_title}"
"""

# Cache the key value vectors for the prompt
LLM.cache_kv(LLM.tokenizer.encode(prompt))

# Initialize the Model
haiku_model = Haiku(LLM, prompt, syllable_pattern)

# Run inference
particles = await smc_standard(
haiku_model, n_particles, ess_threshold, "html", "results/haiku.json"
)
)

# print("--------")
# for i, particle in enumerate(particles):
# print(f"Poem {i} (weight {particle.weight}):")
# print(f"{particle.context}")
return particles

def main():
# Load the language model.
# Mistral is an open model; to use a model with restricted access, like LLaMA 3,
# authenticate using the Huggingface CLI.
LLM = CachedCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
# LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")

# Set batch size if using HuggingFace backend
if LLM.backend == 'hf':
LLM.batch_size = 40

# Get poem title from user
poem_title = input("Enter a title for your Haiku: ")

syllables_per_line = [5, 7, 5] # [5, 3, 5] for a Lune

# Run the example
particles = asyncio.run(run_example(LLM, poem_title, syllable_pattern=syllables_per_line))

print("--------")
for i, particle in enumerate(particles):
print(f"\nPoem {i} (weight {particle.weight}):")
print(f"{particle.context}")

if __name__ == "__main__":
main()

Loading

0 comments on commit b4a9824

Please sign in to comment.