Skip to content

Commit

Permalink
feat(trainer): Enhance SIFT trainer initialization and metrics handling
Browse files Browse the repository at this point in the history
- Add proper device/dtype handling for Mac M3 compatibility
- Initialize optimizer (AdamW) and scheduler (OneCycleLR)
- Add robust checkpoint saving with optimizer state
- Implement EMA loss tracking and visualization
- Add comprehensive metrics computation and tracking
- Fix tensor dtype handling for embedding indices
- Add adaptive stopping criterion with uncertainty tracking

Key changes:
- SIFTTrainer initialization (lines 65-158)
- Metrics computation and visualization (lines 583-620)
- Tensor dtype handling (lines 646-655)
- Checkpoint management (lines 552-573)

This commit improves training stability and adds better monitoring
capabilities while ensuring compatibility across different hardware
configurations.
  • Loading branch information
leonvanbokhorst committed Nov 17, 2024
1 parent 6389f0d commit 3411341
Show file tree
Hide file tree
Showing 9 changed files with 573 additions and 277 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ data/
Llama-3.2-1B-Instruct-Complaint/
dataset/
cache/
checkpoints/



Expand Down
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ The codebase explores several key areas of model adaptation:
- Graceful shutdown
- Comprehensive logging


13. **Semantic Router** (`15_semantic_router.py`)
- Semantic-based query routing system
- Core components:
Expand Down Expand Up @@ -333,7 +332,7 @@ The codebase explores several key areas of model adaptation:
- Motion data types:
- Orientation
- Position
- Velocity
- Velocity
- Acceleration
- Angular velocity
- Features:
Expand All @@ -352,7 +351,6 @@ The codebase explores several key areas of model adaptation:
- Missing frame handling
- Exception tracking


**PyTorch Experiments** (`src/poc/`)
- Neural Network Architecture Studies
- ResNet Implementation (`resnet.py`, `resnet_02.py`, `resnet_03.py`)
Expand Down Expand Up @@ -391,7 +389,7 @@ The codebase explores several key areas of model adaptation:
- Early stopping
- Learning rate scheduling

** 19. SIFT Fine-Tuning** (`19_llm_fine_tuning_sift.py`)
**19. SIFT Fine-Tuning** (`19_llm_fine_tuning_sift.py`)
- Selective Instance Fine-Tuning (SIFT) Implementation
- Real-time model adaptation
- Semantic similarity search
Expand Down Expand Up @@ -431,7 +429,6 @@ The codebase explores several key areas of model adaptation:
- Real-time visualization
- Training summaries


## Development Standards

- PEP 8 and black formatting
Expand Down
107 changes: 89 additions & 18 deletions src/19_llm_fine_tuning_sift.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from sift.sift_trainer import SIFTTrainer
from sift.sift_visualization import SIFTVisualizer
import sys
import numpy as np

# Configure logging
logging.basicConfig(level=logging.ERROR, format="%(message)s")
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -49,12 +50,11 @@ def main():
metrics = MetricsComputer()
visualizer = SIFTVisualizer()

# Initialize trainer
# Initialize trainer with enhanced parameters
trainer = SIFTTrainer(
llm_name="unsloth/Llama-3.2-1B",
embedding_model="BAAI/bge-large-en-v1.5",
index_dir="cache/faiss",
max_length=512,
)

# Sample test prompts and training data
Expand Down Expand Up @@ -139,20 +139,31 @@ def format_step(step: int, metrics: dict) -> str:
# Training loop
last_prompt_stats = None

# Add these near the start of the main() function after initializing trainer
metrics_tracker = {
'global_losses': [],
'prompt_losses': [],
'uncertainties': [],
'steps_per_prompt': []
}

for prompt_idx, prompt in enumerate(test_prompts):
try:
# Disable all other loggers
logging.getLogger("sift.sift_trainer").setLevel(logging.WARNING)
logging.getLogger("tqdm").setLevel(logging.WARNING)

