Skip to content

Commit

Permalink
Merge pull request #316 from ohdearquant/merge_direct
Browse files Browse the repository at this point in the history
chain of thoughts / react
  • Loading branch information
ohdearquant authored Mar 30, 2024
2 parents 38c7073 + 4921702 commit 0485fdc
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 11 deletions.
5 changes: 4 additions & 1 deletion lionagi/core/direct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@
from .score import score
from .react import react
from .vote import vote
from .plan import plan
from .cot import chain_of_thoughts, chain_of_react

__all__ = ["predict", "select", "score", "vote", "react"]

__all__ = ["predict", "select", "score", "vote", "react", "plan", "chain_of_thoughts", "chain_of_react"]
89 changes: 88 additions & 1 deletion lionagi/core/direct/cot.py
Original file line number Diff line number Diff line change
@@ -1 +1,88 @@
# TODO: chain of thoughts
from typing import Callable
from lionagi.libs import convert
from ..tool import func_to_tool
from ..schema import Tool
from .predict import predict
from .plan import plan
from .react import react

from .utils import _process_tools


async def chain_of_thoughts(
sentence=None,
branch=None,
instruction=None,
reason=False,
confidence_score=False,
num_steps=3,
directive_kwargs={},
return_branch=False,
**kwargs
):

out_, outs, answer, reasons, confidence_score = "", [], "", [], 0
if branch is not None:
out_ = await plan(sentence, branch=branch, instruction=instruction, num_steps=num_steps, **kwargs)
else:
out_, branch = await plan(sentence, instruction=instruction, branch=branch, num_steps=num_steps, return_branch=True, **kwargs)

for i in range(len(out_.plan)):
_out = await predict(branch=branch, instruction=out_.plan[f"step_{i+1}"], reason=reason, confidence_score=confidence_score, **directive_kwargs)
answer += _out.answer
if reason:
reasons.append(_out.reason)
if confidence_score:
confidence_score += _out.confidence_score
outs.append(_out)

setattr(out_, "chain_output", outs)
setattr(out_, "chain_answer", answer)

if reason:
setattr(out_, "chain_reasons", reasons)
if confidence_score:
setattr(out_, "chain_confidence_score", confidence_score/len(outs))

if return_branch:
return out_, branch

return out_


async def chain_of_react(
sentence=None,
branch=None,
instruction=None,
num_steps=3,
tools=None,
directive_system=None,
directive_kwargs={},
return_branch=False,
**kwargs
):
out_, outs, reasons, actions, action_responses = "", [], [], [], []
if branch is not None:
out_ = await plan(sentence, branch=branch, instruction=instruction, num_steps=num_steps, **kwargs)
else:
out_, branch = await plan(sentence, instruction=instruction, branch=branch, num_steps=num_steps, return_branch=True, **kwargs)

_process_tools(tools, branch)

for i in range(len(out_.plan)):
_out = await react(branch=branch, system=directive_system, instruction=out_.plan[f"step_{i+1}"], **directive_kwargs)
outs.append(_out)
reasons.append(_out.reason)
actions.append(_out.actions)
if _out.action_needed:
action_responses.append(_out.action_response)

setattr(out_, "chain_output", convert.to_list(outs))
setattr(out_, "chain_reason", convert.to_list(reasons))
setattr(out_, "chain_actions", convert.to_list(actions))
setattr(out_, "chain_action_response", convert.to_list(action_responses))

if return_branch:
return out_, branch

return out_
162 changes: 162 additions & 0 deletions lionagi/core/direct/plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# plan.py

from lionagi.libs import func_call, ParseUtil
from lionagi.integrations.bridge.pydantic_.pydantic_bridge import Field
from ..prompt.scored_template import ScoredTemplate
from ..branch import Branch


class PlanTemplate(ScoredTemplate):
template_name: str = "default_plan"
sentence: str | list | dict = Field(
default_factory=str,
description="the given sentence(s) or context to generate a plan for",
)
plan: dict | str= Field(
default_factory=dict, description="the generated step by step plan, return as a dictionary following {step_n: {plan: ..., reason: ...}} format")
signature: str = "sentence -> plan"

def __init__(
self,
sentence=None,
instruction=None,
confidence_score=False,
reason=False,
num_step=3,
**kwargs,
):
super().__init__(**kwargs)

self.sentence = sentence
self.task = f"Generate a {num_step}_step plan based on the given context. Instruction: {instruction}."

if reason:
self.output_fields.append("reason")

if confidence_score:
self.output_fields.append("confidence_score")


async def _plan(
sentence,
*,
instruction=None,
branch=None,
confidence_score=False,
reason=False,
retries=2,
delay=0.5,
backoff_factor=2,
default_value=None,
timeout=None,
branch_name=None,
system=None,
messages=None,
service=None,
sender=None,
llmconfig=None,
tools=None,
datalogger=None,
persist_path=None,
tool_manager=None,
return_branch=False,
**kwargs,
):
if "temperature" not in kwargs:
kwargs["temperature"] = 0.1

instruction = instruction or ""

branch = branch or Branch(
name=branch_name,
system=system,
messages=messages,
service=service,
sender=sender,
llmconfig=llmconfig,
tools=tools,
datalogger=datalogger,
persist_path=persist_path,
tool_manager=tool_manager,
)

