diff --git a/mle/agents/__init__.py b/mle/agents/__init__.py index e49de82..0d8a924 100644 --- a/mle/agents/__init__.py +++ b/mle/agents/__init__.py @@ -4,3 +4,4 @@ from .planner import * from .summarizer import * from .reporter import * +from .chat import * diff --git a/mle/agents/chat.py b/mle/agents/chat.py new file mode 100644 index 0000000..5d546f7 --- /dev/null +++ b/mle/agents/chat.py @@ -0,0 +1,129 @@ +import sys +import json +from rich.console import Console + +from mle.function import * +from mle.utils import get_config, print_in_box, WorkflowCache + + +class ChatAgent: + + def __init__(self, model, working_dir='.', console=None): + """ + ChatAgent assists users with planning and debugging ML projects. + + Args: + model: The machine learning model used for generating responses. + """ + config_data = get_config() + + self.model = model + self.chat_history = [] + self.working_dir = working_dir + self.cache = WorkflowCache(working_dir, 'baseline') + + self.console = console + if not self.console: + self.console = Console() + + self.sys_prompt = f""" + You are a programmer working on an Machine Learning task using Python. + You are currently working on: {self.working_dir}. + + Your can leverage your capabilities by using the specific functions listed below: + + 1. Creating project structures based on the user requirement using function `create_directory`. + 2. Writing clean, efficient, and well-documented code using function `create_file` and `write_file`. + 3. Exam the project to re-use the existing code snippets as much as possible, you may need to use + functions like `list_files`, `read_file` and `write_file`. + 4. Writing the code into the file when creating new files, do not create empty files. + 5. Use function `preview_csv_data` to preview the CSV data if the task include CSV data processing. + 6. Decide whether the task requires execution and debugging before moving to the next or not. + 7. Generate the commands to run and test the current task, and the dependencies list for this task. + 8. You only write Python scripts, don't write Jupiter notebooks which require interactive execution. + """ + self.search_prompt = """ + 9. Performing web searches use function `web_search` to get up-to-date information or additional context. + """ + + self.functions = [ + schema_read_file, + schema_create_file, + schema_write_file, + schema_list_files, + schema_create_directory, + schema_search_arxiv, + schema_search_papers_with_code, + schema_web_search, + schema_execute_command, + schema_preview_csv_data + ] + + if config_data.get('search_key'): + self.functions.append(schema_web_search) + self.sys_prompt += self.search_prompt + + if not self.cache.is_empty(): + dataset = self.cache.resume_variable("dataset") + ml_requirement = self.cache.resume_variable("ml_requirement") + advisor_report = self.cache.resume_variable("advisor_report") + self.sys_prompt += f""" + The overall project information: \n + {'Dataset: ' + dataset if dataset else ''} \n + {'Requirement: ' + ml_requirement if ml_requirement else ''} \n + {'Advisor: ' + advisor_report if advisor_report else ''} \n + """ + + self.chat_history.append({"role": 'system', "content": self.sys_prompt}) + + def greet(self): + """ + Generate a greeting message to the user, including inquiries about the project's purpose and + an overview of the support provided. This initializes a collaborative tone with the user. + + Returns: + str: The generated greeting message. + """ + system_prompt = """ + You are a Chatbot designed to collaborate with users on planning and debugging ML projects. + Your goal is to provide concise and friendly greetings within 50 words, including: + 1. Infer about the project's purpose or objective. + 2. Summarize the previous conversations if it existed. + 2. Offering a brief overview of the assistance and support you can provide to the user, such as: + - Helping with project planning and management. + - Assisting with debugging and troubleshooting code. + - Offering advice on best practices and optimization techniques. + - Providing resources and references for further learning. + Make sure your greeting is inviting and sets a positive tone for collaboration. + """ + self.chat_history.append({"role": "system", "content": system_prompt}) + greets = self.model.query( + self.chat_history, + function_call='auto', + functions=self.functions, + ) + + self.chat_history.append({"role": "assistant", "content": greets}) + return greets + + def chat(self, user_prompt): + """ + Handle the response from the model streaming. + The stream mode is integrative with the model streaming function, we don't + need to set it into the JSON mode. + + Args: + user_prompt: the user prompt. + """ + text = '' + self.chat_history.append({"role": "user", "content": user_prompt}) + for content in self.model.stream( + self.chat_history, + function_call='auto', + functions=self.functions, + ): + if content: + text += content + yield text + + self.chat_history.append({"role": "assistant", "content": text}) diff --git a/mle/agents/coder.py b/mle/agents/coder.py index f86dd3f..090cfb7 100644 --- a/mle/agents/coder.py +++ b/mle/agents/coder.py @@ -205,35 +205,3 @@ def interact(self, task_dict: dict): ) print_in_box(process_summary(self.code_summary), self.console, title="MLE Developer", color="cyan") return self.code_summary - - def chat(self, user_prompt): - """ - Handle the response from the model streaming. - The stream mode is integrative with the model streaming function, we don't - need to set it into the JSON mode. - Args: - user_prompt: the user prompt. - """ - text = '' - self.chat_history.append({"role": "user", "content": user_prompt}) - for content in self.model.stream( - self.chat_history, - function_call='auto', - functions=[ - schema_read_file, - schema_create_file, - schema_write_file, - schema_list_files, - schema_create_directory, - schema_search_arxiv, - schema_search_papers_with_code, - schema_web_search, - schema_execute_command, - schema_preview_csv_data - ] - ): - if content: - text += content - yield text - - self.chat_history.append({"role": "assistant", "content": text}) diff --git a/mle/cli.py b/mle/cli.py index 7f3bf3a..e134e9a 100644 --- a/mle/cli.py +++ b/mle/cli.py @@ -6,18 +6,13 @@ import uvicorn import questionary from pathlib import Path -from rich.live import Live -from rich.panel import Panel from rich.console import Console -from rich.markdown import Markdown -from concurrent.futures import ThreadPoolExecutor import mle +import mle.workflow as workflow from mle.server import app from mle.model import load_model -from mle.agents import CodeAgent -import mle.workflow as workflow -from mle.utils import Memory, WorkflowCache +from mle.utils import Memory from mle.utils.system import ( get_config, write_config, @@ -58,6 +53,9 @@ def start(ctx, mode, model): elif mode == 'kaggle': # Kaggle mode return ctx.invoke(kaggle, model=model) + elif mode == 'chat': + # Chat mode + return ctx.invoke(chat, model=model) else: raise ValueError("Invalid mode. Supported modes: 'baseline', 'report', 'kaggle'.") @@ -79,6 +77,8 @@ def report(ctx, repo, model, user, visualize): "[blue underline]http://localhost:3000/[/blue underline]", console=console, title="MLE Report", color="green" ) + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor() as executor: future1 = executor.submit(ctx.invoke, serve) future2 = executor.submit(ctx.invoke, web) @@ -139,37 +139,15 @@ def kaggle(model): @cli.command() -def chat(): +@click.option('--model', default=None, help='The model to use for the chat.') +def chat(model): """ chat: start an interactive chat with LLM to work on your ML project. """ if not check_config(console): return - model = load_model(os.getcwd()) - cache = WorkflowCache(os.getcwd()) - coder = CodeAgent(model) - - # read the project information - dataset = cache.resume_variable("dataset") - ml_requirement = cache.resume_variable("ml_requirement") - advisor_report = cache.resume_variable("advisor_report") - - # inject the project information into prompts - coder.read_requirement(advisor_report or ml_requirement or dataset) - - while True: - try: - user_pmpt = questionary.text("[Exit/Ctrl+D]: ").ask() - if user_pmpt: - with Live(console=Console()) as live: - for text in coder.chat(user_pmpt.strip()): - live.update( - Panel(Markdown(text), title="[bold magenta]MLE-Agent[/]", border_style="magenta"), - refresh=True - ) - except (KeyboardInterrupt, EOFError): - exit() + return workflow.chat(os.getcwd(), model) @cli.command() diff --git a/mle/utils/cache.py b/mle/utils/cache.py index d9d737f..a8e72d2 100644 --- a/mle/utils/cache.py +++ b/mle/utils/cache.py @@ -71,16 +71,18 @@ class WorkflowCache: methods to load, store, and remove cached steps. """ - def __init__(self, project_dir: str): + def __init__(self, project_dir: str, workflow: str = 'baseline'): """ Initialize WorkflowCache with a project directory. Args: project_dir (str): The directory of the project. + workflow (str): The name of the cached workflow. """ self.project_dir = project_dir - self.buffer = self._load_cache_buffer() - self.cache: Dict[int, Dict[str, Any]] = self.buffer["cache"] + self.workflow = workflow + self.buffer = self._load_cache_buffer(workflow) + self.cache: Dict[int, Dict[str, Any]] = self.buffer["cache"][workflow] def is_empty(self) -> bool: """ @@ -124,22 +126,27 @@ def resume_variable(self, key: str, step: Optional[int] = None): if step is not None: return self.__call__(step).resume(key) else: - for step in range(self.current_step()): + for step in range(self.current_step() + 1): value = self.resume_variable(key, step) if value is not None: return value return None - def _load_cache_buffer(self) -> Dict[str, Any]: + def _load_cache_buffer(self, workflow: str) -> Dict[str, Any]: """ Load the cache buffer from the configuration. + Args: + workflow (str): The name of the cached workflow. + Returns: dict: The buffer loaded from the configuration. """ buffer = get_config() or {} - if "cache" not in buffer: + if "cache" not in buffer.keys(): buffer["cache"] = {} + if workflow not in buffer["cache"].keys(): + buffer["cache"][workflow] = {} return buffer def _store_cache_buffer(self) -> None: @@ -159,7 +166,7 @@ def __call__(self, step: int, name: Optional[str] = None) -> WorkflowCacheOperat Returns: WorkflowCacheOperator: An instance of WorkflowCacheOperator. """ - if step not in self.cache: + if step not in self.cache.keys(): timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.cache[step] = { "step": step, diff --git a/mle/workflow/__init__.py b/mle/workflow/__init__.py index 1c04456..36d25c5 100644 --- a/mle/workflow/__init__.py +++ b/mle/workflow/__init__.py @@ -1,3 +1,4 @@ from .baseline import baseline from .report import report -from .kaggle import kaggle \ No newline at end of file +from .kaggle import kaggle +from .chat import chat diff --git a/mle/workflow/baseline.py b/mle/workflow/baseline.py index faa3d89..342e82f 100644 --- a/mle/workflow/baseline.py +++ b/mle/workflow/baseline.py @@ -29,7 +29,7 @@ def baseline(work_dir: str, model=None): """ console = Console() - cache = WorkflowCache(work_dir) + cache = WorkflowCache(work_dir, 'baseline') model = load_model(work_dir, model) if not cache.is_empty(): diff --git a/mle/workflow/chat.py b/mle/workflow/chat.py new file mode 100644 index 0000000..fb78043 --- /dev/null +++ b/mle/workflow/chat.py @@ -0,0 +1,41 @@ +""" +Chat Mode: the mode to have an interactive chat with LLM to work on ML project. +""" +import os +import questionary +from rich.live import Live +from rich.panel import Panel +from rich.console import Console +from rich.markdown import Markdown +from mle.model import load_model +from mle.utils import print_in_box, WorkflowCache +from mle.agents import ChatAgent + + +def chat(work_dir: str, model=None): + console = Console() + cache = WorkflowCache(work_dir, 'chat') + model = load_model(work_dir, model) + chatbot = ChatAgent(model) + + if not cache.is_empty(): + if questionary.confirm(f"Would you like to continue the previous conversation?\n").ask(): + chatbot.chat_history = cache.resume_variable("conversation") + + with cache(step=1, name="chat") as ca: + greets = chatbot.greet() + print_in_box(greets, console=console, title="MLE Chatbot", color="magenta") + + while True: + try: + user_pmpt = questionary.text("[Exit/Ctrl+D]: ").ask() + if user_pmpt: + with Live(console=Console()) as live: + for text in chatbot.chat(user_pmpt.strip()): + live.update( + Panel(Markdown(text), title="[bold magenta]MLE-Agent[/]", border_style="magenta"), + refresh=True + ) + ca.store("conversation", chatbot.chat_history) + except (KeyboardInterrupt, EOFError): + break diff --git a/mle/workflow/kaggle.py b/mle/workflow/kaggle.py index e9e2d6e..3eb5395 100644 --- a/mle/workflow/kaggle.py +++ b/mle/workflow/kaggle.py @@ -15,7 +15,7 @@ def kaggle(work_dir: str, model=None, kaggle_username=None, kaggle_token=None): The workflow of the kaggle mode. """ console = Console() - cache = WorkflowCache(work_dir) + cache = WorkflowCache(work_dir, 'kaggle') model = load_model(work_dir, model) kaggle = KaggleIntegration(kaggle_username, kaggle_token)