Skip to content

Conversation

@rahul-tuli
Copy link
Collaborator

This PR adds a comprehensive type-safe configuration generator for the EAGLE data generation pipeline, replacing the manual dictionary-based configuration with a robust dataclass architecture.

Key Changes

Type-Safe Configuration Architecture

  • Dataclass-based design using Python's @dataclass decorator for type safety
  • Reproducibility tracking including package versions, GPU info, and cache keys
  • Real example generation from vLLM generator output (not hardcoded)
  • JSON-serializable output with full schema documentation

Configuration Components

The new system includes well-defined dataclasses for each configuration section:

  • PackageVersions: Tracks torch, vllm, transformers, and speculators versions
  • ReproducibilityInfo: Command, cache key, GPU info, and package versions
  • ModelConfig: Target model path, tensor parallel size, max model len, GPU memory utilization, hidden size
  • DataConfig: Training data path, chat template, sequence length, max samples, num samples, seed
  • HiddenStatesConfig: Layer IDs for hidden state extraction with description
  • GenerationConfig: Batch size and cache directory
  • FormatConfig: Output file pattern and detailed schema documentation
  • ExampleData: Real prompt/output examples from vLLM generation
  • DataGenerationConfig: Complete configuration combining all components

Integration

  • Integrates seamlessly with VllmHiddenStatesGenerator via extract_config_from_generator()
  • Works with existing scripts/data_generation_offline.py without breaking changes
  • Generates config v2.0 with enhanced metadata (upgraded from v1.0)
  • Fully backward compatible - all v1.0 fields are preserved

Code Quality

Testing

  • ✅ Test coverage includes:
    • Package version detection from environment
    • GPU info detection (CUDA/CPU)
    • Hidden size extraction from model configs (including text_config fallback)
    • Example data generation from real token IDs
    • Full config generation flow
    • Generator extraction workflow
    • JSON serialization and deserialization
    • Dataclass creation and validation

Files Changed

New Files

  • src/speculators/data_generation/config_generator.py (345 lines)

    • Main config generator with all dataclass definitions
    • generate_config(): Direct config creation from parameters
    • extract_config_from_generator(): Extract config from VllmHiddenStatesGenerator
    • generate_example_data(): Create example data from token IDs
    • Helper functions for GPU info and hidden size detection
  • tests/unit/data_generation/test_config_generator.py (395 lines, 18 tests)

    • Comprehensive test suite with 18 unit tests
    • Tests all dataclass creation paths
    • Tests config generation with real and mocked components
    • Tests JSON serialization
    • Tests generator extraction workflow
  • tests/unit/data_generation/__init__.py (new)

    • Test module initialization

Modified Files

  • scripts/data_generation_offline.py

    • Updated save_config() to use extract_config_from_generator()
    • Removed manual dictionary-based config construction (v1.0)
    • Now generates config v2.0 with enhanced metadata
    • Reduced from 48 lines to 33 lines in save_config() function
  • src/speculators/data_generation/__init__.py

    • Export config generator API: DataGenerationConfig, generate_config, extract_config_from_generator
    • Makes config generator easily accessible from package

Example Usage

Direct Config Generation

from speculators.data_generation.config_generator import generate_config

config = generate_config(
    target_model_path="meta-llama/Llama-3.1-8B",
    train_data_path="sharegpt",
    chat_template="llama3",
    seq_length=2048,
    layer_ids=[2, 14, 24, 31],
    tensor_parallel_size=1,
    max_model_len=2048,
    gpu_memory_utilization=0.8,
    batch_size=8,
    cache_dir="./cache",
    num_samples=1000,
    example_prompt_token_ids=[1, 2, 3, 4, 5],
    example_output_token_ids=[10, 11, 12],
)

# Save to JSON
import json
with open("data_config.json", "w") as f:
    json.dump(config.to_dict(), f, indent=2)

Extract from Generator

from speculators.data_generation.config_generator import extract_config_from_generator

# Extract config from vLLM generator (includes real examples)
config = extract_config_from_generator(
    generator=vllm_generator,
    train_data_path="sharegpt",
    chat_template="llama3",
    seq_length=2048,
    batch_size=8,
    cache_dir="./cache",
    num_samples=1000,
)

# Config includes real example prompt/output from vLLM generation
print(f"Example prompt: {config.example_prompt_str}")
print(f"Example output: {config.example_output_str}")

Benefits

  1. Reproducibility: Full tracking of environment, packages, and generation parameters

    • Package versions (torch, vllm, transformers, speculators)
    • GPU info (device name, count)
    • Command line used to generate data
    • Cache key for deterministic preprocessing
  2. Documentation: Self-documenting schema with actual examples

    • Real prompt/output token IDs from vLLM generation
    • Detailed schema documentation for output format
    • Clear field descriptions

Testing

Unit Tests

All tests pass:

$ python -m pytest tests/unit/data_generation/test_config_generator.py -v
============================= test session starts ==============================
collected 18 items

