Skip to content

Commit

Permalink
Merge pull request #314 from ohdearquant/merge_direct
Browse files Browse the repository at this point in the history
v0.0.314 one step react
  • Loading branch information
ohdearquant authored Mar 29, 2024
2 parents 015a9e1 + 5eb557b commit 5914863
Show file tree
Hide file tree
Showing 17 changed files with 445 additions and 57 deletions.
3 changes: 2 additions & 1 deletion lionagi/core/direct/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .predict import predict
from .select import select
from .score import score
from .react import react
from .vote import vote

__all__ = ["predict", "select", "score", "vote"]
__all__ = ["predict", "select", "score", "vote", "react"]
5 changes: 3 additions & 2 deletions lionagi/core/direct/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
confidence score, and reason for the prediction.
"""

from pydantic import Field
from lionagi.libs import func_call
from ..prompt.prompt_template import ScoredTemplate
from lionagi.integrations.bridge.pydantic_.pydantic_bridge import Field

from ..prompt.scored_template import ScoredTemplate
from ..branch import Branch


Expand Down
167 changes: 167 additions & 0 deletions lionagi/core/direct/react.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from lionagi.libs import func_call, convert, AsyncUtil

from lionagi.integrations.bridge.pydantic_.pydantic_bridge import Field
from ..prompt.action_template import ActionedTemplate
from ..branch import Branch


class ReactTemplate(ActionedTemplate):
template_name: str = "default_react"
sentence: str | list | dict = Field(
default_factory=str,
description="the given sentence(s) to reason and take actions on",
)

def __init__(
self,
sentence=None,
instruction=None,
confidence_score=False,
**kwargs,
):
super().__init__(**kwargs)

self.sentence = sentence
self.task = f"Think step by step. Perform reasoning and prepare actions with given tools only.Instruction: {instruction}. Absolutely DO NOT MAKE UP FUNCTIONS !!!"

if confidence_score:
self.output_fields.append("confidence_score")


async def _react(
sentence,
*,
instruction=None,
branch=None,
confidence_score=False,
retries=2,
delay=0.5,
backoff_factor=2,
default_value=None,
timeout=None,
branch_name=None,
system=None,
messages=None,
service=None,
sender=None,
llmconfig=None,
tools=None,
datalogger=None,
persist_path=None,
tool_manager=None,
return_branch=False,
**kwargs,
):

if "temperature" not in kwargs:
kwargs["temperature"] = 0.1

instruction = instruction or ""

branch = branch or Branch(
name=branch_name,
system=system,
messages=messages,
service=service,
sender=sender,
llmconfig=llmconfig,
tools=tools,
datalogger=datalogger,
persist_path=persist_path,
tool_manager=tool_manager,
)

_template = ReactTemplate(
sentence=sentence,
instruction=instruction,
confidence_score=confidence_score,
)

await func_call.rcall(
branch.chat,
prompt_template=_template,
retries=retries,
delay=delay,
backoff_factor=backoff_factor,
default=default_value,
timeout=timeout,
**kwargs,
)

if _template.action_needed:
actions = _template.actions
tasks = [branch.tool_manager.invoke(i.values()) for i in actions]
results = await AsyncUtil.execute_tasks(*tasks)

a = []
for idx, item in enumerate(actions):
res = {
"function": item["function"],
"arguments": item["arguments"],
"output": results[idx],
}
branch.add_message(response=res)
a.append(res)

_template.__setattr__("action_response", a)

return (_template, branch) if return_branch else _template


async def react(
sentence,
*,
instruction=None,
num_instances=1,
branch=None,
confidence_score=False,
retries=2,
delay=0.5,
backoff_factor=2,
default_value=None,
timeout=None,
branch_name=None,
system=None,
messages=None,
service=None,
sender=None,
llmconfig=None,
tools=None,
datalogger=None,
persist_path=None,
tool_manager=None,
return_branch=False,
**kwargs,
):

async def _inner(i=0):
return await _react(
sentence=sentence,
instruction=instruction,
num_instances=num_instances,
branch=branch,
confidence_score=confidence_score,
retries=retries,
delay=delay,
backoff_factor=backoff_factor,
default_value=default_value,
timeout=timeout,
branch_name=branch_name,
system=system,
messages=messages,
service=service,
sender=sender,
llmconfig=llmconfig,
tools=tools,
datalogger=datalogger,
persist_path=persist_path,
tool_manager=tool_manager,
return_branch=return_branch,
**kwargs,
)

if num_instances == 1:
return await _inner()

elif num_instances > 1:
return await func_call.alcall(range(num_instances), _inner)
3 changes: 2 additions & 1 deletion lionagi/core/direct/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pydantic import Field
import numpy as np
from lionagi.libs import func_call, convert
from ..prompt.prompt_template import ScoredTemplate
from ..prompt.scored_template import ScoredTemplate
from ..branch import Branch


Expand Down Expand Up @@ -183,6 +183,7 @@ async def _score(

async def score(
sentence,
*,
num_instances=1,
instruction=None,
score_range=(1, 10),
Expand Down
5 changes: 4 additions & 1 deletion lionagi/core/direct/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pydantic import Field

from lionagi.libs import func_call, StringMatch
from ..prompt.prompt_template import ScoredTemplate
from ..prompt.scored_template import ScoredTemplate
from ..branch import Branch


Expand All @@ -39,6 +39,9 @@ class SelectTemplate(ScoredTemplate):
answer: Enum | str = Field(
default_factory=str, description="selection from given choices"
)
choices: list = Field(
default_factory=list, description="the given choices"
)

signature: str = "sentence -> answer"

Expand Down
2 changes: 1 addition & 1 deletion lionagi/core/messages/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(

if output_fields:
format_ = f"""
Follow the following response format.
MUST EXACTLY Follow the following response format. NO ADDITIONAL COMMENTS ALLOWED!
```json
{output_fields}
```
Expand Down
26 changes: 26 additions & 0 deletions lionagi/core/prompt/action_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Any
from lionagi.integrations.bridge.pydantic_.pydantic_bridge import Field

