diff --git a/docs/recipes/post_training/reason1/physical-plausibility-check/assets/custom_dataset_sft_config.toml b/docs/recipes/post_training/reason1/physical-plausibility-check/assets/custom_dataset_sft_config.toml new file mode 100644 index 0000000..63f2c9d --- /dev/null +++ b/docs/recipes/post_training/reason1/physical-plausibility-check/assets/custom_dataset_sft_config.toml @@ -0,0 +1,64 @@ +# Training Configuration for Custom Dataset (Transfer1) +# Dataset: Custom labeled videos with physical plausibility scores +# Model: Cosmos-Reason1-7B +# +# This configuration is optimized for 8 GPUs. Adjust dp_shard_size based on your setup: +# - 2 GPUs: dp_shard_size = 2 +# - 4 GPUs: dp_shard_size = 4 +# - 8 GPUs: dp_shard_size = 8 + +[custom.dataset] +# Path to training dataset with conversation format +path = "data/transfer1_split_with_conv/train" + +[train] +# Training epochs and output configuration +epoch = 10 +output_dir = "outputs/transfer1_sft" +compile = false +train_batch_per_replica = 32 + +# Evaluation configuration +eval_steps = 50 +evaluation_strategy = "steps" +save_strategy = "steps" +load_best_model_at_end = true +metric_for_best_model = "eval_loss" + +[policy] +# Model configuration +model_name_or_path = "nvidia/Cosmos-Reason1-7B" +model_max_length = 4096 + +[logging] +# Logging configuration +logger = ['console', 'tensorboard'] +project_name = "cosmos_reason1" +experiment_name = "post_training_hf/transfer1_sft" + +[train.train_policy] +# Training policy configuration +type = "sft" +conversation_column_name = "conversations" +mini_batch = 4 + +[train.eval_policy] +# Evaluation dataset configuration +dataset.name = "data/transfer1_split_with_conv/eval" + +[train.ckpt] +# Checkpoint configuration +enable_checkpoint = true +save_freq = 50 +max_keep = 5 +save_mode = "async" + +[policy.parallelism] +# Parallelism configuration +tp_size = 1 +cp_size = 1 +dp_shard_size = 8 +pp_size = 1 +dp_replicate_size = 1 + + diff --git a/docs/recipes/post_training/reason1/physical-plausibility-check/post_training.md b/docs/recipes/post_training/reason1/physical-plausibility-check/post_training.md index c694ba3..21cb1a2 100644 --- a/docs/recipes/post_training/reason1/physical-plausibility-check/post_training.md +++ b/docs/recipes/post_training/reason1/physical-plausibility-check/post_training.md @@ -372,12 +372,208 @@ Realistically, dough should stretch and fold in certain ways when rolled or shap - **Model prediction**: 2. (The prediction matches the ground truth.) - **Summary of the model output**: The analysis has successfully identified the key issues in the video, including unnatural deformation, inconsistent texture, gravity-defying movement, abrupt motion changes, and unrealistic food preparation behavior. +## Fine-Tuning on Custom Datasets + +Having demonstrated fine-tuning on the public VideoPhy-2 dataset, we now show how to adapt this methodology to custom datasets. This section uses videos generated by Cosmos Transfer 2.5 with human-labeled physical plausibility scores. + +### Dataset Preparation + +The custom dataset workflow supports local video files with human-annotated quality scores. The dataset preparation involves: + +1. **Data Organization**: Videos and associated metadata (prompts, labels) +2. **Train/Eval Split**: Stratified splitting to maintain label distribution +3. **Conversation Format**: Converting to the format required for SFT training + +### Step 1: Create Train/Eval Split + +The first step creates a stratified train/eval split from local videos with labels. Copy the script from the cookbook: + +```bash +# In cosmos-reason1 root directory +cp /path/to/cosmos-cookbook/scripts/examples/reason1/physical-plausibility-check/create_dataset_with_split.py \ + examples/post_training_hf/scripts/ +``` + +Prepare your data directory structure: + +``` +data/ +├── transfer1_generated_videos/ # Video files (.mp4) +├── prompts/ # Prompt text files (.txt) +└── transfer25_human_labeled.xlsx # Labels spreadsheet +``` + +**Example prompt file** (`prompts/video_001_prompt.txt`): + +``` +A person waves hello to another person approaching from the left +``` + +**Example labels spreadsheet** (`transfer25_human_labeled.xlsx`): + +| output_link | Action alignment | Physical common sense | Quality | +|-------------|------------------|----------------------|---------| +| https://example.com/videos/video_001.mp4 | 5 | 5 | 5 | +| https://example.com/videos/video_002.mp4 | 4 | 3 | 4 | +| https://example.com/videos/video_003.mp4 | 2 | 1 | 2 | + +The script expects: +- **output_link**: Video URL or path (used to match video files) +- **Physical common sense**: Score 0-1 or 1-5 (use `--scale_labels` to convert 0-1 to 1-5) + +**Note**: If your video URLs don't match the filename pattern, customize the `extract_filename_from_url()` function in `create_dataset_with_split.py`. The script includes examples for simple and complex URL patterns. + +Run the script to create train/eval split: + +```bash +cd examples/post_training_hf/ + +uv run scripts/create_dataset_with_split.py \ + --output_dir data/transfer1_split \ + --data_dir data \ + --excel_file transfer25_human_labeled.xlsx \ + --eval_size 0.1 \ + --balance_labels \ + --scale_labels +``` + +**Key Options:** + +- `--eval_size 0.1`: 10% of data for evaluation +- `--balance_labels`: Balance label distribution across classes +- `--scale_labels`: Map binary labels (0,1) to 1-5 scale +- `--random_seed 42`: Reproducible splitting + +### Step 2: Add Conversation Format + +The second step converts the dataset to conversation format required for training. Copy the script: + +```bash +cp /path/to/cosmos-cookbook/scripts/examples/reason1/physical-plausibility-check/add_conversations_to_dataset.py \ + examples/post_training_hf/scripts/ +``` + +Convert both train and eval splits: + +```bash +# Process train split +uv run scripts/add_conversations_to_dataset.py \ + --input_dir data/transfer1_split/train \ + --output_dir data/transfer1_split_with_conv/train \ + --prompt_path prompts/video_reward.yaml + +# Process eval split +uv run scripts/add_conversations_to_dataset.py \ + --input_dir data/transfer1_split/eval \ + --output_dir data/transfer1_split_with_conv/eval \ + --prompt_path prompts/video_reward.yaml +``` + +### Step 3: Configure Training + +Copy the training configuration from the cookbook: + +```bash +cp /path/to/cosmos-cookbook/docs/recipes/post_training/reason1/physical-plausibility-check/assets/custom_dataset_sft_config.toml \ + examples/post_training_hf/configs/transfer1_sft.toml +``` + +The training uses the existing `scripts/custom_sft.py` script already available in the cosmos-reason1 repository. + +**Key Configuration Parameters** (from `configs/transfer1_sft.toml`): + +- `custom.dataset.path`: Path to training dataset (`"data/transfer1_split_with_conv/train"`) +- `train.epoch`: Number of training epochs (10) +- `train.eval_steps`: Evaluate every 50 steps +- `train.output_dir`: Output directory for checkpoints (`"outputs/transfer1_sft"`) +- `policy.model_name_or_path`: Base model (`"nvidia/Cosmos-Reason1-7B"`) +- `policy.parallelism.dp_shard_size`: Data parallel sharding - adjust based on GPUs (2, 4, or 8) +- `train.ckpt.save_freq`: Save checkpoint every 50 steps +- `train.ckpt.max_keep`: Keep 5 best checkpoints + +### Step 4: Run Training + +Start the fine-tuning process: + +```bash +cd examples/post_training_hf/ +cosmos-rl --config configs/transfer1_sft.toml scripts/custom_sft.py +``` + +Training outputs are saved to `outputs/transfer1_sft/[timestamp]/`: + +- `safetensors/step_*/`: Model checkpoints +- `tensorboard/`: Training metrics + +Monitor training progress with TensorBoard: + +```bash +tensorboard --logdir outputs/transfer1_sft/ +``` + +### Step 5: Evaluate Fine-Tuned Model + +After training, evaluate the model on the evaluation dataset. Copy the evaluation script: + +```bash +cp /path/to/cosmos-cookbook/scripts/examples/reason1/physical-plausibility-check/evaluate_model.py \ + examples/post_training_hf/scripts/ +``` + +Run evaluation: + +```bash +uv run scripts/evaluate_model.py \ + --model_path outputs/transfer1_sft/[timestamp]/safetensors/step_80 \ + --eval_dataset data/transfer1_split_with_conv/eval \ + --prompt_path prompts/video_reward.yaml \ + --output_dir eval_results +``` + +The evaluation generates: + +- `evaluation_results.json`: Detailed metrics +- `evaluation_report.html`: Interactive HTML report + +**Evaluation Metrics:** + +- **Exact Accuracy**: Percentage of exact score matches +- **Within ±1 Accuracy**: Predictions within 1 point of ground truth +- **Mean Absolute Error**: Average prediction error +- **Binary Classification**: Precision, recall, F1 for good vs bad videos + +### Results and Analysis + +The fine-tuned model shows improved performance on custom datasets. The evaluation report provides: + +- Overall accuracy metrics +- Confusion matrix showing prediction patterns +- Per-sample results with model responses +- Binary classification metrics for quality filtering + +This workflow can be adapted to other video quality assessment tasks by: + +1. Organizing videos and labels in the specified format +2. Adjusting the prompt template for your specific task +3. Modifying the label scaling if using different score ranges + ## Conclusion -Fine-tuning Cosmos Reason 1 on VideoPhy-2 data significantly improves physical plausibility prediction, progressing from zero-shot (0.293 correlation) to SFT (0.395) and RL (0.425). Key insights: +This case study demonstrates the full spectrum of fine-tuning Cosmos Reason 1 for physical plausibility prediction: + +- **Zero-shot Performance**: The base model shows strong understanding of physical laws without fine-tuning +- **Supervised Fine-Tuning**: Training on VideoPhy-2 improves correlation from 0.293 to 0.395 +- **Reinforcement Learning**: Further enhancement to 0.425 correlation with better reasoning traces +- **Custom Dataset Adaptation**: Complete workflow for fine-tuning on domain-specific datasets + +Key insights: - **Progressive improvement**: Each training stage (SFT, RL) delivers measurable gains in both accuracy and correlation, with RL achieving the best overall performance. - **Thinking traces enhance intepretability**: RL training with structured prompts enables the model to generate detailed reasoning traces that explain its predictions. -- **Flexibility**: This methodology can be adapted to other video quality assessment tasks by substituting the datasets and defining appropriate metrics. +- **Flexibility**: The methodology can be adapted to custom datasets and other video quality assessment tasks by following the dataset preparation workflow and adjusting prompts and metrics. + +The custom dataset workflow enables practitioners to: -As a next step, we can investigate reasoning SFT as a warmup step using datasets that contain thinking traces. This can improve the model's reasoning ability before RL training. +1. Leverage videos from Cosmos Transfer or other sources +2. Apply human labeling for domain-specific quality criteria +3. Fine-tune models for specialized use cases in video generation quality control diff --git a/scripts/examples/reason1/physical-plausibility-check/add_conversations_to_dataset.py b/scripts/examples/reason1/physical-plausibility-check/add_conversations_to_dataset.py new file mode 100644 index 0000000..3ba1e74 --- /dev/null +++ b/scripts/examples/reason1/physical-plausibility-check/add_conversations_to_dataset.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +"""Add conversation format to existing datasets. + +This script converts datasets with caption/video_url/pc format +to the conversation format required for training. +""" + +import argparse +import json +from pathlib import Path + +import datasets +import yaml +from cosmos_reason1_utils.text import PromptConfig, create_conversation +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--input_dir", type=str, required=True, + help="Input dataset directory" + ) + parser.add_argument( + "--output_dir", type=str, required=True, + help="Output dataset directory" + ) + parser.add_argument( + "--prompt_path", type=str, required=True, + help="Path to prompt YAML file" + ) + args = parser.parse_args() + + # Load prompt template + print(f"📝 Loading prompt from: {args.prompt_path}") + with open(args.prompt_path, 'r') as f: + prompt_config = PromptConfig.model_validate(yaml.safe_load(f)) + + system_prompt = prompt_config.system_prompt + user_prompt = prompt_config.user_prompt + + # Load existing dataset + print(f"📂 Loading dataset from: {args.input_dir}") + dataset = datasets.load_from_disk(args.input_dir) + print(f"✅ Loaded {len(dataset)} samples") + print(f"Current features: {list(dataset.features.keys())}") + + # Convert to conversation format + print("\n🔄 Converting to conversation format...") + conversations = [] + + for sample in tqdm(dataset, desc="Processing samples"): + video_path = sample['video_url'] + pc_score = sample['pc'] + + # Create conversation + conversation = create_conversation( + system_prompt=system_prompt, + user_prompt=user_prompt, + videos=[video_path], + response=f"\n{pc_score}\n", + ) + + conversations.append(json.dumps(conversation)) + + # Add conversations column to dataset + dataset = dataset.add_column("conversations", conversations) + + print(f"\n✅ Added 'conversations' column") + print(f"New features: {list(dataset.features.keys())}") + + # Save updated dataset + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n💾 Saving to: {output_dir}") + dataset.save_to_disk(str(output_dir)) + + print(f"\n✅ Dataset saved successfully!") + print(f"\nSample conversation:") + print(json.loads(dataset[0]['conversations'])) + + +if __name__ == "__main__": + main() + + diff --git a/scripts/examples/reason1/physical-plausibility-check/create_dataset_with_split.py b/scripts/examples/reason1/physical-plausibility-check/create_dataset_with_split.py new file mode 100644 index 0000000..f9fdb79 --- /dev/null +++ b/scripts/examples/reason1/physical-plausibility-check/create_dataset_with_split.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Create training and evaluation datasets from local video files. + +This script creates train/eval splits from local video files with human-labeled quality scores. +Supports stratified splitting to maintain label distribution in both sets. +""" + +import argparse +import json +import os +import random +import re +import sys +from collections import Counter +from pathlib import Path +from typing import Optional + +try: + import datasets + import pandas as pd + import yaml + from rich import print + from tqdm import tqdm +except ImportError as e: + print(f"Error: Missing required package: {e}") + print("Please install dependencies:") + print(" pip install datasets pandas openpyxl pyyaml rich tqdm") + sys.exit(1) + + +def extract_filename_from_url(url: str) -> Optional[str]: + """Extract local filename from URL. + + This function extracts the filename from video URLs. Customize this + function based on your URL structure. Examples: + + - Simple: Extract just the filename + url = "https://example.com/videos/video_001.mp4" -> "video_001.mp4" + + - Complex: Parse structured paths and reconstruct filenames + url = "https://example.com/action/wave/segment_01/video_001.mp4" + -> "wave_segment_01_video_001.mp4" + """ + # Example 1: Simple filename extraction + # Uncomment and modify for your use case: + # return url.split('/')[-1] # Returns last part of URL + + # Example 2: Custom pattern matching + # Modify this pattern to match your URL structure + pattern = r'([\w]+)/(com_\d+_\d+_[a-f0-9]+)_segment_(\d+)_left/gpu_(\d+)/video_(\d+)/output\.mp4' + match = re.search(pattern, url) + if match: + action = match.group(1) + timestamp_id = match.group(2) + segment = match.group(3) + gpu = match.group(4) + video = match.group(5) + filename = f'{action}_{timestamp_id}_segment_{segment}_left_gpu_{gpu}_video_{video}_output.mp4' + return filename + + # Fallback: Try extracting just the filename + if '/' in url: + return url.split('/')[-1] + + return None + + +def balance_dataset_labels(dataset: datasets.Dataset, verbose: bool = True) -> datasets.Dataset: + """Balance dataset by resampling so each label appears the same number of times.""" + random.seed(42) + + # Extract PC labels and group samples by label + label_to_indices = {} + for i, sample in enumerate(dataset): + pc_score = sample.get("pc") + if pc_score is not None: + if pc_score not in label_to_indices: + label_to_indices[pc_score] = [] + label_to_indices[pc_score].append(i) + + if verbose: + print("\n📊 Original label distribution:") + for label in sorted(label_to_indices.keys()): + count = len(label_to_indices[label]) + print(f" Label {label}: {count} samples") + + # target samples per label is the average number of samples per label + target_samples_per_label = len(dataset) // len(label_to_indices) + + if verbose: + print(f"\n🎯 Target samples per label: {target_samples_per_label}") + + # Resample each label to target count + balanced_indices = [] + for label, indices in label_to_indices.items(): + if len(indices) >= target_samples_per_label: + if verbose: + print(f"Downsampling label {label} from {len(indices)} to {target_samples_per_label}") + selected_indices = random.sample(indices, target_samples_per_label) + else: + if verbose: + print(f"Upsampling label {label} from {len(indices)} to {target_samples_per_label}") + selected_indices = random.choices(indices, k=target_samples_per_label) + + balanced_indices.extend(selected_indices) + + # Shuffle the balanced indices + random.shuffle(balanced_indices) + + # Create new balanced dataset + balanced_data = [dataset[i] for i in balanced_indices] + balanced_dataset = datasets.Dataset.from_list(balanced_data) + + if verbose: + print("\n📊 Final balanced label distribution:") + final_label_counts = Counter(sample["pc"] for sample in balanced_dataset) + for label in sorted(final_label_counts.keys()): + print(f" Label {label}: {final_label_counts[label]} samples") + print(f"\n✅ Dataset balanced: {len(dataset)} → {len(balanced_dataset)} samples") + + return balanced_dataset + + +def stratified_split(dataset: datasets.Dataset, eval_size: float = 0.1, random_seed: int = 42): + """Split dataset into train/eval while maintaining label distribution.""" + random.seed(random_seed) + + # Group indices by label + label_to_indices = {} + for i, sample in enumerate(dataset): + pc_score = sample.get("pc") + if pc_score is not None: + if pc_score not in label_to_indices: + label_to_indices[pc_score] = [] + label_to_indices[pc_score].append(i) + + train_indices = [] + eval_indices = [] + + # Split each label proportionally + for label, indices in label_to_indices.items(): + random.shuffle(indices) + split_point = int(len(indices) * (1 - eval_size)) + train_indices.extend(indices[:split_point]) + eval_indices.extend(indices[split_point:]) + + # Shuffle to mix labels + random.shuffle(train_indices) + random.shuffle(eval_indices) + + # Create datasets + train_data = [dataset[i] for i in train_indices] + eval_data = [dataset[i] for i in eval_indices] + + train_dataset = datasets.Dataset.from_list(train_data) + eval_dataset = datasets.Dataset.from_list(eval_data) + + return train_dataset, eval_dataset + + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument( + "--output_dir", type=str, help="Output directory for train/eval datasets.", required=True + ) + parser.add_argument( + "--data_dir", + type=str, + default="data", + help="Path to data directory containing videos, prompts, and Excel file.", + ) + parser.add_argument( + "--excel_file", + type=str, + default="transfer25_human_labeled.xlsx", + help="Excel file with video URLs and labels.", + ) + parser.add_argument( + "--eval_size", + type=float, + default=0.1, + help="Fraction of data to use for evaluation (default: 0.1 = 10%%).", + ) + parser.add_argument( + "--balance_labels", + action="store_true", + help="Balance dataset labels before splitting.", + ) + parser.add_argument( + "--scale_labels", + action="store_true", + help="Map binary labels (0,1) to 1-5 scale: 0→1 (bad), 1→5 (good).", + ) + parser.add_argument( + "--random_seed", + type=int, + default=42, + help="Random seed for reproducibility.", + ) + args = parser.parse_args() + + # Set random seed + random.seed(args.random_seed) + + output_dir = Path(args.output_dir).resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + # Resolve paths + data_dir = Path(args.data_dir).resolve() + video_dir = data_dir / "transfer1_generated_videos" + prompt_dir = data_dir / "prompts" + excel_path = data_dir / args.excel_file + + print(f"📂 Data directory: {data_dir}") + print(f"📹 Video directory: {video_dir}") + print(f"📝 Prompt directory: {prompt_dir}") + print(f"📊 Excel file: {excel_path}") + print(f"📁 Output directory: {output_dir}") + print(f"🎲 Random seed: {args.random_seed}") + print(f"📊 Eval size: {args.eval_size * 100:.0f}%") + + # Read Excel file + print("\n📖 Reading Excel file...") + df = pd.read_excel(excel_path, skiprows=1, names=["video_url", "label"]) + print(f"Found {len(df)} labeled videos") + print(f"Label distribution: {df['label'].value_counts().to_dict()}") + + # Extract filenames from URLs + df["filename"] = df["video_url"].apply(extract_filename_from_url) + missing_filenames = df["filename"].isna().sum() + if missing_filenames > 0: + print(f"⚠️ Warning: {missing_filenames} URLs couldn't be parsed") + df = df[df["filename"].notna()].reset_index(drop=True) + + # Verify files exist + df["video_path"] = df["filename"].apply(lambda x: str(video_dir / x)) + df["prompt_path"] = df["filename"].apply( + lambda x: str(prompt_dir / x.replace("_output.mp4", "_prompt.txt")) + ) + + df["video_exists"] = df["video_path"].apply(os.path.exists) + df["prompt_exists"] = df["prompt_path"].apply(os.path.exists) + + missing_videos = (~df["video_exists"]).sum() + missing_prompts = (~df["prompt_exists"]).sum() + + if missing_videos > 0: + print(f"⚠️ Warning: {missing_videos} videos not found locally") + if missing_prompts > 0: + print(f"⚠️ Warning: {missing_prompts} prompts not found locally") + + df = df[df["video_exists"] & df["prompt_exists"]].reset_index(drop=True) + print(f"✅ {len(df)} samples have both video and prompt files") + + # Scale labels if requested + if args.scale_labels: + print("\n🔄 Scaling labels: 0→1 (bad physics), 1→5 (good physics)") + df["pc"] = df["label"].apply(lambda x: 1 if x == 0 else 5) + else: + df["pc"] = df["label"] + + # Read prompts + print("\n📝 Reading prompt files...") + prompts = [] + for prompt_path in tqdm(df["prompt_path"], desc="Loading prompts"): + try: + with open(prompt_path, "r") as f: + prompt_text = f.read().strip() + prompts.append(prompt_text) + except Exception as e: + print(f"⚠️ Error reading {prompt_path}: {e}") + prompts.append("") + + df["caption"] = prompts + + # Create dataset + dataset_dict = { + "caption": df["caption"].tolist(), + "video_url": df["video_path"].tolist(), + "pc": df["pc"].tolist(), + } + + full_dataset = datasets.Dataset.from_dict(dataset_dict) + print(f"\n📦 Created full dataset with {len(full_dataset)} samples") + + # Balance if requested (before splitting) + if args.balance_labels: + print("\n⚖️ Balancing dataset labels...") + full_dataset = balance_dataset_labels(full_dataset) + + # Perform stratified split + print(f"\n✂️ Splitting dataset: {(1-args.eval_size)*100:.0f}% train, {args.eval_size*100:.0f}% eval") + train_dataset, eval_dataset = stratified_split(full_dataset, eval_size=args.eval_size, random_seed=args.random_seed) + + # Print split statistics + print(f"\n📊 Split Statistics:") + print(f" Train: {len(train_dataset)} samples") + train_label_counts = Counter(train_dataset["pc"]) + for label in sorted(train_label_counts.keys()): + print(f" Label {label}: {train_label_counts[label]} samples ({train_label_counts[label]/len(train_dataset)*100:.1f}%)") + + print(f"\n Eval: {len(eval_dataset)} samples") + eval_label_counts = Counter(eval_dataset["pc"]) + for label in sorted(eval_label_counts.keys()): + print(f" Label {label}: {eval_label_counts[label]} samples ({eval_label_counts[label]/len(eval_dataset)*100:.1f}%)") + + # Save datasets + train_path = output_dir / "train" + eval_path = output_dir / "eval" + + print(f"\n💾 Saving datasets...") + train_dataset.save_to_disk(str(train_path)) + eval_dataset.save_to_disk(str(eval_path)) + + print(f"✅ Train dataset saved to: {train_path}") + print(f"✅ Eval dataset saved to: {eval_path}") + + # Save split info + split_info = { + "total_samples": len(full_dataset), + "train_samples": len(train_dataset), + "eval_samples": len(eval_dataset), + "eval_fraction": args.eval_size, + "random_seed": args.random_seed, + "balanced": args.balance_labels, + "scaled_labels": args.scale_labels, + "train_label_distribution": dict(train_label_counts), + "eval_label_distribution": dict(eval_label_counts), + } + + split_info_path = output_dir / "split_info.json" + with open(split_info_path, "w") as f: + json.dump(split_info, f, indent=2) + + print(f"📄 Split info saved to: {split_info_path}") + + print("\n" + "="*80) + print("✅ DATASET CREATION COMPLETE!") + print("="*80) + print(f"\nDataset locations:") + print(f" Train: {train_path}") + print(f" Eval: {eval_path}") + print(f"\nTo load in Python:") + print(f" import datasets") + print(f" train_ds = datasets.load_from_disk('{train_path}')") + print(f" eval_ds = datasets.load_from_disk('{eval_path}')") + print("="*80) + + +if __name__ == "__main__": + main() + + diff --git a/scripts/examples/reason1/physical-plausibility-check/evaluate_model.py b/scripts/examples/reason1/physical-plausibility-check/evaluate_model.py new file mode 100644 index 0000000..06562a6 --- /dev/null +++ b/scripts/examples/reason1/physical-plausibility-check/evaluate_model.py @@ -0,0 +1,514 @@ +#!/usr/bin/env python3 +""" +Evaluate Fine-tuned Cosmos Reason 1 Model on Transfer1 Evaluation Dataset + +Usage: + python3 scripts/evaluate_model.py \ + --model_path outputs/transfer1_sft/20251023145904/checkpoints/step_80/policy \ + --eval_dataset data/transfer1_split_with_conv/eval \ + --prompt_path prompts/video_reward.yaml \ + --output_dir eval_results +""" + +import argparse +import json +import re +import time +import xml.etree.ElementTree as ET +from pathlib import Path +from collections import Counter, defaultdict + +import yaml +from datasets import load_from_disk +from qwen_vl_utils import process_vision_info +from transformers import AutoProcessor +from vllm import LLM, SamplingParams + + +def parse_response(response): + """Parse response to extract integer score from tags.""" + try: + # Try XML parsing first + wrapped = f"{response.strip()}" + root = ET.fromstring(wrapped) + answer_element = root.find("answer") + + if answer_element is not None and answer_element.text: + answer_text = answer_element.text.strip() + try: + answer_int = int(answer_text) + # Ensure score is in valid range + if 1 <= answer_int <= 5: + return answer_int + except ValueError: + pass + + # Try regex as fallback + match = re.search(r"\s*(\d+)\s*", response) + if match: + try: + answer_int = int(match.group(1)) + if 1 <= answer_int <= 5: + return answer_int + except ValueError: + pass + + except Exception: + pass + + return None + + +def load_prompt_config(prompt_path): + """Load prompt configuration from YAML file.""" + with open(prompt_path, 'r') as f: + config = yaml.safe_load(f) + return config.get('system_prompt', ''), config.get('user_prompt', '') + + +def run_inference_batch(llm, processor, video_paths, system_prompt, user_prompt, batch_size=4): + """Run inference on a batch of videos.""" + sampling_params = SamplingParams( + temperature=0.1, # Low temperature for more deterministic outputs + top_k=10, + top_p=0.9, + repetition_penalty=1.05, + max_tokens=512, # Shorter for evaluation + ) + + results = [] + + for i in range(0, len(video_paths), batch_size): + batch_videos = video_paths[i:i+batch_size] + batch_inputs = [] + + for video_path in batch_videos: + messages = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": [ + { + "type": "video", + "video": video_path, + "fps": 16, + "total_pixels": 8192 * 28 * 28, + }, + {"type": "text", "text": user_prompt}, + ], + }, + ] + + prompt = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + image_inputs, video_inputs, video_kwargs = process_vision_info( + messages, return_video_kwargs=True + ) + + mm_data = {} + if image_inputs is not None: + mm_data["image"] = image_inputs + if video_inputs is not None: + mm_data["video"] = video_inputs + + batch_inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data, + "mm_processor_kwargs": video_kwargs, + }) + + # Generate responses for batch + outputs = llm.generate(batch_inputs, sampling_params=sampling_params) + + for output in outputs: + response_text = output.outputs[0].text + predicted_score = parse_response(response_text) + results.append({ + 'response': response_text, + 'predicted_score': predicted_score + }) + + return results + + +def calculate_metrics(predictions, ground_truths): + """Calculate evaluation metrics.""" + # Filter out failed predictions + valid_pairs = [(pred, gt) for pred, gt in zip(predictions, ground_truths) if pred is not None] + + if not valid_pairs: + return None + + predictions_valid = [p for p, _ in valid_pairs] + ground_truths_valid = [g for _, g in valid_pairs] + + # Exact accuracy + exact_matches = sum(1 for pred, gt in valid_pairs if pred == gt) + exact_accuracy = exact_matches / len(valid_pairs) + + # Accuracy within 1 point + within_1 = sum(1 for pred, gt in valid_pairs if abs(pred - gt) <= 1) + within_1_accuracy = within_1 / len(valid_pairs) + + # Mean Absolute Error + mae = sum(abs(pred - gt) for pred, gt in valid_pairs) / len(valid_pairs) + + # Confusion matrix + confusion_matrix = defaultdict(lambda: defaultdict(int)) + for pred, gt in valid_pairs: + confusion_matrix[gt][pred] += 1 + + # Binary classification metrics (1 vs 5) + # Ground truth: 1 = bad, 5 = good + binary_predictions = [1 if p <= 2 else (5 if p >= 4 else 3) for p in predictions_valid] + binary_ground_truth = [g for g in ground_truths_valid] + + # True positives, false positives, etc. for score 1 (bad videos) + tp_bad = sum(1 for pred, gt in zip(binary_predictions, binary_ground_truth) if pred == 1 and gt == 1) + fp_bad = sum(1 for pred, gt in zip(binary_predictions, binary_ground_truth) if pred == 1 and gt != 1) + tn_bad = sum(1 for pred, gt in zip(binary_predictions, binary_ground_truth) if pred != 1 and gt != 1) + fn_bad = sum(1 for pred, gt in zip(binary_predictions, binary_ground_truth) if pred != 1 and gt == 1) + + # True positives, false positives, etc. for score 5 (good videos) + tp_good = sum(1 for pred, gt in zip(binary_predictions, binary_ground_truth) if pred == 5 and gt == 5) + fp_good = sum(1 for pred, gt in zip(binary_predictions, binary_ground_truth) if pred == 5 and gt != 5) + tn_good = sum(1 for pred, gt in zip(binary_predictions, binary_ground_truth) if pred != 5 and gt != 5) + fn_good = sum(1 for pred, gt in zip(binary_predictions, binary_ground_truth) if pred != 5 and gt == 5) + + # Precision, Recall, F1 for bad videos + precision_bad = tp_bad / (tp_bad + fp_bad) if (tp_bad + fp_bad) > 0 else 0 + recall_bad = tp_bad / (tp_bad + fn_bad) if (tp_bad + fn_bad) > 0 else 0 + f1_bad = 2 * precision_bad * recall_bad / (precision_bad + recall_bad) if (precision_bad + recall_bad) > 0 else 0 + + # Precision, Recall, F1 for good videos + precision_good = tp_good / (tp_good + fp_good) if (tp_good + fp_good) > 0 else 0 + recall_good = tp_good / (tp_good + fn_good) if (tp_good + fn_good) > 0 else 0 + f1_good = 2 * precision_good * recall_good / (precision_good + recall_good) if (precision_good + recall_good) > 0 else 0 + + return { + 'total_samples': len(predictions), + 'valid_predictions': len(valid_pairs), + 'failed_predictions': len(predictions) - len(valid_pairs), + 'exact_accuracy': exact_accuracy, + 'within_1_accuracy': within_1_accuracy, + 'mean_absolute_error': mae, + 'confusion_matrix': dict(confusion_matrix), + 'binary_metrics': { + 'bad_videos': { + 'precision': precision_bad, + 'recall': recall_bad, + 'f1_score': f1_bad, + 'true_positives': tp_bad, + 'false_positives': fp_bad, + 'true_negatives': tn_bad, + 'false_negatives': fn_bad, + }, + 'good_videos': { + 'precision': precision_good, + 'recall': recall_good, + 'f1_score': f1_good, + 'true_positives': tp_good, + 'false_positives': fp_good, + 'true_negatives': tn_good, + 'false_negatives': fn_good, + } + }, + 'score_distribution': { + 'predictions': dict(Counter(predictions_valid)), + 'ground_truth': dict(Counter(ground_truths_valid)), + } + } + + +def generate_html_report(results, metrics, output_path): + """Generate HTML evaluation report.""" + html = """ + + + + Fine-tuned Model Evaluation Report + + + +
+

