"""Benchmark generation speeds across different attention backends using subprocesses."""

import os
import sys
import subprocess
import json
import matplotlib.pyplot as plt
import numpy as np

# List of backends to test
BACKENDS = [
    "native",
    "flash_varlen",
    "xformers",
    "flash_hub",
    "flash_varlen_hub",
]

# Worker script that runs a single backend benchmark
WORKER_SCRIPT = """
import os
import sys
import time
import torch
import json

backend_name = sys.argv[1]

# Set backend before importing
os.environ["DIFFUSERS_ATTN_BACKEND"] = backend_name

from diffusers import QwenImagePipeline

print(f"Benchmarking: {backend_name}", file=sys.stderr)

# Load pipeline
print(f"Loading pipeline...", file=sys.stderr)
load_start = time.time()
pipe = QwenImagePipeline.from_pretrained(
    "Qwen/Qwen-Image",
    torch_dtype=torch.bfloat16
).to("cuda")
load_time = time.time() - load_start
print(f"  Load time: {load_time:.2f}s", file=sys.stderr)

# Warmup run
print("Warming up...", file=sys.stderr)
_ = pipe(
    prompt="warmup",
    height=512,
    width=512,
    num_inference_steps=5
)
torch.cuda.synchronize()

# Benchmark single generation
print("Benchmarking single generation...", file=sys.stderr)
torch.cuda.synchronize()
start = time.time()
_ = pipe(
    prompt="A serene mountain landscape at sunset",
    height=512,
    width=512,
    num_inference_steps=30
)
torch.cuda.synchronize()
single_time = time.time() - start
print(f"  Single generation: {single_time:.2f}s", file=sys.stderr)

# Benchmark batch generation
print("Benchmarking batch generation...", file=sys.stderr)
torch.cuda.synchronize()
start = time.time()
_ = pipe(
    prompt=["cat", "A beautiful detailed painting of mountains and lakes"],
    height=512,
    width=512,
    num_inference_steps=25
)
torch.cuda.synchronize()
batch_time = time.time() - start
print(f"  Batch generation: {batch_time:.2f}s", file=sys.stderr)

# Output results as JSON
result = {
    "backend": backend_name,
    "single_time": single_time,
    "batch_time": batch_time,
}
print(json.dumps(result))
"""


def benchmark_backend(backend_name):
    """Run benchmark for a single backend in a subprocess."""
    print(f"\n{'=' * 80}")
    print(f"Starting subprocess for: {backend_name}")
    print(f"{'=' * 80}")

    # Write worker script to temp file
    worker_file = "/tmp/benchmark_worker.py"
    with open(worker_file, "w") as f:
        f.write(WORKER_SCRIPT)

    try:
        # Run in subprocess
        result = subprocess.run(
            [sys.executable, worker_file, backend_name],
            capture_output=True,
            text=True,
            timeout=600,  # 10 minute timeout
        )

        # Print stderr (progress messages)
        if result.stderr:
            print(result.stderr, end="")

        if result.returncode != 0:
            print(f"✗ {backend_name} failed with return code {result.returncode}")
            if result.stdout:
                print("STDOUT:", result.stdout)
            if result.stderr:
                print("STDERR:", result.stderr)
            return None

        # Parse JSON result from stdout
        result_data = json.loads(result.stdout.strip())
        print(f"✓ {backend_name} completed successfully")
        return result_data

    except subprocess.TimeoutExpired:
        print(f"✗ {backend_name} timed out after 10 minutes")
        return None
    except json.JSONDecodeError as e:
        print(f"✗ {backend_name} failed to parse results: {e}")
        print("STDOUT:", result.stdout)
        return None
    except Exception as e:
        print(f"✗ {backend_name} failed: {e}")
        import traceback

        traceback.print_exc()
        return None


def main():
    results = []

    for backend in BACKENDS:
        result = benchmark_backend(backend)
        if result:
            results.append(result)

    # Create visualization
    print(f"\n{'=' * 80}")
    print("Creating benchmark visualization...")
    print(f"{'=' * 80}\n")

    if not results:
        print("No results to plot!")
        return

    # Extract data
    backends = [r["backend"] for r in results]
    single_times = [r["single_time"] for r in results]
    batch_times = [r["batch_time"] for r in results]

    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    # Plot 1: Single generation times
    colors = plt.cm.Set3(np.linspace(0, 1, len(backends)))
    bars1 = ax1.bar(
        backends, single_times, color=colors, edgecolor="black", linewidth=1.5
    )
    ax1.set_ylabel("Time (seconds)", fontsize=12, fontweight="bold")
    ax1.set_title(
        "Single Image Generation (30 steps, 512x512)", fontsize=13, fontweight="bold"
    )
    ax1.set_xlabel("Attention Backend", fontsize=12, fontweight="bold")
    ax1.tick_params(axis="x", rotation=45)
    ax1.grid(axis="y", alpha=0.3, linestyle="--")

    # Add value labels on bars
    for bar in bars1:
        height = bar.get_height()
        ax1.text(
            bar.get_x() + bar.get_width() / 2.0,
            height,
            f"{height:.2f}s",
            ha="center",
            va="bottom",
            fontweight="bold",
            fontsize=10,
        )

    # Plot 2: Batch generation times
    bars2 = ax2.bar(
        backends, batch_times, color=colors, edgecolor="black", linewidth=1.5
    )
    ax2.set_ylabel("Time (seconds)", fontsize=12, fontweight="bold")
    ax2.set_title(
        "Batch Generation (2 images, 25 steps, 512x512)", fontsize=13, fontweight="bold"
    )
    ax2.set_xlabel("Attention Backend", fontsize=12, fontweight="bold")
    ax2.tick_params(axis="x", rotation=45)
    ax2.grid(axis="y", alpha=0.3, linestyle="--")

    # Add value labels on bars
    for bar in bars2:
        height = bar.get_height()
        ax2.text(
            bar.get_x() + bar.get_width() / 2.0,
            height,
            f"{height:.2f}s",
            ha="center",
            va="bottom",
            fontweight="bold",
            fontsize=10,
        )

    # Overall title
    fig.suptitle(
        "QwenImage Attention Backend Performance Comparison",
        fontsize=15,
        fontweight="bold",
        y=0.98,
    )

    plt.tight_layout()

    # Save figure
    output_path = "backend_benchmark_complete.png"
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    print(f"✓ Benchmark plot saved to: {output_path}")

    # Print summary table
    print("\n" + "=" * 80)
    print("BENCHMARK RESULTS SUMMARY")
    print("=" * 80)
    print(f"{'Backend':<20} {'Single (30 steps)':<20} {'Batch (25 steps)':<20}")
    print("-" * 80)
    for r in results:
        print(
            f"{r['backend']:<20} {r['single_time']:>8.2f}s {'':<11} {r['batch_time']:>8.2f}s"
        )
    print("=" * 80)

    # Calculate speedups relative to native
    if results[0]["backend"] == "native":
        baseline_single = results[0]["single_time"]
        baseline_batch = results[0]["batch_time"]
        print("\nSpeedup vs Native Backend:")
        print("-" * 80)
        for r in results:
            single_speedup = baseline_single / r["single_time"]
            batch_speedup = baseline_batch / r["batch_time"]
            print(
                f"{r['backend']:<20} Single: {single_speedup:>5.2f}x    Batch: {batch_speedup:>5.2f}x"
            )
        print("=" * 80)


if __name__ == "__main__":
    main()
