-
Notifications
You must be signed in to change notification settings - Fork 2
feat/context consistency and tl #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
adamimos
commented
Sep 30, 2025
- configs: interpolate vocab size from generator in penzai transformer/gruformer YAMLs
- run: inject vocab from generator; add BOS/EOS and seq_len consistency checks; route PyTorch models to torch trainer with n_ctx alignment
- torch: move batch to model device; ensure token dtype long in training and evaluation
- configs: add TransformerLens example (predictive_model, training, evaluation) and mess3_085 generator; include convenience top-level TL config
- docs: add TransformerLens example usage and notes
… checks; route PyTorch models to torch trainer with n_ctx alignment
…luation) and mess3_085 generator; include convenience top-level TL config
Code Review for PR #82: feat/context consistency and tlThank you for this PR! The implementation of TransformerLens support and context consistency checks adds valuable functionality. Here's my detailed review: ✅ Strengths
🔍 Suggestions for Improvement
|
|
@claude codex did this PR and it's ugly. You are so much better at it. Please do your work. |
|
Claude finished @adamimos's task —— View job Code Review for PR #82: feat/context consistency and tlI'll provide a comprehensive analysis of this PR's implementation of TransformerLens support and context consistency improvements. Todo List:
🎯 Overall AssessmentThis PR successfully integrates TransformerLens support into the codebase while adding robust consistency checks. The implementation is thoughtful and maintains backward compatibility while introducing new capabilities. 🏗️ Architecture & DesignStrengths
Areas for Improvement1. Exception Handling SpecificityIssue: The broad exception catching reduces debuggability and masks potential issues. # Current (run.py:97-98)
try:
import torch.nn as nn
except Exception:
nn = NoneRecommendation: Be specific about expected failures: try:
import torch.nn as nn
except ImportError:
nn = None # PyTorch not installed, use JAX pathThis pattern appears in multiple places (lines 54-56, 111-113) and should be consistently improved. 2. Configuration Validation ArchitectureIssue: The vocab size setting logic (lines 41-56) is complex with nested conditionals that are hard to test. Recommendation: Extract to a dedicated function: def inject_vocab_size(cfg: DictConfig, vocab_size: int) -> None:
"""Inject vocab size into model config, handling various config structures."""
inst = cfg.predictive_model.instance
# Direct attribute
if hasattr(inst, "vocab_size") or (isinstance(inst, DictConfig) and "vocab_size" in inst):
inst.vocab_size = vocab_size
# Penzai-style
if isinstance(inst, DictConfig) and "config" in inst and "vocab_size" in inst["config"]:
inst.config.vocab_size = vocab_size
# TransformerLens-style
if isinstance(inst, DictConfig) and "cfg" in inst and "d_vocab" in inst["cfg"]:
inst.cfg.d_vocab = vocab_size3. Device Handling RobustnessIssue: Device detection in evaluation (evaluate_pytorch_model.py:28-31) could fail on models without parameters. try:
device = next(model.parameters()).device
except Exception:
device = NoneRecommendation: Handle StopIteration explicitly: try:
device = next(model.parameters()).device
except StopIteration:
# Model has no parameters, default to CPU
device = torch.device('cpu')🔍 Detailed Code Analysisrun.py ChangesLines 41-56: The vocab size injection is well-thought-out but could benefit from logging when injection succeeds/fails for debugging. Line 42: Using model = typed_instantiate(cfg.predictive_model.instance, object) # Use object as base typeLines 65-74: Good defensive programming with assertions, but consider using if cfg.validation is not None and cfg.validation.sequence_len != cfg.training.sequence_len:
raise ValueError(
f"validation.sequence_len ({cfg.validation.sequence_len}) must match "
f"training.sequence_len ({cfg.training.sequence_len}) for consistent context"
)PyTorch Training/EvaluationGood Practices:
Potential Improvement: In both files, the device/dtype handling pattern is duplicated. Consider a utility function: def prepare_batch_for_model(model: torch.nn.Module, inputs: torch.Tensor, labels: torch.Tensor):
"""Prepare batch tensors for model processing."""
device = next(model.parameters()).device if any(model.parameters()) else torch.device('cpu')
inputs = inputs.to(device).long()
labels = labels.to(device).long()
return inputs, labelsConfiguration Filestransformer_lens_2L2H.yaml:
transformerlens.yaml:
|