_template = PlanTemplate(
sentence=sentence,
instruction=instruction,
confidence_score=confidence_score,
reason=reason,
)

await func_call.rcall(
branch.chat,
prompt_template=_template,
retries=retries,
delay=delay,
backoff_factor=backoff_factor,
default=default_value,
timeout=timeout,
**kwargs,
)

_template.plan = ParseUtil.fuzzy_parse_json(_template.plan)

return (_template, branch) if return_branch else _template


async def plan(
sentence,
*,
instruction=None,
num_instances=1,
branch=None,
confidence_score=False,
reason=False,
retries=2,
delay=0.5,
backoff_factor=2,
default_value=None,
timeout=None,
branch_name=None,
system=None,
messages=None,
service=None,
sender=None,
llmconfig=None,
tools=None,
datalogger=None,
persist_path=None,
tool_manager=None,
return_branch=False,
**kwargs,
):
async def _inner(i=0):
return await _plan(
sentence=sentence,
instruction=instruction,
branch=branch,
confidence_score=confidence_score,
reason=reason,
retries=retries,
delay=delay,
backoff_factor=backoff_factor,
default_value=default_value,
timeout=timeout,
branch_name=branch_name,
system=system,
messages=messages,
service=service,
sender=sender,
llmconfig=llmconfig,
tools=tools,
datalogger=datalogger,
persist_path=persist_path,
tool_manager=tool_manager,
return_branch=return_branch,
**kwargs,
)

if num_instances == 1:
return await _inner()

elif num_instances > 1:
return await func_call.alcall(range(num_instances), _inner)
14 changes: 9 additions & 5 deletions lionagi/core/direct/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ class PredictTemplate(ScoredTemplate):
default_factory=int, description="the number of sentences to predict"
)
answer: str | list = Field(
default_factory=str, description="the predicted sentence(s)"
default_factory=str, description="the predicted sentence(s) or desired output"
)
signature: str = "sentence -> answer"

def __init__(
self,
sentence=None,
num_sentences=None,
instruction=None,
num_sentences=1,
confidence_score=False,
reason=False,
**kwargs,
Expand All @@ -67,9 +68,9 @@ def __init__(
"""
super().__init__(**kwargs)

self.sentence = sentence
self.sentence = sentence or ''
self.num_sentences = num_sentences
self.task = f"predict the next {self.num_sentences} sentence(s)"
self.task = f"follow instruction to predict the next {self.num_sentences} sentence(s). Instruction: {instruction}."

if reason:
self.output_fields.append("reason")
Expand All @@ -82,6 +83,8 @@ async def predict(
sentence=None,
num_sentences=1,
confidence_score=False,
instruction=None,
branch=None,
reason=False,
retries=2,
delay=0.5,
Expand Down Expand Up @@ -128,7 +131,7 @@ async def predict(
Returns:
PredictTemplate: The predict template with the predicted sentence(s).
"""
branch = Branch(
branch = branch or Branch(
name=branch_name,
system=system,
messages=messages,
Expand All @@ -142,6 +145,7 @@ async def predict(
)

predict_template = PredictTemplate(
instruction=instruction,
sentence=sentence,
num_sentences=num_sentences,
confidence_score=confidence_score,
Expand Down
10 changes: 7 additions & 3 deletions lionagi/core/direct/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from lionagi.integrations.bridge.pydantic_.pydantic_bridge import Field
from ..prompt.action_template import ActionedTemplate
from ..branch import Branch
from .utils import _process_tools


class ReactTemplate(ActionedTemplate):
template_name: str = "default_react"
sentence: str | list | dict = Field(
sentence: str | list | dict | None= Field(
default_factory=str,
description="the given sentence(s) to reason and take actions on",
)
Expand All @@ -29,7 +30,7 @@ def __init__(


async def _react(
sentence,
sentence=None,
*,
instruction=None,
branch=None,
Expand Down Expand Up @@ -58,6 +59,9 @@ async def _react(

instruction = instruction or ""

if branch and tools:
_process_tools(tools, branch)

branch = branch or Branch(
name=branch_name,
system=system,
Expand Down Expand Up @@ -109,7 +113,7 @@ async def _react(


async def react(
sentence,
sentence=None,
*,
instruction=None,
num_instances=1,
Expand Down
20 changes: 20 additions & 0 deletions lionagi/core/direct/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Callable
from ..tool import func_to_tool
from ..schema import Tool
# import contextlib
# from lionagi.libs import ParseUtil, StringMatch, convert, func_call

Expand Down Expand Up @@ -85,3 +88,20 @@
# return _out

# return out_ if len(out_) > 1 else out_[0]


def _process_tools(tool_obj, branch):
if isinstance(tool_obj, Callable):
_process_tool(tool_obj, branch)
else:
for i in tool_obj:
_process_tool(i, branch)


def _process_tool(tool_obj, branch):
if isinstance(tool_obj, Tool) and tool_obj.schema_["function"]["name"] not in branch.tool_manager.registry:
branch.register_tools(tool_obj)
if isinstance(tool_obj, Callable):
tool = func_to_tool(tool_obj)[0]
if tool.schema_["function"]["name"] not in branch.tool_manager.registry:
branch.register_tools(tool)
2 changes: 1 addition & 1 deletion lionagi/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.315"
__version__ = "0.0.316"

0 comments on commit 0485fdc

Please sign in to comment.