-
Notifications
You must be signed in to change notification settings - Fork 2.2k
SIMBA Improvements #8766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
SIMBA Improvements #8766
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
|
||
import ast | ||
import enum | ||
import inspect | ||
|
@@ -13,6 +14,7 @@ | |
from dspy.adapters.types.base_type import Type | ||
from dspy.signatures.utils import get_dspy_field_type | ||
|
||
NoneType = type(None) | ||
|
||
def serialize_for_json(value: Any) -> Any: | ||
""" | ||
|
@@ -132,8 +134,19 @@ def find_enum_member(enum, identifier): | |
|
||
raise ValueError(f"{identifier} is not a valid name or value for the enum {enum.__name__}") | ||
|
||
def _strip_optional(ann): | ||
"""If ann is Union[..., NoneType] return the non‑None part, else ann.""" | ||
if get_origin(ann) is Union and NoneType in get_args(ann): | ||
# keep the first non‑None member (there will be only one in Optional[T]) | ||
return next(a for a in get_args(ann) if a is not NoneType) | ||
return ann | ||
Comment on lines
+137
to
+142
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring uses an en-dash character (‑) instead of a regular hyphen (-) in 'non‑None'. This should be corrected for consistency and readability. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
|
||
def parse_value(value, annotation): | ||
annotation = _strip_optional(annotation) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fixes the following failure case: Previously, this was failing for fields of Union(str, None) with values of type 'str' that could also be parsed as ints. Ex: "9812750". The problem is in this sequence: This fix involves first parsing the Optional field to get the expected non-Null type, and parse the value according to this. Now pydantic handles the type coercion correctly from str -> str instead of int-> int when the non-null annotation type is 'str' There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we just need to change the condition from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current approach doesn't handle str | None IIUC, let me file a PR to fix this issue so that we can keep this PR focuses on simba. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #8774, which handles the parse issue |
||
|
||
if value is None: | ||
return None | ||
|
||
if annotation is str: | ||
return str(value) | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -31,6 +31,8 @@ def __init__( | |||||
num_candidates: int = 6, | ||||||
max_steps: int = 8, | ||||||
max_demos: int = 4, | ||||||
prompt_model: Any | None = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw, can we add type annotations for other arguments too? |
||||||
teacher_settings: dict | None = None, | ||||||
demo_input_field_maxlen: int = 100_000, | ||||||
num_threads: int | None = None, | ||||||
temperature_for_sampling: float = 0.2, | ||||||
|
@@ -62,6 +64,8 @@ def __init__( | |||||
self.num_candidates = num_candidates | ||||||
self.max_steps = max_steps | ||||||
self.max_demos = max_demos | ||||||
self.prompt_model = prompt_model if prompt_model else dspy.settings.lm | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
self.teacher_settings = teacher_settings | ||||||
self.demo_input_field_maxlen = demo_input_field_maxlen | ||||||
self.num_threads = num_threads | ||||||
|
||||||
|
@@ -175,7 +179,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]) -> None: | |||||
|
||||||
# We'll generate (program, model) pairs for the trajectory sampling. | ||||||
# Prepare distinct LMs (with different temperatures, etc.) from the baseline=programs[0]. | ||||||
models = prepare_models_for_resampling(programs[0], self.num_candidates) | ||||||
models = prepare_models_for_resampling(programs[0], self.num_candidates, self.teacher_settings) | ||||||
top_programs = top_k_plus_baseline(self.num_candidates) | ||||||
|
||||||
exec_pairs = [] | ||||||
|
@@ -278,6 +282,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]) -> None: | |||||
name2predictor=name2predictor, | ||||||
batch_10p_score=batch_10th_percentile_score, | ||||||
batch_90p_score=batch_90th_percentile_score, | ||||||
prompt_model=self.prompt_model, | ||||||
) | ||||||
except Exception as e: | ||||||
logger.error(f"Strategy failed with error: {e}") | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -11,13 +11,25 @@ | |||||||
|
||||||||
logger = logging.getLogger(__name__) | ||||||||
|
||||||||
|
||||||||
def prepare_models_for_resampling(program: dspy.Module, n: int): | ||||||||
def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: dict | None = None): | ||||||||
lm = program.get_lm() or dspy.settings.lm | ||||||||
start = lm.kwargs.get("rollout_id", 0) | ||||||||
rollout_ids = [start + i for i in range(n)] | ||||||||
return [lm.copy(rollout_id=r, temperature=1.0) for r in rollout_ids] | ||||||||
|
||||||||
start_rollout_id = lm.kwargs.get("rollout_id", 0) | ||||||||
rollout_ids = [start_rollout_id + i for i in range(n)] | ||||||||
|
||||||||
|
||||||||
start_rollout_idx, models = 0, [] | ||||||||
# If we have a teacher model, use this as the first model | ||||||||
if teacher_settings: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has been updated to add support for teacher model (used for 1 of the N trajectories) |
||||||||
teacher_lm = teacher_settings.get("lm") or lm | ||||||||
teacher_lm.kwargs["rollout_id"] = rollout_ids[start_rollout_idx] | ||||||||
models.append(teacher_lm) | ||||||||
Comment on lines
+25
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Direct mutation of
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||
start_rollout_idx += 1 | ||||||||
|
||||||||
# The rest of the models are just copies of the base model | ||||||||
models.extend([lm.copy(rollout_id=r, temperature=1.0) for r in rollout_ids[start_rollout_idx:]]) | ||||||||
|
||||||||
return models | ||||||||
|
||||||||
def wrap_program(program: dspy.Module, metric: Callable): | ||||||||
def wrapped_program(example): | ||||||||
|
@@ -26,33 +38,51 @@ def wrapped_program(example): | |||||||
try: | ||||||||
prediction = program(**example.inputs()) | ||||||||
except Exception as e: | ||||||||
print(e) | ||||||||
logger.info(e) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be |
||||||||
trace = dspy.settings.trace.copy() | ||||||||
|
||||||||
output = None | ||||||||
score = 0.0 | ||||||||
output_metadata = {} | ||||||||
|
||||||||
try: | ||||||||
score = metric(example, prediction) | ||||||||
output = metric(example, prediction) | ||||||||
if isinstance(output, (int, float)): | ||||||||
score = output | ||||||||
elif isinstance(output, dspy.Prediction): | ||||||||
if not hasattr(output, "score"): | ||||||||
raise ValueError("dspy.Prediction must contain a 'score' attribute") | ||||||||
score = output.score | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to handle additional metric metadata in addition to the score. To do this, we check if the output from the metric is a float or int (in which case we use it as the score) or a dspy.Prediction object, which contains a score + potentially additional meta-data There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can also use |
||||||||
# Just extract fields from _store, excluding 'score' | ||||||||
output_metadata = { | ||||||||
k: v for k, v in output._store.items() if k != "score" | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't we use |
||||||||
} | ||||||||
except Exception as e: | ||||||||
print(e) | ||||||||
logger.info(e) | ||||||||
|
||||||||
# Include the `example` in the output for subsequent usage in buckets/strategies. | ||||||||
return { | ||||||||
"prediction": prediction, | ||||||||
"trace": trace, | ||||||||
"score": score, | ||||||||
"example": example | ||||||||
"example": example, | ||||||||
"output_metadata": output_metadata | ||||||||
} | ||||||||
|
||||||||
return wrapped_program | ||||||||
|
||||||||
|
||||||||
|
||||||||
def append_a_demo(demo_input_field_maxlen): | ||||||||
def append_a_demo_(bucket, system, **kwargs): | ||||||||
predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"] | ||||||||
batch_10p_score = kwargs["batch_10p_score"] | ||||||||
|
||||||||
trace = bucket[0]["trace"] | ||||||||
good = bucket[0] | ||||||||
trace = good["trace"] | ||||||||
name2demo = {} | ||||||||
|
||||||||
if good["score"] <= batch_10p_score: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double checking that the demo we're appending is not below the 10th percentile of scores There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||||||||
logger.info(f"Skipping appending a demo as good score {good['score']} is at or below the 10th percentile.") | ||||||||
return False | ||||||||
Comment on lines
+82
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] The condition check and logic for skipping demo appending is duplicated between Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||
|
||||||||
for step in trace: | ||||||||
predictor, _inputs, _outputs = step | ||||||||
|
||||||||
|
@@ -63,7 +93,6 @@ def append_a_demo_(bucket, system, **kwargs): | |||||||
demo = dspy.Example(augmented=True, **_inputs, **_outputs) | ||||||||
name = predictor2name[id(predictor)] | ||||||||
name2demo[name] = demo # keep the last demo for each predictor | ||||||||
|
||||||||
for name, demo in name2demo.items(): | ||||||||
predictor = name2predictor[name] | ||||||||
predictor.demos.append(demo) | ||||||||
|
@@ -77,14 +106,15 @@ def append_a_demo_(bucket, system, **kwargs): | |||||||
def append_a_rule(bucket, system, **kwargs): | ||||||||
predictor2name = kwargs["predictor2name"] | ||||||||
batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] | ||||||||
prompt_model = kwargs["prompt_model"] or dspy.settings.lm | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. q: is it possible that |
||||||||
|
||||||||
module_names = [name for name, _ in system.named_predictors()] | ||||||||
good, bad = bucket[0], bucket[-1] | ||||||||
example = good["example"] | ||||||||
|
||||||||
if good["score"] < batch_10p_score or bad["score"] > batch_90p_score: | ||||||||
logger.info(f"Skipping rule generation as good score {good['score']} is below the 10th percentile " | ||||||||
f"*or* bad score {bad['score']} is above the 90th percentile.") | ||||||||
if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
logger.info(f"Skipping rule generation as good score {good['score']} is at or below the 10th percentile " | ||||||||
f"*or* bad score {bad['score']} is at or above the 90th percentile.") | ||||||||
return False | ||||||||
|
||||||||
if good["score"] <= bad["score"]: | ||||||||
|
@@ -117,12 +147,17 @@ def append_a_rule(bucket, system, **kwargs): | |||||||
"worse_program_outputs": dict(bad["prediction"] or {}), | ||||||||
"worse_reward_value": bad["score"], | ||||||||
"better_reward_value": good["score"], | ||||||||
"worse_reward_info": bad["output_metadata"], | ||||||||
"better_reward_info": good["output_metadata"], | ||||||||
"module_names": module_names, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adding in the metric meta-data (ex. feedback from a judge) to help come up with a better set of rules |
||||||||
} | ||||||||
|
||||||||
kwargs = {k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode() | ||||||||
for k, v in kwargs.items()} | ||||||||
advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice | ||||||||
|
||||||||
with dspy.settings.context(trace=[], lm=prompt_model): | ||||||||
advice_program = dspy.Predict(OfferFeedback) | ||||||||
advice = advice_program(**kwargs).module_advice | ||||||||
|
||||||||
for name, predictor in system.named_predictors(): | ||||||||
if name in advice: | ||||||||
|
@@ -156,11 +191,13 @@ class OfferFeedback(dspy.Signature): | |||||||
) | ||||||||
worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") | ||||||||
worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") | ||||||||
worse_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") | ||||||||
better_program_trajectory: str = InputField( | ||||||||
desc="The trajectory of the program's execution, showing each module's I/O" | ||||||||
) | ||||||||
better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") | ||||||||
better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") | ||||||||
better_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") | ||||||||
module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice") | ||||||||
discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did") | ||||||||
module_advice: dict[str, str] = OutputField( | ||||||||
|
@@ -170,7 +207,6 @@ class OfferFeedback(dspy.Signature): | |||||||
"like the successful trajectory rather than the lower-scoring trajectory." | ||||||||
) | ||||||||
|
||||||||
|
||||||||
def inspect_modules(program): | ||||||||
separator = "-" * 80 | ||||||||
output = [separator] | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: blank line