selected_examples = trainer.select_examples(prompt, training_data)
# Use enhanced selection method
selected_examples = trainer.select_examples_sift(prompt, training_data)
logger.info(f"Selected {len(selected_examples)} examples for fine-tuning")
if not selected_examples:
logger.warning("No examples selected - skipping prompt")
continue

# Re-enable logging and reprint header
logging.getLogger("sift.sift_trainer").setLevel(logging.INFO)

clear_screen()
#clear_screen()

# Reprint header after clear
logger.info(format_last_prompt_summary(last_prompt_stats))
Expand All @@ -166,36 +177,57 @@ def format_step(step: int, metrics: dict) -> str:
if step_metrics is None:
continue

current_loss = min(
step_metrics.get("loss", float("inf")), max_loss_threshold
)
prev_loss = (
prompt_stats["losses"][-1] if prompt_stats["losses"] else None
)

# Compute kernel-based uncertainty with stability check
uncertainties = []
for _ in range(3): # Multiple measurements for stability
uncertainty = trainer.compute_kernel_uncertainty(
prompt, selected_examples[:i + 1]
)
if uncertainty is not None and not np.isnan(uncertainty):
uncertainties.append(uncertainty)

if not uncertainties:
continue

uncertainty = np.median(uncertainties) # Use median for robustness
current_loss = min(step_metrics.get("loss", float("inf")), max_loss_threshold)

# Update tracking
metrics_tracker['global_losses'].append(current_loss)
metrics_tracker['uncertainties'].append(uncertainty)

prompt_stats["losses"].append(current_loss)
prompt_stats["prompt_best"] = min(
prompt_stats["prompt_best"], current_loss
)
prompt_stats["prompt_best"] = min(prompt_stats["prompt_best"], current_loss)

if current_loss < global_best_loss:
global_best_loss = current_loss
# Save best model checkpoint
trainer.save_checkpoint(f"checkpoints/best_model_{prompt_idx}")

# Only log every few steps
# Log progress
if i % 1 == 0:
logger.info(
format_step(
i,
{
"loss": current_loss,
"prev_loss": prev_loss,
"prev_loss": prompt_stats["losses"][-2] if len(prompt_stats["losses"]) > 1 else None,
"global_best": global_best_loss,
"uncertainty": step_metrics.get("uncertainty", 0),
"uncertainty": uncertainty,
},
)
)

# Enhanced stopping check with stability verification
if (i >= min_examples and
trainer.should_stop_adaptive(uncertainty, i, alpha=0.1) and
len(prompt_stats["losses"]) >= 3 and
np.std(prompt_stats["losses"][-3:]) < 0.1):
logger.info(f"Stopping early at step {i} due to convergence")
break

except Exception as e:
logger.error(f"Error in training step: {str(e)}")
continue

# Store and log summary
Expand Down Expand Up @@ -241,6 +273,45 @@ def format_step(step: int, metrics: dict) -> str:

logger.info("=" * 89)

# After training loop, add visualization
if prompt_stats_history:
metrics_data = {
"loss": [stat["losses"] for stat in prompt_stats_history],
"uncertainty": [
stat.get("uncertainties", []) for stat in prompt_stats_history
],
}

visualizer.plot_metrics_over_time(metrics_data)
visualizer.plot_uncertainty_vs_performance(
uncertainty=[
stat.get("uncertainties", [])[-1]
for stat in prompt_stats_history
if stat.get("uncertainties")
],
performance=[stat["prompt_best"] for stat in prompt_stats_history],
save_path="uncertainty_vs_performance.png",
)
visualizer.plot_adaptive_stopping(
metrics={
"uncertainty": [
u
for stat in prompt_stats_history
for u in stat.get("uncertainties", [])
],
"compute": list(
range(
sum(
len(stat.get("uncertainties", []))
for stat in prompt_stats_history
)
)
),
},
alpha=0.1,
save_path="adaptive_stopping.png",
)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/sift/sift_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"Please install it with: pip install zstandard"
)

logging.basicConfig(level=logging.ERROR)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


Expand Down
2 changes: 1 addition & 1 deletion src/sift/sift_embedding_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Tuple

logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)
logger.setLevel(logging.INFO)