from .scored_template import ScoredTemplate


class ActionRequest: ...


class ActionedTemplate(ScoredTemplate):

action_needed: bool | None = Field(
False, description="true if actions are needed else false"
)

actions: list[dict | ActionRequest | Any] | None = Field(
default_factory=list,
description="""provide The list of action(s) to take, each action in {"function": function_name, "arguments": {param1:..., param2:..., ...}}. Leave blank if no further actions are needed, you must use provided parameters for each action, DO NOT MAKE UP KWARG NAME!!!""",
)

answer: str | dict | Any | None = Field(
default_factory=str,
description="output answer to the questions asked if further actions are not needed, leave blank if an accurate answer cannot be provided from context during this step",
)

signature: str = "sentence -> reason, action_needed, actions, answer"
41 changes: 40 additions & 1 deletion lionagi/core/prompt/field_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,45 @@
maps data types to their corresponding validation functions.
"""

from lionagi.libs import convert, StringMatch
from lionagi.libs import convert, StringMatch, ParseUtil


def _has_action_keys(dict_):
return list(dict_.keys()) >= ["function", "arguments"]


def check_action_field(x, fix_=True, **kwargs):
if (
isinstance(x, list)
and convert.is_same_dtype(x, dict)
and all(_has_action_keys(y) for y in x)
):
return x
try:
x = _fix_action_field(x, fix_)
return x
except Exception as e:
raise ValueError("Invalid action field type.") from e


def _fix_action_field(x, discard_=True):
corrected = []
if isinstance(x, str):
x = ParseUtil.fuzzy_parse_json(x)

try:
x = convert.to_list(x)

for i in x:
i = convert.to_dict(i)
if _has_action_keys(i):
corrected.append(i)
elif not discard_:
raise ValueError(f"Invalid action field: {i}")
except Exception as e:
raise ValueError(f"Invalid action field: {e}") from e

return corrected


def check_number_field(x, fix_=True, **kwargs):
Expand Down Expand Up @@ -236,4 +274,5 @@ def _fix_enum_field(x, choices, **kwargs):
"bool": check_bool_field,
"str": check_str_field,
"enum": check_enum_field,
"action": check_action_field,
}
Loading

0 comments on commit 5914863

Please sign in to comment.