diff --git a/src/rawdog/config.py b/src/rawdog/config.py index dbbc4bb..301bb55 100644 --- a/src/rawdog/config.py +++ b/src/rawdog/config.py @@ -10,6 +10,7 @@ "llm_api_key": None, "llm_base_url": None, "llm_model": "gpt-4-turbo-preview", + "pip_model": None, "llm_custom_provider": None, "llm_temperature": 1.0, "retries": 2, @@ -21,6 +22,7 @@ setting_descriptions = { "retries": "If the script fails, retry this many times before giving up.", "leash": "Print the script before executing and prompt for confirmation.", + "pip_model": "The model to use to get package name from import name.", } diff --git a/src/rawdog/llm_client.py b/src/rawdog/llm_client.py index e6c3404..20cbc6b 100644 --- a/src/rawdog/llm_client.py +++ b/src/rawdog/llm_client.py @@ -46,7 +46,14 @@ def add_message(self, role: str, content: str): def get_python_package(self, import_name: str): base_url = self.config.get("llm_base_url") - model = self.config.get("llm_model") + model = self.config.get("pip_model") + llm_model = self.config.get("llm_model") + if model is None: + if "ft:" in llm_model or "rawdog" in llm_model or "abante" in llm_model: + model = "gpt-3.5-turbo" + else: + model = llm_model + custom_llm_provider = self.config.get("llm_custom_provider") messages = [