Fine-tuned Model Evaluation Report

+

Cosmos Reason 1 - Transfer1 Dataset Evaluation

+""" + + # Metrics section + html += """ +

Overall Metrics

+
+""" + + if metrics: + html += f""" +
+
Exact Accuracy
+
{metrics['exact_accuracy']:.1%}
+
+
+
Within ±1 Accuracy
+
{metrics['within_1_accuracy']:.1%}
+
+
+
Mean Absolute Error
+
{metrics['mean_absolute_error']:.2f}
+
+
+
Valid Predictions
+
{metrics['valid_predictions']}/{metrics['total_samples']}
+
+
+ +

Binary Classification Metrics

+
+
+
Bad Videos F1-Score
+
{metrics['binary_metrics']['bad_videos']['f1_score']:.1%}
+
+ Precision: {metrics['binary_metrics']['bad_videos']['precision']:.1%} | + Recall: {metrics['binary_metrics']['bad_videos']['recall']:.1%} +
+
+
+
Good Videos F1-Score
+
{metrics['binary_metrics']['good_videos']['f1_score']:.1%}
+
+ Precision: {metrics['binary_metrics']['good_videos']['precision']:.1%} | + Recall: {metrics['binary_metrics']['good_videos']['recall']:.1%} +
+
+
+""" + + # Confusion matrix + if metrics and metrics.get('confusion_matrix'): + html += """ +
+

