-
-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #450 from lion-agi/dev_new_direct
added new select
- Loading branch information
Showing
12 changed files
with
705 additions
and
290 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from .action_model import ActionModel | ||
from .brainstorm_model import BrainstormModel | ||
from .plan_model import PlanModel | ||
from .reason_model import ReasonModel | ||
from .step_model import StepModel | ||
|
||
__all__ = [ | ||
"ReasonModel", | ||
"StepModel", | ||
"BrainstormModel", | ||
"ActionModel", | ||
"PlanModel", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from typing import Any | ||
|
||
from lionfuncs import to_dict, validate_str | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
|
||
class ActionModel(BaseModel): | ||
|
||
title: str = Field( | ||
..., | ||
title="Title", | ||
description="Provide a concise title summarizing the action.", | ||
) | ||
content: str = Field( | ||
..., | ||
title="Content", | ||
description="Provide a brief description of the action to be performed.", | ||
) | ||
function: str = Field( | ||
..., | ||
title="Function", | ||
description=( | ||
"Specify the name of the function to execute. **Choose from the provided " | ||
"`tool_schema`; do not invent function names.**" | ||
), | ||
examples=["print", "add", "len"], | ||
) | ||
arguments: dict[str, Any] = Field( | ||
{}, | ||
title="Arguments", | ||
description=( | ||
"Provide the arguments to pass to the function as a dictionary. **Use " | ||
"argument names and types as specified in the `tool_schema`; do not " | ||
"invent argument names.**" | ||
), | ||
examples=[{"num1": 1, "num2": 2}, {"x": "hello", "y": "world"}], | ||
) | ||
|
||
@field_validator("title", mode="before") | ||
def validate_title(cls, value: Any) -> str: | ||
return validate_str(value, "title") | ||
|
||
@field_validator("content", mode="before") | ||
def validate_content(cls, value: Any) -> str: | ||
return validate_str(value, "content") | ||
|
||
@field_validator("function", mode="before") | ||
def validate_function(cls, value: Any) -> str: | ||
return validate_str(value, "function") | ||
|
||
@field_validator("arguments", mode="before") | ||
def validate_arguments(cls, value: Any) -> dict[str, Any]: | ||
return to_dict( | ||
value, | ||
fuzzy_parse=True, | ||
suppress=True, | ||
recursive=True, | ||
) | ||
|
||
|
||
__all__ = ["ActionModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Any | ||
|
||
from lionfuncs import validate_str | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
from .reason_model import ReasonModel | ||
from .step_model import StepModel | ||
|
||
|
||
class BrainstormModel(BaseModel): | ||
|
||
title: str = Field( | ||
..., | ||
title="Title", | ||
description="Provide a concise title summarizing the brainstorming session.", | ||
) | ||
content: str = Field( | ||
..., | ||
title="Content", | ||
description="Describe the context or focus of the brainstorming session.", | ||
) | ||
ideas: list[StepModel] = Field( | ||
..., | ||
title="Ideas", | ||
description="A list of ideas for the next step, generated during brainstorming.", | ||
) | ||
reason: ReasonModel = Field( | ||
..., | ||
title="Reason", | ||
description="Provide the high level reasoning behind the brainstorming session.", | ||
) | ||
|
||
@field_validator("title", mode="before") | ||
def validate_title(cls, value: Any) -> str: | ||
return validate_str(value, "title") | ||
|
||
@field_validator("content", mode="before") | ||
def validate_content(cls, value: Any) -> str: | ||
return validate_str(value, "content") | ||
|
||
|
||
__all__ = ["BrainstormModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import Any, List | ||
|
||
from lionfuncs import validate_str | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
from .reason_model import ReasonModel | ||
from .step_model import StepModel | ||
|
||
|
||
class PlanModel(BaseModel): | ||
""" | ||
Represents a plan consisting of multiple steps, with an overall reason. | ||
Attributes: | ||
title (str): A concise title summarizing the plan. | ||
content (str): A detailed description of the plan. | ||
reason (ReasonModel): The overall reasoning behind the plan. | ||
steps (List[StepModel]): A list of steps to execute the plan. | ||
""" | ||
|
||
title: str = Field( | ||
..., | ||
title="Title", | ||
description="Provide a concise title summarizing the plan.", | ||
) | ||
content: str = Field( | ||
..., | ||
title="Content", | ||
description="Provide a detailed description of the plan.", | ||
) | ||
reason: ReasonModel = Field( | ||
..., | ||
title="Reason", | ||
description="Provide the reasoning behind the entire plan.", | ||
) | ||
steps: list[StepModel] = Field( | ||
..., | ||
title="Steps", | ||
description="A list of steps to execute the plan.", | ||
) | ||
|
||
@field_validator("title", mode="before") | ||
def validate_title(cls, value: Any) -> str: | ||
return validate_str(value, "title") | ||
|
||
@field_validator("content", mode="before") | ||
def validate_content(cls, value: Any) -> str: | ||
return validate_str(value, "content") | ||
|
||
|
||
__all__ = ["PlanModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import logging | ||
from typing import Any | ||
|
||
from lionfuncs import to_num, validate_str | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
|
||
class ReasonModel(BaseModel): | ||
title: str = Field( | ||
..., | ||
title="Title", | ||
description="Provide a concise title summarizing the reason.", | ||
) | ||
content: str = Field( | ||
..., | ||
title="Content", | ||
description=( | ||
"Provide a detailed explanation supporting the reason, including relevant " | ||
"information or context." | ||
), | ||
) | ||
confidence_score: float | None = Field( | ||
None, | ||
description=( | ||
"Provide an objective numeric confidence score between 0 and 1 (with 3 " | ||
"decimal places) indicating how likely you successfully achieved the task " | ||
"according to user expectation. Interpret the score as:\n" | ||
"- **1**: Very confident in a good job.\n" | ||
"- **0**: Not confident at all.\n" | ||
"- **[0.8, 1]**: You can continue the path of reasoning if needed.\n" | ||
"- **[0.5, 0.8)**: Recheck your reasoning and consider reverting to a " | ||
"previous, more confident reasoning path.\n" | ||
"- **[0, 0.5)**: Stop because the reasoning is starting to be off track." | ||
), | ||
examples=[0.821, 0.257, 0.923, 0.439], | ||
ge=0, | ||
le=1, | ||
) | ||
|
||
@field_validator("title", mode="before") | ||
def validate_title(cls, value: Any) -> str: | ||
return validate_str(value, "title") | ||
|
||
@field_validator("content", mode="before") | ||
def validate_content(cls, value: Any) -> str: | ||
return validate_str(value, "content") | ||
|
||
@field_validator("confidence_score", mode="before") | ||
def validate_confidence_score(cls, value: Any) -> float: | ||
try: | ||
return to_num( | ||
value, | ||
upper_bound=1, | ||
lower_bound=0, | ||
num_type=float, | ||
precision=3, | ||
) | ||
except Exception as e: | ||
logging.error(f"Failed to convert {value} to a number. Error: {e}") | ||
return 0.0 | ||
|
||
|
||
__all__ = ["ReasonModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import logging | ||
from typing import Any | ||
|
||
from lionfuncs import validate_boolean, validate_str | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
from .action_model import ActionModel | ||
from .reason_model import ReasonModel | ||
|
||
|
||
class StepModel(BaseModel): | ||
title: str = Field( | ||
..., | ||
title="Title", | ||
description="Provide a concise title summarizing the step.", | ||
) | ||
content: str = Field( | ||
..., | ||
title="Content", | ||
description="Describe the content of the step in detail.", | ||
) | ||
reason: ReasonModel = Field( | ||
..., | ||
title="Reason", | ||
description="Provide the reasoning behind this step, including supporting details.", | ||
) | ||
action_required: bool = Field( | ||
False, | ||
title="Action Required", | ||
description=( | ||
"Indicate whether this step requires an action. Set to **True** if an " | ||
"action is required; otherwise, set to **False**." | ||
), | ||
) | ||
actions: list[ActionModel] = Field( | ||
[], | ||
title="Actions", | ||
description=( | ||
"List of actions to be performed if `action_required` is **True**. Leave " | ||
"empty if no action is required. **When providing actions, you must " | ||
"choose from the provided `tool_schema`. Do not invent function or " | ||
"argument names.**" | ||
), | ||
) | ||
|
||
@field_validator("title", mode="before") | ||
def validate_title(cls, value: Any) -> str: | ||
return validate_str(value, "title") | ||
|
||
@field_validator("content", mode="before") | ||
def validate_content(cls, value: Any) -> str: | ||
return validate_str(value, "content") | ||
|
||
@field_validator("action_required", mode="before") | ||
def validate_action_required(cls, value: Any) -> bool: | ||
try: | ||
return validate_boolean(value) | ||
except Exception as e: | ||
logging.error( | ||
f"Failed to convert {value} to a boolean. Error: {e}" | ||
) | ||
return False | ||
|
||
|
||
__all__ = ["StepModel"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Callable | ||
from enum import Enum | ||
|
||
from lionfuncs import choose_most_similar | ||
from pydantic import BaseModel | ||
|
||
from lionagi.core.director.models import ReasonModel | ||
from lionagi.core.session.branch import Branch | ||
|
||
from .utils import is_enum | ||
|
||
PROMPT = "Please select up to {max_num_selections} items from the following list {choices}. Provide the selection(s), and no comments from you" | ||
|
||
|
||
class SelectionModel(BaseModel): | ||
selected: list[str | Enum] | ||
|
||
|
||
class ReasonSelectionModel(BaseModel): | ||
selected: list[str | Enum] | ||
reason: ReasonModel | ||
|
||
|
||
async def select( | ||
choices: list[str] | type[Enum], | ||
max_num_selections: int = 1, | ||
instruction=None, | ||
context=None, | ||
system=None, | ||
sender=None, | ||
recipient=None, | ||
reason: bool = False, | ||
return_enum: bool = False, | ||
enum_parser: Callable = None, # parse the model string response to appropriate type | ||
branch: Branch = None, | ||
return_pydantic_model=False, | ||
**kwargs, # additional chat arguments | ||
): | ||
selections = [] | ||
if return_enum and not is_enum(choices): | ||
raise ValueError("return_enum can only be True if choices is an Enum") | ||
|
||
if is_enum(choices): | ||
selections = [selection.value for selection in choices] | ||
else: | ||
selections = choices | ||
|
||
prompt = PROMPT.format( | ||
max_num_selections=max_num_selections, choices=selections | ||
) | ||
|
||
if instruction: | ||
prompt = f"{instruction}\n\n{prompt} \n\n " | ||
|
||
branch = branch or Branch() | ||
response: SelectionModel | ReasonSelectionModel | str = await branch.chat( | ||
instruction=prompt, | ||
context=context, | ||
system=system, | ||
sender=sender, | ||
recipient=recipient, | ||
pydantic_model=SelectionModel if not reason else ReasonSelectionModel, | ||
return_pydantic_model=True, | ||
**kwargs, | ||
) | ||
|
||
selected = response | ||
if isinstance(response, SelectionModel | ReasonSelectionModel): | ||
selected = response.selected | ||
selected = [selected] if not isinstance(selected, list) else selected | ||
corrected_selections = [ | ||
choose_most_similar(selection, selections) for selection in selected | ||
] | ||
|
||
if return_enum: | ||
out = [] | ||
if not enum_parser: | ||
enum_parser = lambda x: x | ||
for selection in corrected_selections: | ||
selection = enum_parser(selection) | ||
for member in choices.__members__.values(): | ||
if member.value == selection: | ||
out.append(member) | ||
corrected_selections = out | ||
|
||
if return_pydantic_model: | ||
if not isinstance(response, SelectionModel | ReasonSelectionModel): | ||
return SelectionModel(selected=corrected_selections) | ||
response.selected = corrected_selections | ||
return response | ||
return corrected_selections |
Oops, something went wrong.