Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
d636214
fix megatron batch validation
Harryllh Oct 5, 2025
45706a5
change divisible check
Harryllh Oct 6, 2025
c0c1303
check strategy flag existence
Harryllh Oct 8, 2025
6cd33bb
Merge branch 'main' of https://github.com/erictang000/SkyRL into mega…
erictang000 Oct 8, 2025
86c6345
fix tests
erictang000 Oct 8, 2025
1b3cbf0
Merge branch 'NovaSky-AI:main' into main
Harryllh Oct 11, 2025
557a275
Merge branch 'NovaSky-AI:main' into main
Harryllh Oct 22, 2025
b0103e7
Merge branch 'NovaSky-AI:main' into main
Harryllh Oct 29, 2025
8c3f93e
Merge branch 'NovaSky-AI:main' into main
Harryllh Nov 12, 2025
7192106
Merge branch 'NovaSky-AI:main' into main
Harryllh Dec 4, 2025
790e2fd
Merge branch 'NovaSky-AI:main' into main
Harryllh Dec 21, 2025
fc657ec
launch code
Harryllh Dec 22, 2025
0890817
updated bash script
Harryllh Dec 25, 2025
2f8000b
dspy integration started
Harryllh Dec 25, 2025
6ee5085
simplified code
Harryllh Dec 26, 2025
cb51d72
lcb code
Harryllh Dec 27, 2025
6b077f8
clean up
Harryllh Dec 28, 2025
4b6405c
feat: finish dataset
J-Ch-n Dec 28, 2025
40b9e26
refactor: change argument passing
J-Ch-n Dec 28, 2025
c76d7f5
feat: add example file
J-Ch-n Dec 28, 2025
aa38573
refactor: modify data paths in example file
J-Ch-n Dec 28, 2025
f675720
refactor: move example file
J-Ch-n Dec 28, 2025
db01413
feat: move is_stdin from test to example level
J-Ch-n Dec 28, 2025
9c0d45a
Merge pull request #1 from Harryllh/assertion-dataset
Harryllh Dec 29, 2025
765bb48
push before sleep
Harryllh Dec 29, 2025
01b3950
trace collection
Harryllh Dec 29, 2025
fbcc2c0
bash script
Harryllh Dec 30, 2025
409c716
merged from assertion-dataset
Harryllh Dec 30, 2025
9973f59
Merge branch 'assertion' of https://github.com/Harryllh/SkyRL into as…
Harryllh Dec 30, 2025
201195b
`run_lcb.sh` runs (#2)
J-Ch-n Dec 30, 2025
2469ca3
merged from assertion-dataset + clean-up
Harryllh Dec 30, 2025
5a12b75
extra dspy programs
Harryllh Dec 30, 2025
5a8f1ee
working training
Harryllh Jan 2, 2026
ea50aef
epoch and final reward function
Harryllh Jan 2, 2026
99c757b
error handling and timeouts
Harryllh Jan 2, 2026
ed0b0ed
max tokens
Harryllh Jan 2, 2026
44984a4
papillon and hover start code
Harryllh Jan 3, 2026
fe3f8c7
new apis and papillon code
Harryllh Jan 4, 2026
e3c8681
debug
Harryllh Jan 4, 2026
1368923
async papillon supported. async lcb still needs testing
Harryllh Jan 6, 2026
4c5a2bb
Hover (#3)
Harryllh Jan 7, 2026
d9ad220
trace collection change
Harryllh Jan 8, 2026
1e33f96
added chat template
Harryllh Jan 8, 2026
f1e7879
robust trace collection. reasoning trace included
Harryllh Jan 8, 2026
5d32c7d
ready for 7b training
Harryllh Jan 8, 2026
bb5046c
lcb data pipeline
Harryllh Jan 9, 2026
5a3ec2e
path issues
Harryllh Jan 9, 2026
44ba540
updated 8b training script. vllm engine update. concurrency update.
Harryllh Jan 11, 2026
9f11a23
working lcb and LM infra change
Harryllh Jan 12, 2026
fc4f7e3
working papillon
Harryllh Jan 12, 2026
f056e3e
Banking77 Works (#4)
J-Ch-n Jan 12, 2026
a20dc62
efficient lcb
Harryllh Jan 12, 2026
fc46729
router script
Harryllh Jan 12, 2026
5116e10
push before sleep
Jan 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions skyrl-train/examples/dspy/banking77/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import uuid
import dspy
import random
from dspy.datasets import DataLoader
from datasets import load_dataset

CLASSES = load_dataset("PolyAI/banking77", split="train", trust_remote_code=True).features["label"].names

def banking77_data():
kwargs = {"fields": ("text", "label"), "input_keys": ("text",), "split": "train", "trust_remote_code": True}

trainset = [
dspy.Example(x, hint=CLASSES[x.label], label=CLASSES[x.label], task_id=uuid.uuid4()).with_inputs("text", "hint")
for x in DataLoader().from_huggingface(dataset_name="PolyAI/banking77", **kwargs)[:1000]
]
validationset = [
dspy.Example(x, hint=CLASSES[x.label], label=CLASSES[x.label], task_id=uuid.uuid4()).with_inputs("text", "hint")
for x in DataLoader().from_huggingface(dataset_name="PolyAI/banking77", **kwargs)[1000: 1200]
]

random.Random(0).shuffle(trainset)

return trainset, validationset





43 changes: 43 additions & 0 deletions skyrl-train/examples/dspy/banking77/programs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import dspy
from typing import List
from .data import CLASSES
from dspy.adapters import XMLAdapter
from dspy.dsp.utils import deduplicate

class Banking77(dspy.Module):
def __init__(self):
self.intent_classifier = dspy.ChainOfThought(dspy.Signature(f"text -> label: Literal{CLASSES}"))
self.adapter = XMLAdapter()

def forward(self, text: str) -> str:
intent = self.intent_classifier(text=text)

return intent

class Banking77_intent_classifier(Banking77):
def __init__(self):
super().__init__()
self.intent_classifier_traces = []
self.intents = []

async def forward(self, example) -> str:
text = example.get("text")
intent = await self.intent_classifier.acall(text=text)

self.append_trace(example, intent)
return intent

def append_trace(self, kwargs, pred):
finetune_data = self.adapter.format_finetune_data(
signature=self.intent_classifier.predictors()[0].signature,
inputs=kwargs,
outputs=pred,
demos=[] # TODO: Add support for demos
)

all_messages = finetune_data.get('messages', [])
self.intent_classifier_traces.extend(all_messages)
self.intents.append(pred)

def collect_trace(self, example, pred):
return self.intent_classifier_traces, self.intents
22 changes: 22 additions & 0 deletions skyrl-train/examples/dspy/banking77/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .data import CLASSES
# Final task reward (correctness)
async def banking77_final_reward_fn(example, pred, trace=None):
label = pred.get("label")
gold = example.get("label")

if label is None:
return 0.0

return 1.0 if label == gold else 0.0


# Local validity / constraint reward
async def banking77_local_reward_fn(example, pred):
assert len(pred) == 1, "Pred should have only one element"
label = pred[0].get("label")
gold = example.get("label")

if label is None:
return 0.0 if label == gold else 0.0

return 0.5 if label in CLASSES else 0.0
290 changes: 290 additions & 0 deletions skyrl-train/examples/dspy/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
import logging
from typing import List, Any, Union
from uuid import UUID, uuid4
import dspy
import json
import pickle
import os
import base64
import zlib
from pydantic import BaseModel
from .utils import get_benchmark_data

logger = logging.getLogger(__name__)


def _has_test_type(tests, test_type):
"""Check if any test in the test list has 'testtype' set to 'type'.

Args:
tests: Can be a JSON string, list, or dict
test_type: The test type to check for (e.g., 'stdin', 'functional')
"""
# Handle different input types
if isinstance(tests, str):
try:
test_list = json.loads(tests)
except json.JSONDecodeError:
return False
elif isinstance(tests, list):
test_list = tests
elif isinstance(tests, dict):
# If it's a dict, check if it has a list value
test_list = tests.get("tests", tests.get("public", tests.get("private", [])))
if not isinstance(test_list, list):
return False
else:
return False

for test in test_list:
if isinstance(test, dict) and test.get("testtype") == test_type:
return True
return False


def _translate_private_test_cases(encoded_data):
"""Decode and decompress private test cases."""
decoded_data = base64.b64decode(encoded_data)
decompressed_data = zlib.decompress(decoded_data)
original_data = pickle.loads(decompressed_data)
return json.loads(original_data)


def _update_dataset_in_place(dataset):
"""Helper function to translate the test cases."""
for i, entry in enumerate(dataset):
tests = entry.get("tests")
if tests is None:
continue

# Check if already decoded (is a list/dict)
if isinstance(tests, (list, dict)):
# Already decoded, no need to decode
continue

# Try to decode if it's a string (might be encoded)
if isinstance(tests, str):
try:
# First try to parse as JSON (might already be JSON string)
try:
decoded = json.loads(tests)
entry["tests"] = decoded
continue
except json.JSONDecodeError:
pass

# If not JSON, try to decode as base64/zlib encoded
decoded_tests = _translate_private_test_cases(tests)
entry["tests"] = decoded_tests
except Exception as e:
logger.warning(f"Failed to decode test cases for entry {i}: {e}")
# Keep original if decoding fails


def _map_to_dspy_example(row):
"""Map a dataset row to a dspy example format.

Returns only:
- prompt: from problem/question_content
- tests: from tests
"""
return {
"prompt": row["question_content"],
"tests": row["tests"],
}


class LCBExample(BaseModel):
"""Pydantic model representing a DSPy example with UUID."""
uuid: UUID
prompt: str
tests: Union[List[dict], dict, Any]

class Config:
arbitrary_types_allowed = True


class DSPyDataset:
"""
A dataset that loads Live Code Bench data and converts it to DSPy examples.
"""

def __init__(
self,
benchmark_name: str,
max_num_examples: int = None,
):
"""
Initialize the DSPyDataset.

Args:
data_file: JSON file path (e.g., "/path/to/livecodebench.json")
max_num_examples: Maximum number of examples to return. If None, returns all examples.
"""

self.benchmark_name = benchmark_name
self.train_set, self.test_set = get_benchmark_data(benchmark_name)

self.examples = self.train_set
if max_num_examples:
self.examples = self.train_set[:max_num_examples]

# self.data_file = data_file
# self.max_num_examples = max_num_examples
# print('loading dspy dataset...')
# pkl_path = "/home/ray/data/lcb/live_code_bench_dataset_test.pkl"
# with open(pkl_path, "rb") as f:
# examples = pickle.load(f)
# train_set, test_set = examples[:400], examples[400:]
# self.examples = train_set
# print('done loading dspy dataset')

logger.info(f"DSPyDataset initialized with {len(self.examples)} examples")

def _load_dataset(self) -> List[dspy.Example]:
"""Load dataset from JSON file."""
if not os.path.exists(self.data_file):
logger.warning(f"JSON file does not exist: {self.data_file}")
return []

if not self.data_file.endswith(".json"):
logger.warning(f"File is not a JSON file: {self.data_file}")
return []

logger.info(f"Loading dataset from JSON file: {self.data_file}")
examples = self._load_json_file(self.data_file, "train")

# Apply limit if specified
if self.max_num_examples is not None and len(examples) > self.max_num_examples:
examples = examples[:self.max_num_examples]
logger.info(f"Limited dataset to {self.max_num_examples} examples")

return examples

def _load_json_file(self, json_file_path: str, split_name: str = None) -> List[dspy.Example]:
"""Load and process a JSON file into DSPy examples.

Args:
json_file_path: Path to the JSON file
split_name: Name of the split (e.g., "train", "test") for logging purposes
"""
try:
with open(json_file_path, "r", encoding="utf-8") as f:
data = json.load(f)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON format in {json_file_path}: {e}")
return []
except Exception as e:
logger.error(f"Error loading JSON file {json_file_path}: {e}")
return []

# Convert to list if it's a single dict
if isinstance(data, dict):
data = [data]
elif not isinstance(data, list):
logger.error(f"JSON file {json_file_path} does not contain a list or dict")
return []

# Process entries to match the expected format
processed_entries = []
for idx, entry in enumerate(data):
processed_entry = self._process_json_entry(entry, idx)
if processed_entry is not None:
processed_entries.append(processed_entry)

# Decode test cases if they are encoded
_update_dataset_in_place(processed_entries)

# Create dspy.Example objects with UUIDs
examples = []
for row in processed_entries:
example = dspy.Example(**_map_to_dspy_example(row)).with_inputs(
"prompt", "tests"
)
# Add UUID to the example
example.uuid = uuid4()
# Add is_stdin to the example
example.is_stdin = _has_test_type(example.tests, "stdin")

examples.append(example)
return examples

def _process_json_entry(self, entry: dict, index: int = None) -> dict:
"""Process a single JSON entry to match the expected format.

Based on investigation, the JSON files have:
- problem: str (the question) -> question_content
- tests: list of dicts with 'input', 'output', 'testtype' keys -> tests

Returns dict with only:
- question_content: for prompt
- tests: for tests
"""
# Check if entry already has the expected format
if "question_content" in entry and "tests" in entry:
# Already in the expected format
return entry

processed = {}

# Map question/problem -> question_content
if "problem" in entry:
processed["question_content"] = entry["problem"]
elif "question_content" in entry:
processed["question_content"] = entry["question_content"]
else:
logger.warning(f"Entry missing 'problem' or 'question_content': {entry.keys()}")
return None

# Map tests -> tests
# The tests field is a list of dicts with 'input', 'output', 'testtype'
if "tests" in entry:
tests = entry["tests"]
# If it's a string, try to parse as JSON
if isinstance(tests, str):
try:
tests = json.loads(tests)
except json.JSONDecodeError:
# If it's not valid JSON, might be encoded - will be handled by _update_dataset_in_place
processed["tests"] = tests
return processed

# Store tests (already a list, no encoding needed)
processed["tests"] = tests
else:
logger.warning(f"Entry missing 'tests': {entry.keys()}")
return None

return processed

def __getitem__(self, index: int) -> LCBExample:
"""Get a DSPy example by index.

Returns a dspy.Example object that conforms to the LCBExample Pydantic model structure.
The example will have a uuid attribute in addition to prompt and tests.
"""
if index >= len(self.examples):
raise IndexError(f"Index {index} out of range for dataset of size {len(self.examples)}")
return {
"prompt": self.examples[index],
"env_class": None,
"env_extras": None,
"uid": str(index),
}

def __len__(self) -> int:
"""Return the number of examples in the dataset."""
return len(self.examples)

def __iter__(self):
"""Iterate over all DSPy examples."""
for index, example in enumerate(self.examples):
yield {
"prompt": example,
"env_class": None,
"env_extras": None,
"uid": str(index),
}

def collate_fn(self, item_list):
return item_list
Loading