Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance prompt formatting and optimize model performance #41

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions src/04_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

dotenv.load_dotenv()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
HF_TOKEN = os.getenv("HF_TOKEN")


def initialize_tokenizer(model_name: str, hf_token: str) -> AutoTokenizer:
Expand Down Expand Up @@ -156,8 +157,21 @@ def filter_quality(example: Dict[str, Any]) -> bool:
return special_char_ratio <= 0.2


def format_prompt(instruction: str, response: str = "") -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider consolidating the duplicate format_prompt functions into a single implementation

The duplicate format_prompt functions with slightly different implementations introduce unnecessary complexity and potential for bugs. Consolidate them into a single function:

def format_prompt(instruction: str, response: str = "") -> str:
    """Format the prompt for the model with consistent system context and structure."""
    return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Cutting Knowledge Date: December 2023
Today Date: 23 July 2024

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{response}<|eot_id|>"""

This consolidation:

  • Uses consistent dates rather than mixing hardcoded and dynamic dates
  • Maintains a single prompt template structure
  • Removes the risk of diverging implementations
  • Preserves all functionality while reducing code duplication

"""Format the prompt for the model with emphasis on positive and helpful responses."""
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 23 July 2024

You are a helpful and polite assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

Tell me about {instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{response}<|eot_id|>"""


def prepare_dataset(tokenizer):
"""Prepare dataset with Alpaca-style prompt template"""
dataset = load_dataset("leonvanbokhorst/synthetic-complaints-v2")

# Use full validation set and shuffle
Expand All @@ -181,15 +195,11 @@ def format_prompt(instruction: str, response: str = "") -> str:
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: {datetime.now().strftime('%d %b %Y')}

You are a helpful AI assistant.<|eot_id|>
Today Date: 23 July 2024

<|start_header_id|>user<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

{instruction}<|eot_id|>

<|start_header_id|>assistant<|end_header_id|>
Tell me about {instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{response}<|eot_id|>"""

Expand Down Expand Up @@ -250,6 +260,7 @@ def inference_example(model, tokenizer, prompt: str) -> str:
try:
device = model.device
formatted_prompt = format_prompt(prompt)
print(f"Formatted Prompt: {formatted_prompt}")

model_inputs = tokenizer(
formatted_prompt,
Expand All @@ -260,6 +271,8 @@ def inference_example(model, tokenizer, prompt: str) -> str:
return_token_type_ids=False,
).to(device)

print(f"Tokenized Input IDs: {model_inputs['input_ids']}")

with torch.no_grad():
outputs = model.generate(
input_ids=model_inputs["input_ids"],
Expand All @@ -276,7 +289,10 @@ def inference_example(model, tokenizer, prompt: str) -> str:
eos_token_id=tokenizer.eos_token_id,
)

print(f"Raw Output IDs: {outputs}")

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Decoded Response: {response}")
leonvanbokhorst marked this conversation as resolved.
Show resolved Hide resolved

# Extract only the assistant's response
if "<|start_header_id|>assistant<|end_header_id|>" in response:
Expand Down Expand Up @@ -372,9 +388,9 @@ def get_training_args() -> TrainingArguments:
# WSL2 has better memory management than Windows, allowing for more aggressive batching
# RTX 4090 has 24GB VRAM and WSL2 can utilize it more efficiently
# Effective batch size = per_device_batch * gradient_accumulation = 8 * 4 = 32
per_device_train_batch_size=8, # Optimal for 24GB VRAM under WSL2
per_device_eval_batch_size=8, # Match training batch size
gradient_accumulation_steps=4, # Accumulate for larger effective batch
per_device_train_batch_size=12, # Optimal for 24GB VRAM under WSL2
per_device_eval_batch_size=12, # Match training batch size
gradient_accumulation_steps=6, # Accumulate for larger effective batch
# Data Loading Optimization
# WSL2's Linux kernel provides better process management than Windows
# Can use more CPU cores efficiently without system instability
Expand Down
32 changes: 27 additions & 5 deletions src/11_llm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,13 @@ def __init__(
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
token=HF_TOKEN,
torch_dtype=torch.float16,
).to(
self.device
) # Explicitly move to device

self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=HF_TOKEN,
)

# Load reference model for comparison
Expand Down Expand Up @@ -120,8 +118,22 @@ def generate_response(
- top_p=0.9: Nucleus sampling for natural language variation
- do_sample=True: Enables probabilistic sampling
"""
# Use the new prompt template
formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: {datetime.now().strftime('%d %b %Y')}

You are a helpful assistant <|eot_id|>

<|start_header_id|>user <|end_header_id|>

{prompt} <|eot_id|>
<|start_header_id|>assistant <|end_header_id|>
"""

inputs = tokenizer(
prompt,
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
Expand Down Expand Up @@ -273,10 +285,20 @@ def calculate_metrics(

return metrics

def clear_model_caches(self):
"""
Clear model caches to free up memory.
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()

def run_benchmark(self, num_samples: int = 100) -> Dict[str, Any]:
leonvanbokhorst marked this conversation as resolved.
Show resolved Hide resolved
"""
Run comprehensive benchmark comparing models with detailed progress reporting.
"""
# Clear caches before starting the benchmark
self.clear_model_caches()

dataset = load_dataset("leonvanbokhorst/synthetic-complaints-v2")["train"]
test_samples = dataset.select(range(num_samples))

Expand Down Expand Up @@ -379,8 +401,8 @@ def run_benchmark(self, num_samples: int = 100) -> Dict[str, Any]:
# Print sample outputs
tqdm.write("\nSample Outputs:")
tqdm.write("Topic: " + sample["topic"])
tqdm.write("Fine-tuned (first 100 chars): " + ft_response[:100] + "...")
tqdm.write("Reference (first 100 chars): " + ref_response[:100] + "...")
tqdm.write("Fine-tuned (first 256 chars): " + ft_response[:256])
tqdm.write("Reference (first 256 chars): " + ref_response[:256])
tqdm.write("\n" + "=" * 50)

# Print warning if metrics are concerning
Expand Down
14 changes: 9 additions & 5 deletions src/poc/fine_tune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
tokenizer.pad_token_id = tokenizer.eos_token_id

# Load the model with exact same settings as training
device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Using device: {device}")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
Expand All @@ -36,15 +36,19 @@
print("Model type:", type(model))

# Add this before the generate_complaint function
PROMPT_TEMPLATE = """### Instruction:
{0}
PROMPT_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

### Response:
Cutting Knowledge Date: December 2023
Today Date: 23 July 2024

You are a helpful assistant<|eot_id|><|start_header_id|>user <|end_header_id|>

{0}<|eot_id|><|start_header_id|>assistant <|end_header_id|>
"""

def generate_complaint(prompt: str) -> str:
# Make the instruction more explicit
instruction = f"Write an angry complaint about {prompt}. Be specific about why you are upset and express your frustration clearly."
instruction = f"Tell me about {prompt}."

formatted_prompt = PROMPT_TEMPLATE.format(instruction)

Expand Down