tests/unit/data_generation/test_config_generator.py::test_package_versions_dataclass PASSED
tests/unit/data_generation/test_config_generator.py::test_package_versions_from_environment PASSED
tests/unit/data_generation/test_config_generator.py::test_reproducibility_info_creation PASSED
tests/unit/data_generation/test_config_generator.py::test_format_config_default PASSED
tests/unit/data_generation/test_config_generator.py::test_get_gpu_info PASSED
tests/unit/data_generation/test_config_generator.py::test_package_versions_speculators_version PASSED
tests/unit/data_generation/test_config_generator.py::test_get_hidden_size_from_model PASSED
tests/unit/data_generation/test_config_generator.py::test_get_hidden_size_from_model_text_config PASSED
tests/unit/data_generation/test_config_generator.py::test_get_hidden_size_from_model_error PASSED
tests/unit/data_generation/test_config_generator.py::test_generate_example_data PASSED
tests/unit/data_generation/test_config_generator.py::test_generate_config_basic PASSED
tests/unit/data_generation/test_config_generator.py::test_generate_config_with_tokenizer PASSED
tests/unit/data_generation/test_config_generator.py::test_config_json_serializable PASSED
tests/unit/data_generation/test_config_generator.py::test_extract_config_from_generator PASSED
tests/unit/data_generation/test_config_generator.py::test_extract_config_uses_generator_tokenizer PASSED
tests/unit/data_generation/test_config_generator.py::test_generate_config_with_max_samples_none PASSED
tests/unit/data_generation/test_config_generator.py::test_reproducibility_info_contains_all_fields PASSED
tests/unit/data_generation/test_config_generator.py::test_format_config_schema_structure PASSED

======================== 18 passed, 1 warning in 5.40s ========================

Example Config v2.0 Output

{
  "version": "2.0",
  "generated_at": "2025-10-31T13:15:00.123456",
  "speculators_version": "0.2.0.dev67",
  "reproducibility": {
    "command": "data_generation_offline.py --target-model-path meta-llama/Llama-3.1-8B ...",
    "cache_key": "abc123def456...",
    "gpu": "NVIDIA H100 80GB HBM3",
    "packages": {
      "torch": "2.1.0+cu121",
      "vllm": "0.6.3",
      "transformers": "4.40.0",
      "speculators": "0.2.0.dev67"
    }
  },
  "model": {
    "target_model_path": "meta-llama/Llama-3.1-8B",
    "tensor_parallel_size": 1,
    "max_model_len": 2048,
    "gpu_memory_utilization": 0.8,
    "hidden_size": 4096
  },
  "data": {
    "train_data_path": "sharegpt",
    "chat_template": "llama3",
    "seq_length": 2048,
    "max_samples": null,
    "num_samples": 1000,
    "seed": 0
  },
  "hidden_states": {
    "layer_ids": [2, 14, 24, 31],
    "description": "3 layers for EAGLE3 fusion, last layer for target logits"
  },
  "generation": {
    "batch_size": 8,
    "cache_dir": "./cache"
  },
  "format": {
    "file_pattern": "data_{idx}.pt",
    "schema": {
      "input_ids": {
        "dtype": "torch.long",
        "shape": "[seq_len]"
      },
      "hidden_states": {
        "dtype": "list[torch.bfloat16]",
        "shape": "list of [seq_len, 4096]",
        "num_tensors": 4
      },
      "loss_mask": {
        "dtype": "torch.long",
        "shape": "[seq_len]"
      }
    }
  },
  "example_prompt_token_ids": [1, 2, 3, 4, 5],
  "example_prompt_str": "The quick brown fox jumps over the lazy dog.",
  "example_output_token_ids": [10, 11, 12, 13, 14],
  "example_output_str": "This is an example of generated text."
}

@github-actions
Copy link

github-actions bot commented Oct 31, 2025

📦 Build Artifacts Available
The build artifacts (`.whl` and `.tar.gz`) have been successfully generated and are available for download: https://github.com/vllm-project/speculators/actions/runs/19074324832/artifacts/4463646258.
They will be retained for up to 30 days.
Commit: 4b4a3e2

@rahul-tuli rahul-tuli requested a review from shanjiaz October 31, 2025 14:29
@rahul-tuli rahul-tuli force-pushed the update-config branch 2 times, most recently from e0f444a to bfc795a Compare November 4, 2025 15:13
Add comprehensive configuration generator with:
- Type-safe dataclass architecture using Pydantic-style validation
- Reproducibility tracking (package versions, GPU info, cache keys)
- Real example generation from vLLM generator output
- RST-formatted docstrings for better documentation
- Automatic schema documentation for output format
- Full test coverage (18 comprehensive unit tests)

Key Features:
- Uses dataclass instances with asdict() for type safety and DRY
- Integrates with VllmHiddenStatesGenerator for real examples
- Supports both direct config creation and generator extraction
- Tracks reproducibility info (command, packages, GPU, cache key)
- JSON-serializable output for config persistence

Code Quality:
- All style checks passing (ruff format, ruff check)
- All quality checks passing
- 18/18 tests passing
- Removed trivial dataclass tests (focus on actual behavior)
- Created reusable pytest fixtures to reduce duplication
- Simplified RST docstrings (param and returns only)

Files Changed:
- src/speculators/data_generation/config_generator.py (345 lines)
- tests/unit/data_generation/test_config_generator.py (395 lines, 18 tests)
- tests/unit/data_generation/__init__.py (new)

Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
@rahul-tuli rahul-tuli force-pushed the update-config branch 2 times, most recently from 7438511 to 7465520 Compare November 4, 2025 15:27
Update data generation pipeline to use new type-safe config generator:
- Replace manual dictionary-based config with extract_config_from_generator()
- Upgrade from config v1.0 to v2.0 with enhanced metadata
- Add package version tracking and reproducibility info
- Include real example prompts/outputs from vLLM generation
- Export config generator functions from __init__.py

Changes:
- scripts/data_generation_offline.py: Use extract_config_from_generator()
- src/speculators/data_generation/__init__.py: Export config generator API

Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants