diff --git a/agentverse/tasksolving.py b/agentverse/tasksolving.py index 2bd2bd9b7..d9bc9628c 100644 --- a/agentverse/tasksolving.py +++ b/agentverse/tasksolving.py @@ -37,14 +37,14 @@ def from_task(cls, task: str, tasks_dir: str): # Build agents for all pipeline (task) agents = {} for i, agent_config in enumerate(task_config["agents"]): - agent_type = AGENT_TYPES(i) - if i == 2 and agent_config.get("agent_type", "") == "critic": + if agent_config.get("agent_type", "") == "critic": agent = load_agent(agent_config) - agents[agent_type] = [ + agents[AGENT_TYPES.CRITIC] = [ copy.deepcopy(agent) for _ in range(task_config.get("cnt_agents", 1) - 1) ] else: + agent_type = AGENT_TYPES.from_string(agent_config.get("agent_type", "")) agents[agent_type] = load_agent(agent_config) env_config["agents"] = agents diff --git a/agentverse/utils.py b/agentverse/utils.py index 196d0ba2e..1eeb198fb 100644 --- a/agentverse/utils.py +++ b/agentverse/utils.py @@ -35,6 +35,21 @@ class AGENT_TYPES(Enum): EVALUATION = 4 MANAGER = 5 + @staticmethod + def from_string(agent_type: str): + str_to_enum_dict = { + "role_assigner": AGENT_TYPES.ROLE_ASSIGNMENT, + "solver": AGENT_TYPES.SOLVER, + "critic": AGENT_TYPES.CRITIC, + "executor": AGENT_TYPES.EXECUTION, + "evaluator": AGENT_TYPES.EVALUATION, + "manager": AGENT_TYPES.MANAGER, + } + assert ( + agent_type in str_to_enum_dict + ), f"Unknown agent type: {agent_type}. Check your config file." + return str_to_enum_dict.get(agent_type.lower()) + class Singleton(abc.ABCMeta, type): """