diff --git a/browserpilot/agents/compilers/instruction_compiler.py b/browserpilot/agents/compilers/instruction_compiler.py index a7c8dfb..bfa8d97 100644 --- a/browserpilot/agents/compilers/instruction_compiler.py +++ b/browserpilot/agents/compilers/instruction_compiler.py @@ -6,6 +6,7 @@ import logging import traceback import os +import re from typing import Dict, List, Union @@ -16,8 +17,8 @@ logger = logging.getLogger(__name__) # Instantiate OpenAI with OPENAI_API_KEY. -client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) - +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY").strip(), + base_url=os.environ.get("OPENAI_API_BASE_URL", None)) """Set up all the prompt variables.""" # Designated tokens. @@ -300,6 +301,21 @@ def get_completion( stop=stop, ) text = response.choices[0].message.content + elif'api.openai.com' not in str(client.base_url): + # LLama3 is more aggressive and stops at the first stop token. might want + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + temperature=temperature + ) + raw_text = response.choices[0].message.content + text = "\n".join(re.findall(r"```([^`]+)```", raw_text)) + + else: response = client.completions.create( model=model, @@ -335,6 +351,7 @@ def get_action_output(self, instructions): """Get the action output for the given instructions.""" prompt = self.base_prompt.format(instructions=instructions) completion = self.get_completion(prompt).strip() + action_output = completion.strip() lines = [line for line in action_output.split("\n") if not line.startswith("import ")] action_output = "\n".join(lines)