Confusion Matrix

+ + + +""" + # Get all scores + all_scores = sorted(set(list(metrics['confusion_matrix'].keys()) + + [pred for preds in metrics['confusion_matrix'].values() for pred in preds.keys()])) + + for score in all_scores: + html += f"" + html += "" + + for gt_score in all_scores: + html += f"" + for pred_score in all_scores: + count = metrics['confusion_matrix'].get(gt_score, {}).get(pred_score, 0) + html += f"" + html += "" + + html += """ +
Ground Truth \\ PredictedScore {score}
Score {gt_score}{count}
+
+""" + + # Detailed results + html += """ +

Detailed Results

+""" + + for i, result in enumerate(results[:50], 1): # Show first 50 results + video_name = Path(result['video_path']).name + pred_score = result['predicted_score'] + gt_score = result['ground_truth'] + + if pred_score is None: + css_class = "failed" + status = "Failed to Parse" + elif pred_score == gt_score: + css_class = "correct" + status = "Correct" + else: + css_class = "incorrect" + status = "Incorrect" + + html += f""" +
+
{status} - Sample {i}
+
{video_name}
+
+ Ground Truth: {gt_score} + Predicted: {pred_score if pred_score else 'N/A'} +
+
+ Show Response +
{result['response']}
+
+
+""" + + if len(results) > 50: + html += f'

... and {len(results) - 50} more results

' + + html += """ +
+ + +""" + + with open(output_path, 'w') as f: + f.write(html) + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate fine-tuned model on evaluation dataset") + parser.add_argument("--model_path", type=str, required=True, help="Path to fine-tuned model checkpoint") + parser.add_argument("--eval_dataset", type=str, required=True, help="Path to evaluation dataset") + parser.add_argument("--prompt_path", type=str, default="prompts/video_reward.yaml", help="Path to prompt config") + parser.add_argument("--output_dir", type=str, default="eval_results", help="Output directory for results") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size for inference") + args = parser.parse_args() + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + print("=" * 80) + print("Fine-tuned Model Evaluation") + print("=" * 80) + print(f"Model: {args.model_path}") + print(f"Dataset: {args.eval_dataset}") + print(f"Output: {args.output_dir}") + print("=" * 80) + print() + + # Load prompt configuration + print("Loading prompt configuration...") + system_prompt, user_prompt = load_prompt_config(args.prompt_path) + + # Load evaluation dataset + print("Loading evaluation dataset...") + eval_dataset = load_from_disk(args.eval_dataset) + print(f" Loaded {len(eval_dataset)} samples") + print() + + # Load fine-tuned model + print("Loading fine-tuned model...") + print(" (This may take a few minutes...)") + llm = LLM( + model=args.model_path, + limit_mm_per_prompt={"image": 0, "video": 1}, + enforce_eager=True, + ) + processor = AutoProcessor.from_pretrained(args.model_path) + print(" Model loaded successfully") + print() + + # Run inference + print("Running inference...") + video_paths = [sample['video_url'] for sample in eval_dataset] + ground_truths = [sample['pc'] for sample in eval_dataset] + + start_time = time.time() + inference_results = run_inference_batch( + llm, processor, video_paths, system_prompt, user_prompt, + batch_size=args.batch_size + ) + elapsed_time = time.time() - start_time + + print(f" Inference completed in {elapsed_time:.1f} seconds") + print(f" Average: {elapsed_time/len(eval_dataset):.2f} seconds/sample") + print() + + # Combine results + results = [] + predictions = [] + for i, (sample, inference_result) in enumerate(zip(eval_dataset, inference_results)): + result = { + 'video_path': sample['video_url'], + 'ground_truth': sample['pc'], + 'predicted_score': inference_result['predicted_score'], + 'response': inference_result['response'], + 'caption': sample['caption'] + } + results.append(result) + predictions.append(inference_result['predicted_score']) + + # Calculate metrics + print("Calculating metrics...") + metrics = calculate_metrics(predictions, ground_truths) + + if metrics: + print() + print("=" * 80) + print("EVALUATION RESULTS") + print("=" * 80) + print(f"Total Samples: {metrics['total_samples']}") + print(f"Valid Predictions: {metrics['valid_predictions']}") + print(f"Failed Predictions: {metrics['failed_predictions']}") + print(f"Exact Accuracy: {metrics['exact_accuracy']:.2%}") + print(f"Within ±1 Accuracy: {metrics['within_1_accuracy']:.2%}") + print(f"Mean Absolute Error: {metrics['mean_absolute_error']:.3f}") + print() + print("Binary Classification (Bad vs Good):") + print(f" Bad Videos F1: {metrics['binary_metrics']['bad_videos']['f1_score']:.2%}") + print(f" Good Videos F1: {metrics['binary_metrics']['good_videos']['f1_score']:.2%}") + print("=" * 80) + + # Save results + print() + print("Saving results...") + + # Save JSON + json_path = output_dir / "evaluation_results.json" + with open(json_path, 'w') as f: + json.dump({ + 'metrics': metrics, + 'results': results, + 'config': { + 'model_path': args.model_path, + 'eval_dataset': args.eval_dataset, + 'num_samples': len(eval_dataset), + } + }, f, indent=2) + print(f" JSON: {json_path}") + + # Generate HTML report + html_path = output_dir / "evaluation_report.html" + generate_html_report(results, metrics, html_path) + print(f" HTML: {html_path}") + + print() + print("=" * 80) + print("Evaluation completed successfully!") + print(f"Results saved to: {args.output_dir}") + print(f"Open the HTML report: {html_path}") + print("=" * 80) + + +if __name__ == "__main__": + main() + +