Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: adding task to the queue from text hmi #284

Merged
merged 2 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/rai_hmi/rai_hmi/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def initialize_agent(hmi_node: BaseHMINode, rai_node: RaiBaseNode, memory: Memor

@tool
def get_mission_memory(uid: str) -> List[MissionMessage]:
"""List mission memory. Mission uid is required."""
"""List mission memory. It contains the information about running tasks. Mission uid is required."""
return memory.get_mission_memory(uid)

@tool
def add_task_to_queue(task: TaskInput):
"""Use this tool to add a task to the queue. The task will be handled by the executor part of your system."""
def submit_mission_to_the_robot(task: TaskInput):
"""Use this tool submit the task to the robot. The task will be handled by the executor part of your system."""

uid = uuid.uuid4()
hmi_node.add_task_to_queue(
hmi_node.execute_mission(
Task(
name=task.name,
description=task.description,
Expand All @@ -55,7 +55,7 @@ def add_task_to_queue(task: TaskInput):
Ros2GetRobotInterfaces(node=rai_node),
GetCameraImage(node=rai_node),
]
task_tools = [add_task_to_queue, get_mission_memory]
task_tools = [submit_mission_to_the_robot, get_mission_memory]
tools = hmi_node.tools + node_tools + task_tools

agent = create_conversational_agent(
Expand Down
7 changes: 3 additions & 4 deletions src/rai_hmi/rai_hmi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
MissionFeedbackMessage,
)
from rai_hmi.task import Task
from rai_hmi.tools import QueryDatabaseTool, QueueTaskTool
from rai_hmi.tools import QueryDatabaseTool
from rai_interfaces.action import Task as TaskAction


Expand Down Expand Up @@ -72,7 +72,7 @@ class BaseHMINode(Node):
If you have multiple questions, please ask them one by one allowing user to respond before
moving forward to the next question. Keep the conversation short and to the point.
If you are requested tasks that you are capable of perfoming as a robot, not as a
conversational agent, please use tools to submit them to the task queue - only 1
conversational agent, please use tools to submit them to the robot - only 1
task in parallel is supported. For more complicated tasks, don't split them, just
add as 1 task.
They will be done by another agent resposible for communication with the robotic's
Expand Down Expand Up @@ -138,7 +138,6 @@ def _initialize_available_tools(self):
tools.append(
QueryDatabaseTool(get_response=self.query_faiss_index_with_scores)
)
tools.append(QueueTaskTool(add_task=self.add_task_to_queue))
return tools

def status_callback(self):
Expand Down Expand Up @@ -183,7 +182,7 @@ def initialize_task_action_client_and_server(self):
# self, TaskFeedback, "provide_task_feedback", self.handle_task_feedback
# )

def add_task_to_queue(self, task: Task):
def execute_mission(self, task: Task):
"""Sends a task to the action server to be handled by the rai node."""

if not self.task_action_client.wait_for_server(timeout_sec=10.0):
Expand Down
20 changes: 0 additions & 20 deletions src/rai_hmi/rai_hmi/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field

from .task import Task


class QueryDatabaseInput(BaseModel):
query: str = Field(
Expand All @@ -42,21 +40,3 @@ class QueryDatabaseTool(BaseTool):
def _run(self, query: str):
retrieval_response = self.get_response(query)
return str(retrieval_response)


class QueueTaskInput(BaseModel):
task: Task = Field(..., description="The task to queue")


class QueueTaskTool(BaseTool):
name: str = "queue_task"
description: str = "Queue a task for the platform"
input_type: Type[QueueTaskInput] = QueueTaskInput

args_schema: Type[QueueTaskInput] = QueueTaskInput

add_task: Any

def _run(self, task: Task):
self.add_task(task)
return f"Task {task} has been queued for the LLM"