-
-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #316 from ohdearquant/merge_direct
chain of thoughts / react
- Loading branch information
Showing
7 changed files
with
291 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,88 @@ | ||
# TODO: chain of thoughts | ||
from typing import Callable | ||
from lionagi.libs import convert | ||
from ..tool import func_to_tool | ||
from ..schema import Tool | ||
from .predict import predict | ||
from .plan import plan | ||
from .react import react | ||
|
||
from .utils import _process_tools | ||
|
||
|
||
async def chain_of_thoughts( | ||
sentence=None, | ||
branch=None, | ||
instruction=None, | ||
reason=False, | ||
confidence_score=False, | ||
num_steps=3, | ||
directive_kwargs={}, | ||
return_branch=False, | ||
**kwargs | ||
): | ||
|
||
out_, outs, answer, reasons, confidence_score = "", [], "", [], 0 | ||
if branch is not None: | ||
out_ = await plan(sentence, branch=branch, instruction=instruction, num_steps=num_steps, **kwargs) | ||
else: | ||
out_, branch = await plan(sentence, instruction=instruction, branch=branch, num_steps=num_steps, return_branch=True, **kwargs) | ||
|
||
for i in range(len(out_.plan)): | ||
_out = await predict(branch=branch, instruction=out_.plan[f"step_{i+1}"], reason=reason, confidence_score=confidence_score, **directive_kwargs) | ||
answer += _out.answer | ||
if reason: | ||
reasons.append(_out.reason) | ||
if confidence_score: | ||
confidence_score += _out.confidence_score | ||
outs.append(_out) | ||
|
||
setattr(out_, "chain_output", outs) | ||
setattr(out_, "chain_answer", answer) | ||
|
||
if reason: | ||
setattr(out_, "chain_reasons", reasons) | ||
if confidence_score: | ||
setattr(out_, "chain_confidence_score", confidence_score/len(outs)) | ||
|
||
if return_branch: | ||
return out_, branch | ||
|
||
return out_ | ||
|
||
|
||
async def chain_of_react( | ||
sentence=None, | ||
branch=None, | ||
instruction=None, | ||
num_steps=3, | ||
tools=None, | ||
directive_system=None, | ||
directive_kwargs={}, | ||
return_branch=False, | ||
**kwargs | ||
): | ||
out_, outs, reasons, actions, action_responses = "", [], [], [], [] | ||
if branch is not None: | ||
out_ = await plan(sentence, branch=branch, instruction=instruction, num_steps=num_steps, **kwargs) | ||
else: | ||
out_, branch = await plan(sentence, instruction=instruction, branch=branch, num_steps=num_steps, return_branch=True, **kwargs) | ||
|
||
_process_tools(tools, branch) | ||
|
||
for i in range(len(out_.plan)): | ||
_out = await react(branch=branch, system=directive_system, instruction=out_.plan[f"step_{i+1}"], **directive_kwargs) | ||
outs.append(_out) | ||
reasons.append(_out.reason) | ||
actions.append(_out.actions) | ||
if _out.action_needed: | ||
action_responses.append(_out.action_response) | ||
|
||
setattr(out_, "chain_output", convert.to_list(outs)) | ||
setattr(out_, "chain_reason", convert.to_list(reasons)) | ||
setattr(out_, "chain_actions", convert.to_list(actions)) | ||
setattr(out_, "chain_action_response", convert.to_list(action_responses)) | ||
|
||
if return_branch: | ||
return out_, branch | ||
|
||
return out_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# plan.py | ||
|
||
from lionagi.libs import func_call, ParseUtil | ||
from lionagi.integrations.bridge.pydantic_.pydantic_bridge import Field | ||
from ..prompt.scored_template import ScoredTemplate | ||
from ..branch import Branch | ||
|
||
|
||
class PlanTemplate(ScoredTemplate): | ||
template_name: str = "default_plan" | ||
sentence: str | list | dict = Field( | ||
default_factory=str, | ||
description="the given sentence(s) or context to generate a plan for", | ||
) | ||
plan: dict | str= Field( | ||
default_factory=dict, description="the generated step by step plan, return as a dictionary following {step_n: {plan: ..., reason: ...}} format") | ||
signature: str = "sentence -> plan" | ||
|
||
def __init__( | ||
self, | ||
sentence=None, | ||
instruction=None, | ||
confidence_score=False, | ||
reason=False, | ||
num_step=3, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
|
||
self.sentence = sentence | ||
self.task = f"Generate a {num_step}_step plan based on the given context. Instruction: {instruction}." | ||
|
||
if reason: | ||
self.output_fields.append("reason") | ||
|
||
if confidence_score: | ||
self.output_fields.append("confidence_score") | ||
|
||
|
||
async def _plan( | ||
sentence, | ||
*, | ||
instruction=None, | ||
branch=None, | ||
confidence_score=False, | ||
reason=False, | ||
retries=2, | ||
delay=0.5, | ||
backoff_factor=2, | ||
default_value=None, | ||
timeout=None, | ||
branch_name=None, | ||
system=None, | ||
messages=None, | ||
service=None, | ||
sender=None, | ||
llmconfig=None, | ||
tools=None, | ||
datalogger=None, | ||
persist_path=None, | ||
tool_manager=None, | ||
return_branch=False, | ||
**kwargs, | ||
): | ||
if "temperature" not in kwargs: | ||
kwargs["temperature"] = 0.1 | ||
|
||
instruction = instruction or "" | ||
|
||
branch = branch or Branch( | ||
name=branch_name, | ||
system=system, | ||
messages=messages, | ||
service=service, | ||
sender=sender, | ||
llmconfig=llmconfig, | ||
tools=tools, | ||
datalogger=datalogger, | ||
persist_path=persist_path, | ||
tool_manager=tool_manager, | ||
) | ||
|
||
_template = PlanTemplate( | ||
sentence=sentence, | ||
instruction=instruction, | ||
confidence_score=confidence_score, | ||
reason=reason, | ||
) | ||
|
||
await func_call.rcall( | ||
branch.chat, | ||
prompt_template=_template, | ||
retries=retries, | ||
delay=delay, | ||
backoff_factor=backoff_factor, | ||
default=default_value, | ||
timeout=timeout, | ||
**kwargs, | ||
) | ||
|
||
_template.plan = ParseUtil.fuzzy_parse_json(_template.plan) | ||
|
||
return (_template, branch) if return_branch else _template | ||
|
||
|
||
async def plan( | ||
sentence, | ||
*, | ||
instruction=None, | ||
num_instances=1, | ||
branch=None, | ||
confidence_score=False, | ||
reason=False, | ||
retries=2, | ||
delay=0.5, | ||
backoff_factor=2, | ||
default_value=None, | ||
timeout=None, | ||
branch_name=None, | ||
system=None, | ||
messages=None, | ||
service=None, | ||
sender=None, | ||
llmconfig=None, | ||
tools=None, | ||
datalogger=None, | ||
persist_path=None, | ||
tool_manager=None, | ||
return_branch=False, | ||
**kwargs, | ||
): | ||
async def _inner(i=0): | ||
return await _plan( | ||
sentence=sentence, | ||
instruction=instruction, | ||
branch=branch, | ||
confidence_score=confidence_score, | ||
reason=reason, | ||
retries=retries, | ||
delay=delay, | ||
backoff_factor=backoff_factor, | ||
default_value=default_value, | ||
timeout=timeout, | ||
branch_name=branch_name, | ||
system=system, | ||
messages=messages, | ||
service=service, | ||
sender=sender, | ||
llmconfig=llmconfig, | ||
tools=tools, | ||
datalogger=datalogger, | ||
persist_path=persist_path, | ||
tool_manager=tool_manager, | ||
return_branch=return_branch, | ||
**kwargs, | ||
) | ||
|
||
if num_instances == 1: | ||
return await _inner() | ||
|
||
elif num_instances > 1: | ||
return await func_call.alcall(range(num_instances), _inner) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.0.315" | ||
__version__ = "0.0.316" |