Skip to content

Commit

Permalink
Merge pull request #41 from leonvanbokhorst/model-bench
Browse files Browse the repository at this point in the history
feat: Enhance prompt formatting and optimize model performance
  • Loading branch information
leonvanbokhorst authored Nov 7, 2024
2 parents 8b8d1e9 + 208b1e1 commit 4dc1cfe
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 21 deletions.
38 changes: 27 additions & 11 deletions src/04_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

dotenv.load_dotenv()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
HF_TOKEN = os.getenv("HF_TOKEN")


def initialize_tokenizer(model_name: str, hf_token: str) -> AutoTokenizer:
Expand Down Expand Up @@ -156,8 +157,21 @@ def filter_quality(example: Dict[str, Any]) -> bool:
return special_char_ratio <= 0.2


def format_prompt(instruction: str, response: str = "") -> str:
"""Format the prompt for the model with emphasis on positive and helpful responses."""
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Cutting Knowledge Date: December 2023
Today Date: 23 July 2024
You are a helpful and polite assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
Tell me about {instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{response}<|eot_id|>"""


def prepare_dataset(tokenizer):
"""Prepare dataset with Alpaca-style prompt template"""
dataset = load_dataset("leonvanbokhorst/synthetic-complaints-v2")

# Use full validation set and shuffle
Expand All @@ -181,15 +195,11 @@ def format_prompt(instruction: str, response: str = "") -> str:
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Cutting Knowledge Date: December 2023
Today Date: {datetime.now().strftime('%d %b %Y')}
You are a helpful AI assistant.<|eot_id|>
Today Date: 23 July 2024
<|start_header_id|>user<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
{instruction}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
Tell me about {instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{response}<|eot_id|>"""

Expand Down Expand Up @@ -250,6 +260,7 @@ def inference_example(model, tokenizer, prompt: str) -> str:
try:
device = model.device
formatted_prompt = format_prompt(prompt)
print(f"Formatted Prompt: {formatted_prompt}")

model_inputs = tokenizer(
formatted_prompt,
Expand All @@ -260,6 +271,8 @@ def inference_example(model, tokenizer, prompt: str) -> str:
return_token_type_ids=False,
).to(device)

print(f"Tokenized Input IDs: {model_inputs['input_ids']}")

with torch.no_grad():
outputs = model.generate(
input_ids=model_inputs["input_ids"],
Expand All @@ -276,7 +289,10 @@ def inference_example(model, tokenizer, prompt: str) -> str:
eos_token_id=tokenizer.eos_token_id,
)

print(f"Raw Output IDs: {outputs}")

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Decoded Response: {response}")

# Extract only the assistant's response
if "<|start_header_id|>assistant<|end_header_id|>" in response:
Expand Down Expand Up @@ -372,9 +388,9 @@ def get_training_args() -> TrainingArguments:
# WSL2 has better memory management than Windows, allowing for more aggressive batching
# RTX 4090 has 24GB VRAM and WSL2 can utilize it more efficiently
# Effective batch size = per_device_batch * gradient_accumulation = 8 * 4 = 32
per_device_train_batch_size=8, # Optimal for 24GB VRAM under WSL2
per_device_eval_batch_size=8, # Match training batch size
gradient_accumulation_steps=4, # Accumulate for larger effective batch
per_device_train_batch_size=12, # Optimal for 24GB VRAM under WSL2
per_device_eval_batch_size=12, # Match training batch size
gradient_accumulation_steps=6, # Accumulate for larger effective batch
# Data Loading Optimization
# WSL2's Linux kernel provides better process management than Windows
# Can use more CPU cores efficiently without system instability
Expand Down
32 changes: 27 additions & 5 deletions src/11_llm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,13 @@ def __init__(
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
token=HF_TOKEN,
torch_dtype=torch.float16,
).to(
self.device
) # Explicitly move to device

self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=HF_TOKEN,
)

# Load reference model for comparison
Expand Down Expand Up @@ -120,8 +118,22 @@ def generate_response(
- top_p=0.9: Nucleus sampling for natural language variation
- do_sample=True: Enables probabilistic sampling
"""
# Use the new prompt template
formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Cutting Knowledge Date: December 2023
Today Date: {datetime.now().strftime('%d %b %Y')}
You are a helpful assistant <|eot_id|>
<|start_header_id|>user <|end_header_id|>
{prompt} <|eot_id|>
<|start_header_id|>assistant <|end_header_id|>
"""

inputs = tokenizer(
prompt,
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
Expand Down Expand Up @@ -273,10 +285,20 @@ def calculate_metrics(

return metrics

def clear_model_caches(self):
"""
Clear model caches to free up memory.
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()

def run_benchmark(self, num_samples: int = 100) -> Dict[str, Any]:
"""
Run comprehensive benchmark comparing models with detailed progress reporting.
"""
# Clear caches before starting the benchmark
self.clear_model_caches()

dataset = load_dataset("leonvanbokhorst/synthetic-complaints-v2")["train"]
test_samples = dataset.select(range(num_samples))

Expand Down Expand Up @@ -379,8 +401,8 @@ def run_benchmark(self, num_samples: int = 100) -> Dict[str, Any]:
# Print sample outputs
tqdm.write("\nSample Outputs:")
tqdm.write("Topic: " + sample["topic"])
tqdm.write("Fine-tuned (first 100 chars): " + ft_response[:100] + "...")
tqdm.write("Reference (first 100 chars): " + ref_response[:100] + "...")
tqdm.write("Fine-tuned (first 256 chars): " + ft_response[:256])
tqdm.write("Reference (first 256 chars): " + ref_response[:256])
tqdm.write("\n" + "=" * 50)

# Print warning if metrics are concerning
Expand Down
14 changes: 9 additions & 5 deletions src/poc/fine_tune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
tokenizer.pad_token_id = tokenizer.eos_token_id

# Load the model with exact same settings as training
device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Using device: {device}")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
Expand All @@ -36,15 +36,19 @@
print("Model type:", type(model))

# Add this before the generate_complaint function
PROMPT_TEMPLATE = """### Instruction:
{0}
PROMPT_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
### Response:
Cutting Knowledge Date: December 2023
Today Date: 23 July 2024
You are a helpful assistant<|eot_id|><|start_header_id|>user <|end_header_id|>
{0}<|eot_id|><|start_header_id|>assistant <|end_header_id|>
"""

def generate_complaint(prompt: str) -> str:
# Make the instruction more explicit
instruction = f"Write an angry complaint about {prompt}. Be specific about why you are upset and express your frustration clearly."
instruction = f"Tell me about {prompt}."

formatted_prompt = PROMPT_TEMPLATE.format(instruction)

Expand Down

0 comments on commit 4dc1cfe

Please sign in to comment.