Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions dspy/adapters/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: blank line

import ast
import enum
import inspect
Expand All @@ -13,6 +14,7 @@
from dspy.adapters.types.base_type import Type
from dspy.signatures.utils import get_dspy_field_type

NoneType = type(None)

def serialize_for_json(value: Any) -> Any:
"""
Expand Down Expand Up @@ -132,8 +134,19 @@ def find_enum_member(enum, identifier):

raise ValueError(f"{identifier} is not a valid name or value for the enum {enum.__name__}")

def _strip_optional(ann):
"""If ann is Union[..., NoneType] return the non‑None part, else ann."""
if get_origin(ann) is Union and NoneType in get_args(ann):
# keep the first non‑None member (there will be only one in Optional[T])
return next(a for a in get_args(ann) if a is not NoneType)
return ann
Comment on lines +137 to +142
Copy link
Preview

Copilot AI Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring uses an en-dash character (‑) instead of a regular hyphen (-) in 'non‑None'. This should be corrected for consistency and readability.

Copilot uses AI. Check for mistakes.


def parse_value(value, annotation):
annotation = _strip_optional(annotation)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes the following failure case:

Previously, this was failing for fields of Union(str, None) with values of type 'str' that could also be parsed as ints. Ex: "9812750".

The problem is in this sequence:
value = "9812750" (string)
annotation = typing.Optional[str]
candidate = json_repair.loads("9812750") → 9812750 (parses as integer, not str)
TypeAdapter(typing.Optional[str]).validate_python(9812750) → Fails with pydantic.ValidationError since the value is neither a str nor None
Exception handler is triggered: except pydantic.ValidationError as e:
Then we hit this line : issubclass(annotation, Type), which throws the error issubclass() arg 1 must be a class because typing.Optional[str] is not a class, since it's a type annotation/Union.

This fix involves first parsing the Optional field to get the expected non-Null type, and parse the value according to this. Now pydantic handles the type coercion correctly from str -> str instead of int-> int when the non-null annotation type is 'str'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we just need to change the condition from issubclass(annotation, Type) to inspect.isclass(annotation) and issubclass(annotation, Type)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current approach doesn't handle str | None IIUC, let me file a PR to fix this issue so that we can keep this PR focuses on simba.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#8774, which handles the parse issue


if value is None:
return None

if annotation is str:
return str(value)

Expand Down
7 changes: 6 additions & 1 deletion dspy/teleprompt/simba.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(
num_candidates: int = 6,
max_steps: int = 8,
max_demos: int = 4,
prompt_model: Any | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prompt_model: Any | None = None,
prompt_model: dspy.LM | None = None,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, can we add type annotations for other arguments too?

teacher_settings: dict | None = None,
demo_input_field_maxlen: int = 100_000,
num_threads: int | None = None,
temperature_for_sampling: float = 0.2,
Expand Down Expand Up @@ -62,6 +64,8 @@ def __init__(
self.num_candidates = num_candidates
self.max_steps = max_steps
self.max_demos = max_demos
self.prompt_model = prompt_model if prompt_model else dspy.settings.lm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.prompt_model = prompt_model if prompt_model else dspy.settings.lm
self.prompt_model = prompt_model or dspy.settings.lm

self.teacher_settings = teacher_settings
self.demo_input_field_maxlen = demo_input_field_maxlen
self.num_threads = num_threads

Expand Down Expand Up @@ -175,7 +179,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]) -> None:

# We'll generate (program, model) pairs for the trajectory sampling.
# Prepare distinct LMs (with different temperatures, etc.) from the baseline=programs[0].
models = prepare_models_for_resampling(programs[0], self.num_candidates)
models = prepare_models_for_resampling(programs[0], self.num_candidates, self.teacher_settings)
top_programs = top_k_plus_baseline(self.num_candidates)

exec_pairs = []
Expand Down Expand Up @@ -278,6 +282,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]) -> None:
name2predictor=name2predictor,
batch_10p_score=batch_10th_percentile_score,
batch_90p_score=batch_90th_percentile_score,
prompt_model=self.prompt_model,
)
except Exception as e:
logger.error(f"Strategy failed with error: {e}")
Expand Down
74 changes: 55 additions & 19 deletions dspy/teleprompt/simba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,25 @@

logger = logging.getLogger(__name__)


def prepare_models_for_resampling(program: dspy.Module, n: int):
def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: dict | None = None):
lm = program.get_lm() or dspy.settings.lm
start = lm.kwargs.get("rollout_id", 0)
rollout_ids = [start + i for i in range(n)]
return [lm.copy(rollout_id=r, temperature=1.0) for r in rollout_ids]

start_rollout_id = lm.kwargs.get("rollout_id", 0)
rollout_ids = [start_rollout_id + i for i in range(n)]


start_rollout_idx, models = 0, []
# If we have a teacher model, use this as the first model
if teacher_settings:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been updated to add support for teacher model (used for 1 of the N trajectories)

teacher_lm = teacher_settings.get("lm") or lm
teacher_lm.kwargs["rollout_id"] = rollout_ids[start_rollout_idx]
models.append(teacher_lm)
Comment on lines +25 to +26
Copy link
Preview

Copilot AI Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Direct mutation of teacher_lm.kwargs could cause side effects if the teacher model is reused elsewhere. Consider using teacher_lm.copy() and setting the rollout_id on the copy instead.

Suggested change
teacher_lm.kwargs["rollout_id"] = rollout_ids[start_rollout_idx]
models.append(teacher_lm)
models.append(teacher_lm.copy(rollout_id=rollout_ids[start_rollout_idx]))

Copilot uses AI. Check for mistakes.

start_rollout_idx += 1

# The rest of the models are just copies of the base model
models.extend([lm.copy(rollout_id=r, temperature=1.0) for r in rollout_ids[start_rollout_idx:]])

return models

def wrap_program(program: dspy.Module, metric: Callable):
def wrapped_program(example):
Expand All @@ -26,33 +38,51 @@ def wrapped_program(example):
try:
prediction = program(**example.inputs())
except Exception as e:
print(e)
logger.info(e)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be logger.warning or logger.error?

trace = dspy.settings.trace.copy()

output = None
score = 0.0
output_metadata = {}

try:
score = metric(example, prediction)
output = metric(example, prediction)
if isinstance(output, (int, float)):
score = output
elif isinstance(output, dspy.Prediction):
if not hasattr(output, "score"):
raise ValueError("dspy.Prediction must contain a 'score' attribute")
score = output.score
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to handle additional metric metadata in addition to the score. To do this, we check if the output from the metric is a float or int (in which case we use it as the score) or a dspy.Prediction object, which contains a score + potentially additional meta-data

Copy link
Collaborator

@TomeHirata TomeHirata Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also use float(output), which might be more intuitive?

# Just extract fields from _store, excluding 'score'
output_metadata = {
k: v for k, v in output._store.items() if k != "score"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we use output.items()?

}
except Exception as e:
print(e)
logger.info(e)

# Include the `example` in the output for subsequent usage in buckets/strategies.
return {
"prediction": prediction,
"trace": trace,
"score": score,
"example": example
"example": example,
"output_metadata": output_metadata
}

return wrapped_program



def append_a_demo(demo_input_field_maxlen):
def append_a_demo_(bucket, system, **kwargs):
predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"]
batch_10p_score = kwargs["batch_10p_score"]

trace = bucket[0]["trace"]
good = bucket[0]
trace = good["trace"]
name2demo = {}

if good["score"] <= batch_10p_score:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checking that the demo we're appending is not below the 10th percentile of scores

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

logger.info(f"Skipping appending a demo as good score {good['score']} is at or below the 10th percentile.")
return False
Comment on lines +82 to +84
Copy link
Preview

Copilot AI Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The condition check and logic for skipping demo appending is duplicated between append_a_demo_ and append_a_rule functions. Consider extracting this into a shared helper function to reduce code duplication.

Copilot uses AI. Check for mistakes.


for step in trace:
predictor, _inputs, _outputs = step

Expand All @@ -63,7 +93,6 @@ def append_a_demo_(bucket, system, **kwargs):
demo = dspy.Example(augmented=True, **_inputs, **_outputs)
name = predictor2name[id(predictor)]
name2demo[name] = demo # keep the last demo for each predictor

for name, demo in name2demo.items():
predictor = name2predictor[name]
predictor.demos.append(demo)
Expand All @@ -77,14 +106,15 @@ def append_a_demo_(bucket, system, **kwargs):
def append_a_rule(bucket, system, **kwargs):
predictor2name = kwargs["predictor2name"]
batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"]
prompt_model = kwargs["prompt_model"] or dspy.settings.lm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: is it possible that prompt_model is not passed? Maybe kwargs.get("prompt_model") is safer


module_names = [name for name, _ in system.named_predictors()]
good, bad = bucket[0], bucket[-1]
example = good["example"]

if good["score"] < batch_10p_score or bad["score"] > batch_90p_score:
logger.info(f"Skipping rule generation as good score {good['score']} is below the 10th percentile "
f"*or* bad score {bad['score']} is above the 90th percentile.")
if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score:
if good <= batch_10p_score or bad >= batch_90p_score:

logger.info(f"Skipping rule generation as good score {good['score']} is at or below the 10th percentile "
f"*or* bad score {bad['score']} is at or above the 90th percentile.")
return False

if good["score"] <= bad["score"]:
Expand Down Expand Up @@ -117,12 +147,17 @@ def append_a_rule(bucket, system, **kwargs):
"worse_program_outputs": dict(bad["prediction"] or {}),
"worse_reward_value": bad["score"],
"better_reward_value": good["score"],
"worse_reward_info": bad["output_metadata"],
"better_reward_info": good["output_metadata"],
"module_names": module_names,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding in the metric meta-data (ex. feedback from a judge) to help come up with a better set of rules

}

kwargs = {k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode()
for k, v in kwargs.items()}
advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice

with dspy.settings.context(trace=[], lm=prompt_model):
advice_program = dspy.Predict(OfferFeedback)
advice = advice_program(**kwargs).module_advice

for name, predictor in system.named_predictors():
if name in advice:
Expand Down Expand Up @@ -156,11 +191,13 @@ class OfferFeedback(dspy.Signature):
)
worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing")
worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs")
worse_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.")
better_program_trajectory: str = InputField(
desc="The trajectory of the program's execution, showing each module's I/O"
)
better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing")
better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs")
better_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.")
module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice")
discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did")
module_advice: dict[str, str] = OutputField(
Expand All @@ -170,7 +207,6 @@ class OfferFeedback(dspy.Signature):
"like the successful trajectory rather than the lower-scoring trajectory."
)


def inspect_modules(program):
separator = "-" * 80
output = [separator]
Expand Down