Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 46 additions & 47 deletions scripts/data_generation_offline.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import json
import logging
import os
from datetime import datetime
from argparse import Namespace

import torch
from datasets import Dataset as HFDataset
from datasets import load_from_disk
from tqdm import tqdm # type: ignore[import-untyped]

from speculators.data_generation.config_generator import extract_config_from_generator
from speculators.data_generation.logging_utils import PipelineLogger
from speculators.data_generation.preprocessing import (
generate_cache_key,
Expand All @@ -41,8 +43,10 @@
log = PipelineLogger(__name__)


def parse_args():
parser = argparse.ArgumentParser(description="Generate EAGLE training data offline")
def parse_args() -> Namespace:
parser: argparse.ArgumentParser = argparse.ArgumentParser(
description="Generate EAGLE training data offline"
)

# Model arguments
parser.add_argument(
Expand Down Expand Up @@ -148,8 +152,13 @@ def parse_args():
return parser.parse_args()


def load_or_preprocess_dataset(args):
"""Load preprocessed dataset from cache, or run preprocessing if needed."""
def load_or_preprocess_dataset(args: Namespace) -> HFDataset:
"""Load preprocessed dataset from cache, or run preprocessing if needed.

This automatically handles preprocessing if the cached data doesn't exist,
making the pipeline more user-friendly and preventing parameter mismatches.
"""
# Generate cache key (must match the one used during preprocessing)
cache_key = generate_cache_key(
args.target_model_path,
args.chat_template,
Expand Down Expand Up @@ -205,52 +214,42 @@ def find_last_checkpoint(output_dir: str) -> int:
return max_index + 1


def save_config(args, generator, num_samples, output_dir):
"""Save metadata config file for reproducibility"""
config = {
"version": "1.0",
"generated_at": datetime.now().isoformat(),
"model": {
"target_model_path": args.target_model_path,
"tensor_parallel_size": args.tensor_parallel_size,
"max_model_len": args.max_model_len,
"gpu_memory_utilization": args.gpu_memory_utilization,
},
"data": {
"train_data_path": args.train_data_path,
"chat_template": args.chat_template,
"seq_length": args.seq_length,
"max_samples": args.max_samples,
"num_samples": num_samples,
"seed": args.seed,
},
"hidden_states": {
"layer_ids": generator.layer_ids,
"num_layers": len(generator.layer_ids),
"description": (
"First 3 layers for EAGLE3 fusion, last layer for target logits"
),
},
"generation": {
"batch_size": args.batch_size,
"cache_dir": args.cache_dir,
},
"format": {
"file_pattern": "data_{idx}.pt",
"fields": ["input_ids", "hidden_states", "loss_mask"],
"hidden_states_type": "List[torch.Tensor]",
"hidden_states_shape": "List of [seq_len, hidden_dim], one per layer",
"note": ("hidden_states is a list of tensors in order of layer_ids"),
},
}
def save_config(
args: Namespace,
generator: VllmHiddenStatesGenerator,
num_samples: int,
output_dir: str,
) -> None:
"""
Save metadata config file for reproducibility.

Uses the new config generator (v2.0) with enhanced metadata including
package versions, GPU info, example prompts, and detailed schema.
"""
log.subsection("Saving configuration metadata")

# Generate config using new config generator
config = extract_config_from_generator(
generator=generator,
train_data_path=args.train_data_path,
chat_template=args.chat_template,
seq_length=args.seq_length,
batch_size=args.batch_size,
cache_dir=args.cache_dir,
num_samples=num_samples,
max_samples=args.max_samples,
seed=args.seed,
)

# Save to JSON
config_path = os.path.join(output_dir, "data_config.json")
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
log.info(f"Saved config to {config_path}")
json.dump(config.to_dict(), f, indent=2)

log.success(f"Saved config v{config.version} to {config_path}")


def generate_and_save_hidden_states(args, dataset):
def generate_and_save_hidden_states(args: Namespace, dataset: HFDataset) -> int:
"""Generate hidden states and save each sample as a .pt file"""
os.makedirs(args.output_dir, exist_ok=True)

Expand Down Expand Up @@ -313,7 +312,7 @@ def generate_and_save_hidden_states(args, dataset):
return samples_saved


def main():
def main() -> None:
args = parse_args()

log.section("EAGLE Offline Data Generation")
Expand Down
16 changes: 14 additions & 2 deletions src/speculators/data_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
"""Data generation utilities for EAGLE-style speculative decoding training."""

from .vllm_hidden_states_generator import VllmHiddenStatesGenerator
from speculators.data_generation.config_generator import (
DataGenerationConfig,
extract_config_from_generator,
generate_config,
)
from speculators.data_generation.vllm_hidden_states_generator import (
VllmHiddenStatesGenerator,
)

__all__ = ["VllmHiddenStatesGenerator"]
__all__ = [
"DataGenerationConfig",
"VllmHiddenStatesGenerator",
"extract_config_from_generator",
"generate_config",
]
Loading
Loading