[WIP] config for EAGLE data generation #168
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds a comprehensive type-safe configuration generator for the EAGLE data generation pipeline, replacing the manual dictionary-based configuration with a robust dataclass architecture.
Key Changes
Type-Safe Configuration Architecture
@dataclassdecorator for type safetyConfiguration Components
The new system includes well-defined dataclasses for each configuration section:
PackageVersions: Tracks torch, vllm, transformers, and speculators versionsReproducibilityInfo: Command, cache key, GPU info, and package versionsModelConfig: Target model path, tensor parallel size, max model len, GPU memory utilization, hidden sizeDataConfig: Training data path, chat template, sequence length, max samples, num samples, seedHiddenStatesConfig: Layer IDs for hidden state extraction with descriptionGenerationConfig: Batch size and cache directoryFormatConfig: Output file pattern and detailed schema documentationExampleData: Real prompt/output examples from vLLM generationDataGenerationConfig: Complete configuration combining all componentsIntegration
VllmHiddenStatesGeneratorviaextract_config_from_generator()scripts/data_generation_offline.pywithout breaking changesCode Quality
Testing
Files Changed
New Files
src/speculators/data_generation/config_generator.py(345 lines)generate_config(): Direct config creation from parametersextract_config_from_generator(): Extract config from VllmHiddenStatesGeneratorgenerate_example_data(): Create example data from token IDstests/unit/data_generation/test_config_generator.py(395 lines, 18 tests)tests/unit/data_generation/__init__.py(new)Modified Files
scripts/data_generation_offline.pysave_config()to useextract_config_from_generator()save_config()functionsrc/speculators/data_generation/__init__.pyDataGenerationConfig,generate_config,extract_config_from_generatorExample Usage
Direct Config Generation
Extract from Generator
Benefits
Reproducibility: Full tracking of environment, packages, and generation parameters
Documentation: Self-documenting schema with actual examples
Testing
Unit Tests
All tests pass:
Example Config v2.0 Output
{ "version": "2.0", "generated_at": "2025-10-31T13:15:00.123456", "speculators_version": "0.2.0.dev67", "reproducibility": { "command": "data_generation_offline.py --target-model-path meta-llama/Llama-3.1-8B ...", "cache_key": "abc123def456...", "gpu": "NVIDIA H100 80GB HBM3", "packages": { "torch": "2.1.0+cu121", "vllm": "0.6.3", "transformers": "4.40.0", "speculators": "0.2.0.dev67" } }, "model": { "target_model_path": "meta-llama/Llama-3.1-8B", "tensor_parallel_size": 1, "max_model_len": 2048, "gpu_memory_utilization": 0.8, "hidden_size": 4096 }, "data": { "train_data_path": "sharegpt", "chat_template": "llama3", "seq_length": 2048, "max_samples": null, "num_samples": 1000, "seed": 0 }, "hidden_states": { "layer_ids": [2, 14, 24, 31], "description": "3 layers for EAGLE3 fusion, last layer for target logits" }, "generation": { "batch_size": 8, "cache_dir": "./cache" }, "format": { "file_pattern": "data_{idx}.pt", "schema": { "input_ids": { "dtype": "torch.long", "shape": "[seq_len]" }, "hidden_states": { "dtype": "list[torch.bfloat16]", "shape": "list of [seq_len, 4096]", "num_tensors": 4 }, "loss_mask": { "dtype": "torch.long", "shape": "[seq_len]" } } }, "example_prompt_token_ids": [1, 2, 3, 4, 5], "example_prompt_str": "The quick brown fox jumps over the lazy dog.", "example_output_token_ids": [10, 11, 12, 13, 14], "example_output_str": "This is an example of generated text." }