Skip to content

Commit a90cb32

Browse files
committed
Added a simple example
1 parent ef3f286 commit a90cb32

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import argparse
16+
import time
17+
18+
import datasets
19+
import torch
20+
21+
from transformers import AutoModelForCausalLM, AutoTokenizer
22+
from transformers.generation import GenerationConfig
23+
24+
25+
MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
26+
DISPLAYED_SAMPLES = 3
27+
28+
29+
if __name__ == "__main__":
30+
# Parse args
31+
parser = argparse.ArgumentParser()
32+
parser.add_argument("--num-blocks", "-n", type=int, default=None)
33+
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
34+
parser.add_argument(
35+
"--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation"
36+
)
37+
parser.add_argument("--samples", type=int, default=500)
38+
args = parser.parse_args()
39+
40+
# Prepare model
41+
model = AutoModelForCausalLM.from_pretrained(
42+
MODEL_ID,
43+
attn_implementation=args.attn,
44+
dtype=torch.bfloat16,
45+
)
46+
model = model.cuda().eval()
47+
48+
# Prepare tokenizer and dataset
49+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
50+
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
51+
dataset = dataset.select(range(args.samples))
52+
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
53+
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
54+
55+
# Prepare generation config
56+
generation_config = GenerationConfig(
57+
max_new_tokens=512,
58+
use_cuda_graph=False, # Not supported for simple version
59+
eos_token_id=tokenizer.eos_token_id,
60+
pad_token_id=tokenizer.pad_token_id,
61+
do_sample=False,
62+
num_blocks=args.num_blocks,
63+
max_batch_tokens=args.max_batch_tokens,
64+
)
65+
66+
# Warmup iterations
67+
_ = model.generate_batch(
68+
inputs=simple_batch_inputs[: min(5, args.samples)],
69+
generation_config=generation_config,
70+
slice_inputs=True,
71+
)
72+
73+
# Actual batch generation
74+
print("--- Running CB Generation Example ---")
75+
start_time = time.time()
76+
batch_outputs = model.generate_batch(
77+
inputs=simple_batch_inputs,
78+
generation_config=generation_config,
79+
slice_inputs=True,
80+
)
81+
end_time = time.time()
82+
print("Done with batch generation.")
83+
84+
# Decode outputs
85+
token_count = 0
86+
for i, request in enumerate(batch_outputs):
87+
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=True)
88+
# Try to decode the output
89+
try:
90+
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=True)
91+
token_count += len(batch_outputs[request].generated_tokens[1:])
92+
except Exception as e:
93+
print(f"Decoding failed for request {request}: {e}")
94+
continue
95+
96+
# Display sample if asked
97+
if i < DISPLAYED_SAMPLES:
98+
print("-" * 20)
99+
print(f"{request} Input: {input_text}")
100+
if len(output_text) > 0:
101+
print(f"{request} Output: {output_text}")
102+
else:
103+
print(f"[WARN] {request} Output was empty!")
104+
105+
# Compute stats and maybe print them
106+
gen_time = end_time - start_time
107+
tok_per_sec = token_count / gen_time
108+
print("-" * 20)
109+
print("--- Finished CB Generation Example ---\n")
110+
print(f"CB generation took: {gen_time:.2f} seconds for {token_count} tokens. {tok_per_sec:.2f}tok/s")

0 commit comments

Comments
 (0)