def clear_embedding_cache(cache_dir: str = "cache/embeddings"):
Expand Down
73 changes: 0 additions & 73 deletions src/sift/sift_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,76 +11,3 @@ class EvaluationMetrics:
bits_per_byte: float
perplexity: float
uncertainty: float


class MetricsComputer:
"""Compute and track evaluation metrics."""

def __init__(self):
self.metrics_history = []

def compute_bits_per_byte(
self, logits: torch.Tensor, labels: torch.Tensor
) -> float:
"""Compute bits per byte metric."""
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), labels.view(-1), reduction="mean"
)
return (loss / np.log(2)).item()

def compute_perplexity(self, logits: torch.Tensor, labels: torch.Tensor) -> float:
"""Compute perplexity."""
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), labels.view(-1), reduction="mean"
)
return torch.exp(loss).item()

def compute_metrics(
self,
outputs: Dict[str, Any],
labels: Optional[torch.Tensor] = None,
uncertainty: Optional[float] = None
) -> Dict[str, float]:
"""Compute training metrics."""
metrics = {
'loss': outputs['loss'],
'uncertainty': uncertainty if uncertainty is not None else 0.0
}

if outputs['logits'] is not None and labels is not None:
# Add any additional metrics computation here
pass

return metrics

def get_metrics_summary(self) -> Dict[str, List[float]]:
"""Get summary of tracked metrics."""
return {
"bits_per_byte": [m.bits_per_byte for m in self.metrics_history],
"perplexity": [m.perplexity for m in self.metrics_history],
"uncertainty": [m.uncertainty for m in self.metrics_history],
}


class AdaptiveStoppingMetrics:
"""Track metrics for adaptive stopping."""

def __init__(self, alpha: float = 0.1):
self.alpha = alpha
self.uncertainty_history = []
self.compute_history = []

def should_stop(self, uncertainty: float, step: int) -> bool:
"""Determine if training should stop based on uncertainty."""
if step < 5: # Minimum number of steps
return False

# Add your stopping criterion here
return uncertainty < self.alpha

def get_stopping_summary(self) -> Dict[str, list]:
"""Get summary of stopping metrics."""
return {
"uncertainty": self.uncertainty_history,
"compute": self.compute_history,
}
23 changes: 16 additions & 7 deletions src/sift/sift_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

class MetricsComputer:
def __init__(self):
Expand Down Expand Up @@ -94,24 +95,32 @@ def __init__(self, alpha: float = 0.1, window_size: int = 5):
self.stopping_points = []
self.uncertainty_history = []

def should_stop(self, uncertainty: float, step: int) -> bool:
"""Determine if training should stop based on uncertainty."""
def should_stop(self, uncertainty: float, step: int, loss_history: List[float]) -> bool:
"""Enhanced stopping criterion using both uncertainty and loss trends."""
try:
self.uncertainty_history.append(uncertainty)

if step < self.window_size:
return False

# Get recent uncertainties
# Get recent metrics
recent_uncertainties = self.uncertainty_history[-self.window_size:]
avg_uncertainty = sum(recent_uncertainties) / len(recent_uncertainties)
recent_losses = loss_history[-self.window_size:]

# Stop if average uncertainty is below threshold
should_stop = avg_uncertainty < self.alpha
# Compute trends
uncertainty_trend = np.polyfit(range(len(recent_uncertainties)), recent_uncertainties, 1)[0]
loss_trend = np.polyfit(range(len(recent_losses)), recent_losses, 1)[0]

# Stop if both metrics are stable or improving
should_stop = (
uncertainty_trend <= 0 and # Uncertainty is decreasing
loss_trend <= 0 and # Loss is decreasing
np.mean(recent_uncertainties) < self.alpha
)

if should_stop:
self.stopping_points.append(step)
logger.info(f"Stopping at step {step} with uncertainty {avg_uncertainty:.4f}")
logger.info(f"Stopping at step {step} with uncertainty {np.mean(recent_uncertainties):.4f}")

return should_stop

Expand Down
Loading

0 comments on commit 3411341

Please sign in to comment.