From fbbae19d4e593193b954e1dc8de6305da63014d2 Mon Sep 17 00:00:00 2001 From: Ivan Kwiatkowski Date: Tue, 17 Sep 2024 21:19:44 +0200 Subject: [PATCH] Added support for local models using Ollama (#36) Major refactoring to support dynamic construction of the UI menus. This was necessary to support arbitrary model combinations installed via Ollama. Updated translations. --- gepetto/config.ini | 4 + gepetto/config.py | 19 +++-- gepetto/ida/handlers.py | 9 +- gepetto/ida/ui.py | 80 ++++++++---------- gepetto/locales/ca_ES/LC_MESSAGES/gepetto.mo | Bin 3999 -> 4220 bytes gepetto/locales/ca_ES/LC_MESSAGES/gepetto.po | 3 + gepetto/locales/es_ES/LC_MESSAGES/gepetto.mo | Bin 4054 -> 4281 bytes gepetto/locales/es_ES/LC_MESSAGES/gepetto.po | 3 + gepetto/locales/fr_FR/LC_MESSAGES/gepetto.mo | Bin 4308 -> 4512 bytes gepetto/locales/fr_FR/LC_MESSAGES/gepetto.po | 3 + gepetto/locales/gepetto.pot | 3 + gepetto/locales/it_IT/LC_MESSAGES/gepetto.mo | Bin 3993 -> 4214 bytes gepetto/locales/it_IT/LC_MESSAGES/gepetto.po | 5 +- gepetto/locales/ko_KR/LC_MESSAGES/gepetto.mo | Bin 756 -> 4501 bytes gepetto/locales/ko_KR/LC_MESSAGES/gepetto.po | 50 ++++++----- gepetto/locales/ru/LC_MESSAGES/gepetto.mo | Bin 5193 -> 5434 bytes gepetto/locales/ru/LC_MESSAGES/gepetto.po | 5 +- gepetto/locales/tr/LC_MESSAGES/gepetto.mo | Bin 4046 -> 4238 bytes gepetto/locales/tr/LC_MESSAGES/gepetto.po | 5 +- gepetto/locales/zh_CN/LC_MESSAGES/gepetto.mo | Bin 2062 -> 2239 bytes gepetto/locales/zh_CN/LC_MESSAGES/gepetto.po | 5 +- gepetto/locales/zh_TW/LC_MESSAGES/gepetto.mo | Bin 3776 -> 3953 bytes gepetto/locales/zh_TW/LC_MESSAGES/gepetto.po | 3 + .../__pycache__/__init__.cpython-312.pyc | Bin 0 -> 150 bytes .../models/__pycache__/base.cpython-312.pyc | Bin 0 -> 1118 bytes .../models/__pycache__/groq.cpython-312.pyc | Bin 0 -> 2268 bytes .../__pycache__/model_manager.cpython-312.pyc | Bin 0 -> 1895 bytes .../models/__pycache__/openai.cpython-312.pyc | Bin 0 -> 5333 bytes gepetto/models/base.py | 43 ++++------ gepetto/models/groq.py | 15 +++- gepetto/models/local_ollama.py | 64 ++++++++++++++ gepetto/models/model_manager.py | 39 +++++++++ gepetto/models/openai.py | 20 ++++- gepetto/models/together.py | 12 +++ requirements.txt | 1 + 35 files changed, 279 insertions(+), 112 deletions(-) create mode 100644 gepetto/models/__pycache__/__init__.cpython-312.pyc create mode 100644 gepetto/models/__pycache__/base.cpython-312.pyc create mode 100644 gepetto/models/__pycache__/groq.cpython-312.pyc create mode 100644 gepetto/models/__pycache__/model_manager.cpython-312.pyc create mode 100644 gepetto/models/__pycache__/openai.cpython-312.pyc create mode 100644 gepetto/models/local_ollama.py create mode 100644 gepetto/models/model_manager.py diff --git a/gepetto/config.ini b/gepetto/config.ini index f957e64..3a427c0 100644 --- a/gepetto/config.ini +++ b/gepetto/config.ini @@ -36,3 +36,7 @@ API_KEY = # Base URL if you want to redirect requests to a different / local model. # Can also be provided via the TOGETHER_BASE_URL environment variable. BASE_URL = + +[Ollama] +# Endpoint used to connect to the Ollama API. Default is http://localhost:11434 +HOST = \ No newline at end of file diff --git a/gepetto/config.py b/gepetto/config.py index ff3026b..918a1ef 100644 --- a/gepetto/config.py +++ b/gepetto/config.py @@ -2,7 +2,7 @@ import gettext import os -from gepetto.models.base import get_model +from gepetto.models.model_manager import instantiate_model, load_available_models model = None parsed_ini = None @@ -28,7 +28,8 @@ def load_config(): # Select model requested_model = parsed_ini.get('Gepetto', 'MODEL') - model = get_model(requested_model) + load_available_models() + model = instantiate_model(requested_model) def get_config(section, option, environment_variable=None, default=None): @@ -42,10 +43,16 @@ def get_config(section, option, environment_variable=None, default=None): :return: The value of the requested option. """ global parsed_ini - if parsed_ini and parsed_ini.get(section, option): - return parsed_ini.get(section, option) - if environment_variable and os.environ.get(environment_variable): - return os.environ.get(environment_variable) + try: + if parsed_ini and parsed_ini.get(section, option): + return parsed_ini.get(section, option) + if environment_variable and os.environ.get(environment_variable): + return os.environ.get(environment_variable) + except (configparser.NoSectionError, configparser.NoOptionError): + print(_("Warning: Gepetto's configuration doesn't contain option {option} in section {section}!").format( + option=option, + section=section + )) return default diff --git a/gepetto/ida/handlers.py b/gepetto/ida/handlers.py index 7e4210a..f2adb23 100644 --- a/gepetto/ida/handlers.py +++ b/gepetto/ida/handlers.py @@ -8,7 +8,7 @@ import idc import gepetto.config -from gepetto.models.base import get_model +from gepetto.models.model_manager import instantiate_model def comment_callback(address, view, response): @@ -67,14 +67,13 @@ def update(self, ctx): # ----------------------------------------------------------------------------- -def rename_callback(address, view, response, retries=0): +def rename_callback(address, view, response): """ Callback that extracts a JSON array of old names and new names from the response and sets them in the pseudocode. :param address: The address of the function to work on :param view: A handle to the decompiler window :param response: The response from the model - :param retries: The number of times that we received invalid JSON """ names = json.loads(response) @@ -148,13 +147,13 @@ def __init__(self, new_model, plugin): def activate(self, ctx): try: - gepetto.config.model = get_model(self.new_model) + gepetto.config.model = instantiate_model(self.new_model) except ValueError as e: # Raised if an API key is missing. In which case, don't switch. print(_("Couldn't change model to {model}: {error}").format(model=self.new_model, error=str(e))) return gepetto.config.update_config("Gepetto", "MODEL", self.new_model) # Refresh the menus to reflect which model is currently selected. - self.plugin.generate_plugin_select_menu() + self.plugin.generate_model_select_menu() def update(self, ctx): return idaapi.AST_ENABLE_ALWAYS diff --git a/gepetto/ida/ui.py b/gepetto/ida/ui.py index fa49323..d5376b2 100644 --- a/gepetto/ida/ui.py +++ b/gepetto/ida/ui.py @@ -1,12 +1,15 @@ +import functools import random import string +import threading import idaapi import ida_hexrays +import ida_kernwin import gepetto.config from gepetto.ida.handlers import ExplainHandler, RenameHandler, SwapModelHandler -from gepetto.models.base import GPT4_MODEL_NAME, GPT3_MODEL_NAME, GPT4o_MODEL_NAME, GROQ_MODEL_NAME, MISTRAL_MODEL_NAME +import gepetto.models.model_manager # ============================================================================= @@ -19,24 +22,12 @@ class GepettoPlugin(idaapi.plugin_t): explain_menu_path = "Edit/Gepetto/" + _("Explain function") rename_action_name = "gepetto:rename_function" rename_menu_path = "Edit/Gepetto/" + _("Rename variables") - - # Model selection menu - select_gpt35_action_name = "gepetto:select_gpt35" - select_gpt4_action_name = "gepetto:select_gpt4" - select_gpt4o_action_name = "gepetto:select_gpt4o" - select_groq_action_name = "gepetto:select_groq" - select_mistral_action_name = "gepetto:select_mistral" - select_gpt35_menu_path = "Edit/Gepetto/" + _("Select model") + f"/OpenAI/{GPT3_MODEL_NAME}" - select_gpt4_menu_path = "Edit/Gepetto/" + _("Select model") + f"/OpenAI/{GPT4_MODEL_NAME}" - select_gpt4o_menu_path = "Edit/Gepetto/" + _("Select model") + f"/OpenAI/{GPT4o_MODEL_NAME}" - select_groq_menu_path = "Edit/Gepetto/" + _("Select model") + f"/Groq/{GROQ_MODEL_NAME}" - select_mistral_menu_path = "Edit/Gepetto/" + _("Select model") + f"/Together/{GROQ_MODEL_NAME}" - wanted_name = 'Gepetto' wanted_hotkey = '' comment = _("Uses {model} to enrich the decompiler's output").format(model=str(gepetto.config.model)) help = _("See usage instructions on GitHub") menu = None + model_action_map = {} # ----------------------------------------------------------------------------- @@ -67,7 +58,7 @@ def init(self): idaapi.register_action(rename_action) idaapi.attach_action_to_menu(self.rename_menu_path, self.rename_action_name, idaapi.SETMENU_APP) - self.generate_plugin_select_menu() + self.generate_model_select_menu() # Register context menu actions self.menu = ContextMenuHooks() @@ -93,43 +84,42 @@ def bind_model_switch_action(self, menu_path, action_name, model_name): "", "", 208 if str(gepetto.config.model) == model_name else 0) # Icon #208 == check mark. - idaapi.register_action(action) - idaapi.attach_action_to_menu(menu_path, action_name, idaapi.SETMENU_APP) + ida_kernwin.execute_sync(functools.partial(idaapi.register_action, action), ida_kernwin.MFF_FAST) + ida_kernwin.execute_sync(functools.partial(idaapi.attach_action_to_menu, menu_path, action_name, idaapi.SETMENU_APP), + ida_kernwin.MFF_FAST) # ----------------------------------------------------------------------------- def detach_actions(self): - idaapi.detach_action_from_menu(self.select_gpt35_menu_path, self.select_gpt35_action_name) - idaapi.detach_action_from_menu(self.select_gpt4_menu_path, self.select_gpt4_action_name) - idaapi.detach_action_from_menu(self.select_gpt4o_menu_path, self.select_gpt4o_action_name) - idaapi.detach_action_from_menu(self.select_groq_menu_path, self.select_groq_action_name) - idaapi.detach_action_from_menu(self.select_mistral_menu_path, self.select_mistral_action_name) + for provider in gepetto.models.model_manager.list_models(): + for model in provider.supported_models(): + if model in self.model_action_map: + ida_kernwin.execute_sync(functools.partial(idaapi.unregister_action, self.model_action_map[model]), + ida_kernwin.MFF_FAST) + ida_kernwin.execute_sync(functools.partial(idaapi.detach_action_from_menu, + "Edit/Gepetto/" + _("Select model") + + f"/{provider.get_menu_name()}/{model}", + self.model_action_map[model]), + ida_kernwin.MFF_FAST) # ----------------------------------------------------------------------------- - def generate_plugin_select_menu(self): - # Delete any possible previous entries - idaapi.unregister_action(self.select_gpt35_action_name) - idaapi.unregister_action(self.select_gpt4_action_name) - idaapi.unregister_action(self.select_gpt4o_action_name) - idaapi.unregister_action(self.select_groq_action_name) - idaapi.unregister_action(self.select_mistral_action_name) - self.detach_actions() - - # For some reason, IDA seems to have a bug when replacing actions by new ones with identical names. - # The old action object appears to be reused, at least partially, leading to unwanted behavior? - # The best workaround I have found is to generate random names each time. - self.select_gpt35_action_name = f"gepetto:{''.join(random.choices(string.ascii_lowercase, k=7))}" - self.select_gpt4_action_name = f"gepetto:{''.join(random.choices(string.ascii_lowercase, k=7))}" - self.select_gpt4o_action_name = f"gepetto:{''.join(random.choices(string.ascii_lowercase, k=7))}" - self.select_groq_action_name = f"gepetto:{''.join(random.choices(string.ascii_lowercase, k=7))}" - self.select_mistral_action_name = f"gepetto:{''.join(random.choices(string.ascii_lowercase, k=7))}" - - self.bind_model_switch_action(self.select_gpt35_menu_path, self.select_gpt35_action_name, GPT3_MODEL_NAME) - self.bind_model_switch_action(self.select_gpt4_menu_path, self.select_gpt4_action_name, GPT4_MODEL_NAME) - self.bind_model_switch_action(self.select_gpt4o_menu_path, self.select_gpt4o_action_name, GPT4o_MODEL_NAME) - self.bind_model_switch_action(self.select_groq_menu_path, self.select_groq_action_name, GROQ_MODEL_NAME) - self.bind_model_switch_action(self.select_mistral_menu_path, self.select_mistral_action_name, MISTRAL_MODEL_NAME) + def generate_model_select_menu(self): + def do_generate_model_select_menu(): + # Delete any possible previous entries + self.detach_actions() + + for provider in gepetto.models.model_manager.list_models(): + for model in provider.supported_models(): + # For some reason, IDA seems to have a bug when replacing actions by new ones with identical names. + # The old action object appears to be reused, at least partially, leading to unwanted behavior? + # The best workaround I have found is to generate random names each time. + self.model_action_map[model] = f"gepetto:{model}_{''.join(random.choices(string.ascii_lowercase, k=7))}" + self.bind_model_switch_action("Edit/Gepetto/" + _("Select model") + f"/{provider.get_menu_name()}/{model}", + self.model_action_map[model], + model) + # Building the list of available models can take a few seconds with Ollama, don't hang the UI. + threading.Thread(target=do_generate_model_select_menu).start() # ----------------------------------------------------------------------------- diff --git a/gepetto/locales/ca_ES/LC_MESSAGES/gepetto.mo b/gepetto/locales/ca_ES/LC_MESSAGES/gepetto.mo index b72cd646f9407a9c5cc472d2e6987d44760e994f..038862a9d8c331ac098d66501618ed8ea0731625 100644 GIT binary patch delta 761 zcmZY6&1(}u7{~FSF)?YY_D!qSic_S-Q*AC)a!?D3RB{sAg3`jWS=+_180dL`Ne26!WNt3vNSMUptVODom2n#9IbUl`Y=z&@w)7nRyN#1b0qSSi3(*kq3ZemNobJphDh&!2#l}G` z;@;v;X$&#rwM2tdvhU8had5#yVOx5N!Ck}?YM_{sEH5c%uLS delta 540 zcmX}pJ4nM&6vpw-nwZ2_wLVfrY28HdfrH`$v8WWqMbSaTE-ns&Lf1kZEx3t`2rdrd z+QHGuMQ{;uaBy%CbaqiH_|Lt?11G;< zv4|shjl=j-+XtR4eiA1*U&R}WFXK1+Twn{IZj?6IUt%Br;12gI-6ZXCP{I!Uz_{XH zT*rK~G>Ye##s}=j_nLl-?R*edI4|P?K4T8^t<@gqW-t#($((wUgYVF0zwk2qx%B)Mb8=j{!}X2rEMg$E_TB KRqy02dG-g-VlX-Y diff --git a/gepetto/locales/ca_ES/LC_MESSAGES/gepetto.po b/gepetto/locales/ca_ES/LC_MESSAGES/gepetto.po index 794f60e..7606552 100644 --- a/gepetto/locales/ca_ES/LC_MESSAGES/gepetto.po +++ b/gepetto/locales/ca_ES/LC_MESSAGES/gepetto.po @@ -122,3 +122,6 @@ msgstr "" msgid "Couldn't change model to {model}: {error}" msgstr "" + +msgid "Warning: Gepetto's configuration doesn't contain option {option} in section {section}!" +msgstr "Advertència: La configuració de Gepetto no conté l'opció {option} a la secció {section}!" \ No newline at end of file diff --git a/gepetto/locales/es_ES/LC_MESSAGES/gepetto.mo b/gepetto/locales/es_ES/LC_MESSAGES/gepetto.mo index 0f5e95281f0d3e7d011d70b1e4af8adacf7eeaf9..4070046b942dab6af75283cc63fee9bd6c4631b7 100644 GIT binary patch delta 767 zcmZY6%}Z2K7{~En(~L8{Ic2YeI7cCLrOqa0B!?i9B3g(T85Wo8(Jtn>GINVSL=cQt zZ6XS)bwSHnR3H-c4@isJge`@j{s32f&&+gN^uU?VbDw+XJm=+W>0|HOPFMDsh%=0C z#&^a!#;bk&5W^YiCzkOpcJG(2;s*Nfx7*LZ@HWq*S?MS~$8mg%S8x{xu$*hXpTV=z znife6FCO6}zQIelhqtiWA)UvUID}tt5Hopc8q4VGkMT5qz!BWW9X3%YNFBI)KyuW7 zACzw5>>;T{f4wKM6u#jSUN|hBz;}2ZKcQbx@d&4ZQ|KFQU=O~=KHS7R_yhg_s-4n9 ze1cWX7o`eL;25rBNPlgS3}GMVR>DcVj5YMPaEx)B!GVZd5C@RP$#%x3(qO}4v&~>n zV*O%GX`CYCw!{Ic?C+ia&0wMbh2N<^4b~#AP+g5&CYzhS5!77d7Aj`KF54uD`|Bo* z-F&ogzZN7>?95zj>#jfX)g*|ViI~CSr<0meC;qfV&f6}aI90>#AP%W zxxy*@K+~mcmsG+MmT?ge@DR;+&v=UYZmEh*oX1yOz;r$-q>2IOD`c{UmuR|9;Wp^B z=&A;ir5$@2dCdj|ief}(MKk`R$gGFtJoS^ND4Jk{8pZwWFLKeNC>wrjJmq9s&0NL3 IE*EZp0qx)`?EnA( diff --git a/gepetto/locales/es_ES/LC_MESSAGES/gepetto.po b/gepetto/locales/es_ES/LC_MESSAGES/gepetto.po index 08f1dfa..3c7ec07 100644 --- a/gepetto/locales/es_ES/LC_MESSAGES/gepetto.po +++ b/gepetto/locales/es_ES/LC_MESSAGES/gepetto.po @@ -123,3 +123,6 @@ msgstr "¡Edite este script para insertar su clave de API de {api_provider}!" msgid "Couldn't change model to {model}: {error}" msgstr "" + +msgid "Warning: Gepetto's configuration doesn't contain option {option} in section {section}!" +msgstr "Advertencia: La configuración de Gepetto no contiene la opción {option} en la sección {section}!" diff --git a/gepetto/locales/fr_FR/LC_MESSAGES/gepetto.mo b/gepetto/locales/fr_FR/LC_MESSAGES/gepetto.mo index a4dd57b23afa4828787a62483df412f9df9c9b70..b7c5f84614c9ba37d82b5541feec35aaa20020e7 100644 GIT binary patch delta 682 zcmZY5ze^lJ6u|Mf{zh{-P4FiX=M1tZ#E6Ir=P!hysErCSxd`gwZOp|6nBsi$prm#9usW3I#ih zMKtxLt3+gu+#!armg;NRi%WQpZGIdODaQpI#!Vc-YkY^{mm(GT5$iCHCT?LB9^)6h zz-bHz>1a2L4Hgkhqu;6~`qF}NjNvY}Vs2UX2RpEzI9irEFzSuU)P>O-@LR}>GL?d< zGvl>UebkgQoIn1_n%9n=ji>=TVSAoyCe@6qzdExE^OonhTE$#DsZGz%dX}S=o5-Y> zv#uyoNqZ)ndt9u9jus0~i(-AAXKQ~aqT<%mYwiCPm9{gNJDgPsE2-k<(;6{LleV50 KQlV6P>Fxt=8h?ub delta 492 zcmZY6KS%;$7{~G7sh9mV3o4DE&XNi&i6|lHk06mr5RFCA6oC-6MT<+cMAPLcv_uVU zLPJZsvVfkgL=Uy4dVp{@g8TeifMFgX$Z5({hUroy4g6#J}hA~zT*nk zFonrFX$_BX5~~=;uwNR*G@5!JyYLbdSi&co`9fi(LE0gGXq5K&zEVw`ij8YL#78{9 zbxtv1YV5#3iuzqDscXLsttTPZQ#yfJizI>({kn8es1zn0b!eGcb|(-1z=QOI zjz&o6U9YkP{RthVQwQ-X2trVF=+yh!*``BZc;@xYcZTQtJm2Tz;sT?S zvB$W|*gVFA=ub(9=-?c7Hc0pJ725xI>+eT6$@|T;)QanP8@KTu9$+sznfUt?yeO^c z1rx)E*Z3H>a1?*zQye}n-M|g(!>@P^Q;pIXI%xB6@jQOULEOcC4w21DKY2emAw9>v zlhQa2H%VR8*E*9jiBI?rdrwJcu!=+Y1@B;{nQq_%Jc~X(ckmg0!$&x9S_*L$ zuj3CaU{6cDj*Ao2*BTQCzheh>)4y9dg!UHN80`!iB5px6APKt>8yqSL4lE8^3(h3= zFZPtgIcD6JXpkIx@6@dY8~iV9OKmsoopOcBRWqq{W^CLo2YxVHFn7I@7e!&dVy41i z#-Dvzb|XIw%yj5gf_!ANk?RL0EY;GBvCD>3#hZ$A^~v&;YPPW{J25gp?^nWtdFcLI pbAdHIGwuJaX98;2wSA9#F&~!VKa%wWuUK>`QBNl&-sf80`~mhwcE$hz delta 540 zcmX}py)Q#i7{~FSy48E?rS#$=XeY6FYot;U3mqyku!zt~Ox>E4{hAuYCI)|jm$iwN z#by$dp@Y$6XlxSSd#-qr=X}n&=jNXCJoh_$oe5fL=SqZ6j+1X>hP>!tAl#Vri3Y3K z#tkfYM(^KZxPHb8_B$lja1~FmgpW9lCLZnQa7+rc#w5!|1-CK4b$rD=^t+@y9^o9` z<1D`atxY1jd;ynuzJ)C=e}-+=RU;i@IVtVn6Heerw^ZkR<(ZuG;0_0Iu}50L5*D$6 z^Z0^8_=A(^^-5Ei!+qSvO?p-<*j#@#bkF>0W-N F8ow{oEt>!U diff --git a/gepetto/locales/it_IT/LC_MESSAGES/gepetto.po b/gepetto/locales/it_IT/LC_MESSAGES/gepetto.po index ac5c54c..6ba6fac 100644 --- a/gepetto/locales/it_IT/LC_MESSAGES/gepetto.po +++ b/gepetto/locales/it_IT/LC_MESSAGES/gepetto.po @@ -119,4 +119,7 @@ msgid "Please edit the configuration file to insert your {api_provider} API key! msgstr "Per favore, modifica lo script insendo la tua {api_provider} API key!" msgid "Couldn't change model to {model}: {error}" -msgstr "" \ No newline at end of file +msgstr "" + +msgid "Warning: Gepetto's configuration doesn't contain option {option} in section {section}!" +msgstr "Avviso: La configurazione di Gepetto non contiene l'opzione {option} nella sezione {section}!" \ No newline at end of file diff --git a/gepetto/locales/ko_KR/LC_MESSAGES/gepetto.mo b/gepetto/locales/ko_KR/LC_MESSAGES/gepetto.mo index e1c44a021606bb3d4df38a818b8792c21fb90a28..6f6cfd88e0a845cc3b19a35b0ff4d1b8b7429a48 100644 GIT binary patch literal 4501 zcmb7GU2hy$89qRvu+Ws2Z-nTH1d6SA*LIqgXhVRIq-mO_bseXGT49ZMkJr=8&MaqU z5*y1>J~m++J4#yF&L-HYt8SVPxxyrNx+t!YxI^NCUjXV27dvMp$OV@b&wFNOKZ4c5 z$m3^s_MG?Qd7t;4^S8$znq#=Wjr%*e74GljUjG#S;TrffV-4U-z%k$tKf~B_z-vHS z|IhgQBM&k5OML%1up9V3@XNptfj;3SZKuK~&T8ju733j7lAU%)}& zN5H+nmp;eXE5Lc+)4=<{USRjbjAeko0KN>Y0O|SPf!_c=`gz8F0Neun9da55enihx zjJ*!r^hN9iR)OCI-Uq%4d<;Px2im|Lz<&TA0X~6$lusJ?EYJdO1+D?V3j7e52L2mJ zJbZ6cwC@F=gYT~bDX+f(4d4@xGWHYTN#H@?pMcK;H-VfY@D<=L;GcnAz=t1WED!7h zQi%ZUYq&uZLx~t@Xhf%d3FA5w{2JxzZTpZ6MWRMEa$js7x*4N>f1TbbnI+uGB0vYv1D384m-YA z^1Z3lVZTriuE$4&=LyM=8PYUHEWvG~DBK<{Mai1r$4zgH8~nM$FTTJHDGfXs6H@S_ zV#39P$fqOCf{7TSt=zElux9z9wW1`QlH-bev?9azI^1?VF2+liVcOU+;f*2Y9`4w1 zT9T&i>1?#$ys5J=F5J_|2Wi$Bby@MtSqRLopy6H_p{avZW-!mE@-P#Z$JSRk6{ zZ`(;2zw(8gXxSUEgy9M<@@7=)oMVrg1s_}zA)~Y#beJfG^hh0&PZ}k27@|97=7pT% zy9W-CVmGtc64j=&St6l1IiD((O|e0t6_PZCiw6h_-J*G?8c%T-ONsZxLU7+TKnX(d zq_5TLLdtx<={@U@V1Xr|Ax#q-f=D{NLGC)vJhme$(?#n=!s7UB+PaDvo+v-k6-1k? zn%cmm6rL|_aNBO-<5haZMmu=E4S`w{Jy9S@0m+Q#IDEu}$N>UPrAH2mM+2fM{4Fle z`BIYFiF3;=nx30sM^L?3CA2?jd)gK+N7uWWCB&R~eQ+DhQrBvVk~MJHsuA3{9e!+S z=A!6Y^A9_s3CMnCNK#xj=|u{@Ls>LFwBXUB)K5;N-N^{2s9r7#OCp?%KTK_ANum&V z+>UYgM#EA@LTafyS+7M8HQS~;h5^vD^T2bC!eeUYM1HE7dY9V`(GWZN3>%QnYltd+ zAfJ9&NSEe+k|FQU^rr>|^>g}=TQKwKpZWziJ>+EhYrgB5Im{=>czD?GhKD)&Ul&<+ zQs}-*?}bvbd|ThPC)0iS+s6B!%Jy&H+SlLLhlS}uam*y=8?4)& zeX@URAFkB%hU=w=q+z?3;W;wP4;;gQKbYy!cJZ@a@DWjZNEpQ|Kj^@>&AT*q^R6+^ zE4kU;UZYgXU`*StIV#38V7s>vP1jxz``q6Cr?>ZQf2#MWGkkEcw>x#{z@dFDYWg#M zDLsU3L^4rAN;Hpq^^0jI%{u7{Z`YBbXVN=b>&V%tkm-FGV|lHi9V4dKj5o`*uhf8r zA!VhXab(fWa=WCzy1V*!^62BPzMUx+OxM)n3|`ma+znN!s!JDnFmocDTjjysvMMjq zC-u&{D%aJO*^OW2!TDJpoT{qD6Fj(c0!Fx6Tn*kWtHm-`OABhQ63$f{cb0m1c&Zkh zo8{`a7lW&9 z9Z_7rNGA21Q>T{IQaQLj!`0n|VCh=tK88s79yMQA^Q$of6kxEtsxJM8;)g{%Aft(g zHaarA@qRsccbS}#p_szOR%#&N?yS1K(CKxkPQ&|_8 zx_2Wyb6%aRh3D$LQLj_+uB~&mT4~%_59cdQN1M62Tn^ru;o;e(@Z<_b)qV)jwW*=% ztGdRK?`CM7Sq0~@eKjnvXixL^bk%O%YxjnlR;s~;%W7erEH_q8sHM89)*AH{&{Nio zY9&w{!TDM^i)S|@O0voLqMdOFwCc`@#_bO_;;=OjqZ-E5X6ldYs4rL3wQ&A*;uCdV z4^~isYDHbHBPbrMEU43q?UA>PK{Hg2aApSEbQN`dNT!#TH##OzgAn&u&uFPuqqf_e zbg5Q!uv`b3dRc6(O_De~bo5}Hb2}fAI7qK(40ZPH;3A@FRaq-WpBxgwpo6a53~tX+ z0^wyqof6OkfkotkuBGraq(nxeqQp90T4>x^BeAuGPOK#Y*Zr#I&NW4*%MQ|#j+w4S zg1=g!5p!j>U}8+8@#nl2BRUWl`wxzoj$3i7s6BVAYy1PjKb#oUa!|3 zphr{r59&;{byP{zRbA7M<5LY7?}9$Md3a$N43IW4nbp;b9>HxoX`WD;I}%-^Y7ty=2Wxf{QUq~M(2W=Ag;zmz*-b&;XInnwJfHvhowM_@ed<~{lKmd6 zoe`dWw1c!QT7>q};>B~mM&t+f;%z*(L*x?9pg(`J+W(8!>0hW7X~7qG7CpR(zp)E@ z>nh`8xL+hI84AU~Gi<_T9KsLSk8SlL*D#8w(8Ck>3%jwsL1Y*&qn|He8-Byn_!CcZ z_{dI?0++hlDDsXxz-C!K@|of&!CmJ7; z9KcCDfFJQWZsC2*?dDg(KN!XNporst(F%#&z&rRDzhWOoTSX4z9Hw!tH9u>0SbOp% z>s+({f3a}+s)77#D`f309I6S|4Ufde?WCRPQ3LT*+;yGKv>J8nTgk-T@tB)*Y<1I# zr|nMHPu*D3R!*uMXDUu8qI7(;l2wa|aN$ftcT-l+nMqyLOJ+*HFIR8$68)0)bg@V2 zX{A^6sxFy_1O+q2lvm~vwfFxpO0SxndBBkW&gn%~Dw-TI3o_F@Romy#i^^X>FPHat PtBa~JvwfcYSZL!P$#S!o delta 476 zcmZY6!7GDt9LMp`Gn*%j%_hU9OwpnTE8${FO5}nYOO%>gaoEX0JK4qYX{p6=ZBZ_d z6LNP@E>1KTrO3%Y;QjQt`PTP&eZRloQ{Uh3d2U7?QpHBZUl!3rCdgNEh&(fR5EDM> z6Vtedi6&_okI{Yq>97CBE&3V16vk7`U>%q615=m|c>5dJCl!^a7+|A>L9F2#UgH9q z&C({W;1t$z9N%yZL#EefFh;+R(^$eu4tI(9KiA4`oBLp4tUM_crh=e}hyA`>`?ow$ql@h9HH=4z2K96_7k#76vrUHAiSF{v7n1b#jz zk|%$BUSyc_W%7c^EfVk09=wklTe*#`WCaIt4M*_^JFp|c8q|0LSMUyQAWLK)NALg# zu)S7f8bf@DUvc;@orbznrO$Dk@g}D55AuoRCbzH~_0$*MIxWuS!bhhkVgHxCoh`Mt z5IGIbm7VKl$;z>3I%_;L`#=r288;05wm{AJ=7l%AnAM@@8h1@25KciLG>?sQi>HRDfnIvYp-w>FDL<#Z;F=Cw)3QFcKGab(h3<^R2?ri<_A fi+c;jJ(G6T*%^Ub+;WW@uX`%zze*Om>)xCI3_*bU delta 475 zcmZY6&npCB9LMp`tX<>R7&%yJ4MQs>cK3pkQQGpOZIdIV$icQHCy~RP>``ixi<@NA z-Z;w5!GZh%;;59P0|$AJar4x7Uf=KYeVXTaW)9qk9{;&Lek4XGks_XmG;w0^U}R&` z8|H8WQ+3ij?qj(A@vnZz73%T0)P~3CVHM}_6$dewi26&|CHX3o^wOwc3!dX5USS^1 zdTAL8IEq!w;0unRZAP_+N$O1;!wS~8#0^?lZj|tfpi9t=3!!B@B-4MKw1Du z9{|!CK>8w-{|iV50r}j_3=A$nIu1yi0qMm+S|3Q?0@7hXnu~>jVFM#WG(#egaR4YF z#Kypo3#2atXF4< zJoOTgXh~vboBShyNO*D3&oic^yzVo*k{;^v>sqAXr7)-HQGu?udr!t+@zFI#(` u?_KwF&I*NUm|}&O%`2XEv_09e_Q{Ukr#-DtC#-(i(D`!VP6e19iVOhXm}Z9n delta 223 zcmXZUF$)247{~G7x&OI=a(__oRC1E#{(=gos273HqN#J@DP%Ip2AoyB6H59@;F+9}`Li zQAK2kT4LVdLut$t`H4-~kJnC!JjF58<9{5s|6&jAC;1|!ID-$+!!Z8AE7(+!>wkq8 zMTR6vMbYsYi*Ok`aUEN*yinvhzQ9}PVFT{s4ZLKEbYVN{{1lesH@uC%aD&U#p5zd; z_fLsLapyF1*k7KV5ve3GfsMF=`ossQ57f%19y{?Kx){V2EW_`p=kH(*n#H+!SFwZk z9c;!CY{Lz_j+FtCo4OLFqIFSkJcoLr-*^ryN_aMGC`peSmyPr3RpVi?{$Ki?^;HAu zuf|#9!cs6VRM6EO8MNY7-(B@MHV|_hyDp)kw$&T&dp#I&;7H%_A8`p&z0w-)VH%s5#TS#!g0AB)P{09l2^aC$ k@oub9OJ2*GPnz>F|J9|K|JS08bRZdOT-jmU`*x0g0bS=ac>n+a diff --git a/gepetto/locales/zh_TW/LC_MESSAGES/gepetto.po b/gepetto/locales/zh_TW/LC_MESSAGES/gepetto.po index d637e56..5e0decf 100644 --- a/gepetto/locales/zh_TW/LC_MESSAGES/gepetto.po +++ b/gepetto/locales/zh_TW/LC_MESSAGES/gepetto.po @@ -111,3 +111,6 @@ msgstr "請編輯程式碼以新增您的 {api_provider} API 金鑰!" msgid "Couldn't change model to {model}: {error}" msgstr "" + +msgid "Warning: Gepetto's configuration doesn't contain option {option} in section {section}!" +msgstr "警告: Gepetto 的配置在 {section} 部分中不包含選項 {option}!" \ No newline at end of file diff --git a/gepetto/models/__pycache__/__init__.cpython-312.pyc b/gepetto/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..985dfcabac10a1a6f227bc0a69ac7d1da9ec25cc GIT binary patch literal 150 zcmX@j%ge<81l%_lrhw?jAOanHW&w&!XQ*V*Wb|9fP{ah}eFmxdW$J7d6Iz^FR179# zlJiqiW86~JM1I3_+mGcU6wK3=b&@)w5cQlrz)B1L89GxH8-7C&j3jY9* zkPKbgC_jaY3bZsxROw6%O)7To{NN0sV2inLX1;G`c2>F%zFz`;Zoj=~M+U$*bQZ^% zgV`euCP0B25j4PreW2zTsKwyOT##4ab1TCQGcl_61v>R`l?GVWJ=M_(FlZPQ8YZ)j zyx(aVrY^O{_W3-~aF}}y&tL`S-POFdN$r(lsVx3O>1mDg7fmTKs|Dr1ApK5EI_z*j zS)^RfWX${Z&|aGs?>)od1QIwi>Ss{3l`UA*R?czEc%sa(r7DEduIvu_Bno<1Fc=Kc z6#}IX^tpOid)#~}7#FE+hMLvfXG13CpxK#!Q>kdSh+w1*F%?W($*G_391f1*td)R4Kqa#1YeY zBM6RT5@A{GS)>)JP0|vQlTh}Ud^MmPE2<44iR^}1NZCzL*;bzN9nICh3J@roT83eKfzl7xH~iPN$-$e0 JpP*q{{TCJq_xAt* literal 0 HcmV?d00001 diff --git a/gepetto/models/__pycache__/groq.cpython-312.pyc b/gepetto/models/__pycache__/groq.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30dad3addc8e4600f2ecabfd4e92b01b8bf956d5 GIT binary patch literal 2268 zcmbVNO>7fK6rS;V*FW)JY~ql=X2VYtr8|wMI{$*p-EN5sc&|@A)!4nlHcF#y!XwU z_nv4+LR#C=71Qio+&ZBy!y>qn_ z84;|)G}_dq3d&))l|SRFpmCh>pF=5;dHbkwnMssbC0f`Pjr-vd*S$ch9J8mN^2C0R zsGr;6ho0~6IW%Y~~TnIv`VMHP${KmFNAk1IXiqw5boM4d_pR zb{khwU<;`0Vj10p+m!K(3Mk_;QuqmbSv;Ti8ntWM)oQlzcAaY~Rz0fV*5*}D*<-$G zMcwbLdRJ1tZp_hyYkR>}d{=b$U5=6Eb`Me|*ZL)XC5*ngT9&Q~uXNt81~%{LwHjRH z|JQS6PuW`*Xwus3z0qd(`xm?{+zq;SDnRett#kdudqmvtCeUtV@v^tq9z+F<2GNa` zD`;F8L|1SsV7BZ zj8Z{mK^w)u+0!2nGXJro=X#a%XHT-QZFN=f=z}=p8ZlD>^Ks`8i1{x9V5LNLR(Hv= zipsFAk%EYEUNHy~BL9PG^#HQ5NLwT$9Bj)e08F z3X^gM&1Ws<<&(f3B}`Iyv~ec&*zv%D6IWthlju3x1}{JxfiQ4mj8D*=Tn;FJFKe2L zVoIRPd`;PXOp4BM^jBIZNe`~}FfDYWJh0`(6=_1F6rLales{~Fr8=10kviuXqF zVZ5oDo=GpoI~U@evme|$@#E)<@&1bMK{N>mZgtG>P!{8#{E=vyIr!b!*9R99=?wvg z_pQcT=i3e~#t+Yj4zI=(UPXMEXs#ZaIkJ@KT1a%wUbvU} zseLhVrs7{sHdnP7ZFxt_azoow3B?;XMAX#&ZTOq;^3Jxu1y57tQLDUBCp88u-kWvL zHd;`u`7eY6^$$X^H39f@lQ;TT(VkuC{;ohzTDaehfu;idx%hD9&!zAISJJs0Qk2W3 ztX|s*nHJRJaGqF~43!C{MJXS2l(&JMM;5N-d6?+BqNFf=1GY>)ecyMo%_oKZko=%| z2*|T(v@F)Id9awga{$1K$#nrqaoeyaf%Fu%IyP+`|42zOxQ2i@!K lVCBFD0=9MxBkaEx`69A}S{G3366#z)owLV(L;LM4{0n<(48H&X literal 0 HcmV?d00001 diff --git a/gepetto/models/__pycache__/model_manager.cpython-312.pyc b/gepetto/models/__pycache__/model_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9735fcd78e88ec88329569c8211b20247ff8b73 GIT binary patch literal 1895 zcmZ`(-ESL35TCWrU+36$8mDPc+gv{s#cdKO0*a!56jDA)(^eu)qgpA)#kY3P*mu_M zxoxaeiclXCh!j^vN+?A{;)Rm`0(d~;CB;FwM_&-|(6^9Os?;ZD&pzC=Vl2<>?0oF( zZ)fi3SWE__*mn=Jk0gLUh2l}9x!Ar<0G43_bkIow(hW({1a)7*SM;a-1bjfHwPrv% zKqcz$MvVq_|1e5jN{93S^=Ey0@P?EQ>!BNvj_6_JG9~#v&A_KbUW7kGG`5D72ujEC zJjgec9sGZ?_3UtR6&lwgYre+5m-i%GI)&L2G(lq@_W|&Bn#do1K312#2u2oPz*=_G zgdX*>y$b0L+)LuWlXm0>o{#sC!l<@2|IQi+qAY%1HUf*MC72PDbvK! zJsIB}L=D^1$da_eG3LSAgM8OV5WqOH0mzhZcbuxJ%BaBu6$l zFHdI*8s|(L&O}DXujzV?Qu zqGlohY_K%r#2L*Q+@?%*O)QETMN{)$WOMK!+F7M{E^S1Q)x+_6B>wr)dy##({42bB zDZUV|mVXWJ-$qZ!2D3};8CtAQV8bxrn3wPS_ z&%?;vD%Pb=*LlaQh+OVDw{h|GIoP8I0Iri+m?e4P@eFyrbp;_@+~Y&97XBIRo^g=C z76`kfqxDOiV>{^WpyvN7(0Vc3`wl?M$Fq(@N~(M$D@dnIZfmA(Xg1}FrnrbJ(;BC` zVws9zbH&t()O8&yX-q3NT!xhMmub`6wv;k&x-e>&nK|V8SL8{JnTDAgQj|)sqFZ=^ zY1s-3WBb^&BhhO#lj?G$T)~=ld=_^)n)5K&n1?|D zzB^lvpHsU5#V%FVOvFyJ4a?-J$~weN#9+96>E6kaN}7Gy3F!tCPv(SGHER^8sye}} zRnRGOq`YOA4#_%UYU=!|f!>l;qNd}^>Z}WsIiZYIEMbr%F{o`i?7#BzbY z$?zMCpEAVoCuEy`LdXM%{tl79Vc?I0N51KtljldPv3qR?YS3G6@1BcqNyh2_$))av?&{0SCvKit9=v*vid1(tI(myv+4~6I7`{eox`C9%$UVUR@@9}y#aX%89*Q&|S_HP0S^sgMcA8uQi ztbVXOd2{mC#Z|KUS}k^TBYfIl_UQGxPYLb literal 0 HcmV?d00001 diff --git a/gepetto/models/__pycache__/openai.cpython-312.pyc b/gepetto/models/__pycache__/openai.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4efa0aab0353bd0cca5abcbf654cda1b3bccb68e GIT binary patch literal 5333 zcmdrQU2ogg^->fm$&xLFj%CM57RIrgYKh%8PO{Ex1Nbw^T02?lw25Yvp~Y*7GDWJF zRBS2skPX4$0sE+d0^WcE8-`$Lkv!xl3`N(My+~6AxlDissQ1t}CvkzHec8FBC^?=B zY~7xA0NpR1bM86k-t%$w=hoI11n`$Xd_LD1Md&j+@qnujo@QaNh7_bQ8cMo7lVqH^ zFX?l7HtBbIF3CAPpX8zUX@OiY8Dw0VP_ji~lVQc5j3`{P)rZa^g};TA01n<^-Mx}+ zMJ5saoKB_^jNSUWtfvcd8c!Mu)=EM;Z}krjoa(m&csN$ z%-pul89kkX`5MBAN}lv#R$;Jz&Zqcp`R~;6lN_wFw@_owEB;$38BjQAgGxvV++va; z=!4L=6!}Dp&7QqDZSw?!MEWz^QkHbntPh|65Rnl51V8lU;ONlI zWfK$AX=hTP>C9Q2$ChQxr0ac-s&CF1KrO2S`J&yL#+H=BdO_0V94_qvDF&Wd5`eWy z5AY^>6zwGRke?Bcisp#d3xhReLFCb3 za&)vD9CdYRNC_AL_&K`#o}s!dF-u5-B^R3dWE=W;-w?`D+8v-Uicev^8@cRT@|iDK z^t>Ffs1gn+{yVTAg>0+aZ*LXmF6VvsdInkGV=gGXxB3(28_5{BQ@6n~`;F+~J~wpq zTYEgYvh?QcfF<9Oe~Be=tHb+{rN^@ame{*3o=!UW*xj(lAI>A$0XQqU(wtrZ{I5bj6O;CAoO% z!j-FbaAf$6aq04<*X?M1-CF>@0oZL(#U^+~fNlZX*}+*5TPhIE4reSYFQqgU>sF%G z=A5f;2kH)O^R7mA3w7H1n9CDYw`@kT`8k8+WXo=OQ`QQ2oDhT9yzBg~<#>s*U$psA zX9Hm8Zp7Js6KiwiTVP(YpV}Ch8jD480LD%@5GlBh7?dPcS1m~@MW_RMvu=t?Bj8Q6 z-qF3{|9Q)Xuy1W(b)Y63s|d$FICoDi5C5PlTv`cy8jU}U?0N5Sd9PFzuKlfj-|Fe# zExvoY+Wz7e%e0=@5Dt_NzETxlEk|D6h;^>Dt+v%-iApT-!SnZ~s)&hOU(f7g;IWrCmmLwhhZJ{MoQ2NZC=kH{lu=AArB*)@Ti;p zZo4eB z-zi{Hq!~pP!O#04dGx1@TpqH`EMyudbc?dCh$&gqC_e3H<pa?ee2BafV zh40s8Rg-~opt0|WW~Am#v7lS3_AmGZr?7e*D-GQYP8ahkHDB>KWL(ma-A%^F8{$cB z27L8Kk5$a$o*}U(YZ|)ba_L!a$~tBHVPBH)*?wYZ*bYE4wP?=uIWgT1=CEnP1HvTK z#UX76Jqwg3a`K{@E97W(Vc|tf)UckmGV};%lw)9#d90hq#f-6luVX^X4DTF3&n8Cc zhnIDb)+*?-g{fXv1~@dV4`a(P#90-r0Ud^mBx}VIHI)jRHIDJ9DW(d9fap$1qNzF6 zG6za$&SD)CSrhSM3On`}v7Ulj1x)Uko7O_4pzEHwox40F-oUhKSO$eTXH>MHLSFB0t8jzdn|>7jzs*iL#WLD;=gE^=txvtA2 z>}2~=vv$|lU37AiGGegJNbl{dwZ7ApzSH-}AJ5gsrYmF9 z<;!nY#@?!pNtH3FHkPi8rSJD;9`?Oh>l>={4SfOC*(?9ph<4PXeU)h66U2w|%tlOD z3$2FA-KqOHQ&zK8JYS7zE5n~gV-I5;@5$w^LN$KX zS~Z<3$FEi6dM##DVg}S}dls3mws;ha*P@3j(Zf#=6Ec{3!8l$OCThZDMVKs4U#U&M zU73FS{+wFQ&R6HOvM^Z{aSP%PgPaGy6zx?v}DfUmNA~61h z>7N*6KN;j;{7n4@4BYyS7Ip3mUS9i8vWe$$|5fVvw~|49r$#uf6FI4%(2Pa@-@H7JTADb1jNE@PoaaK=hNlz# z-$jz_7Kw&{f=0&?NqVOsYxOS?(}EgX9bCara3$1f|Hhm1&_3J_K39joYi+I6V*yLgK7(R_Ljm0tM2`_Lh zPo8JE?x*}APGhzvm;e`X_n_<(QSv>y%P=%{$jU$zS0Inrob$RQwD=|8rHvDU+v|_l z0Vk1|1N5y+SW4RFeQ{s1gjV2${?kF|`+-meTy>sR@yp;VuIm{s*Yzk?`Ej=(=zSn)dx=6x-sz&oGgd{x1;i9o_y38yO%Q literal 0 HcmV?d00001 diff --git a/gepetto/models/base.py b/gepetto/models/base.py index d771e32..cafbb74 100644 --- a/gepetto/models/base.py +++ b/gepetto/models/base.py @@ -1,34 +1,23 @@ import abc -GPT3_MODEL_NAME = "gpt-3.5-turbo-0125" -GPT4_MODEL_NAME = "gpt-4-turbo" -GPT4o_MODEL_NAME = "gpt-4o" -GROQ_MODEL_NAME = "llama-3.1-70b-versatile" -MISTRAL_MODEL_NAME = "mistralai/Mixtral-8x22B-Instruct-v0.1" - -class LanguageModel(abc.ABC): +class LanguageModel(metaclass=abc.ABCMeta): @abc.abstractmethod - def query_model_async(self, query, cb): + def query_model_async(self, query, cb, additional_model_options) -> None: pass + def __eq__(self, other): + return self.get_menu_name() == other.get_menu_name() + + def __hash__(self): + return self.get_menu_name().__hash__() -def get_model(model): - """ - Instantiates a model based on its name - :param model: The model to use - :return: - """ - if model == GPT3_MODEL_NAME or model == GPT4_MODEL_NAME or model == GPT4o_MODEL_NAME: - from gepetto.models.openai import GPT - return GPT(model) - elif model == GROQ_MODEL_NAME: - from gepetto.models.groq import Groq - return Groq(model) - elif model == MISTRAL_MODEL_NAME: - from gepetto.models.together import Together - return Together(model) - else: - print(f"Warning: {model} does not exist! Using default model ({GPT4o_MODEL_NAME}).") - from gepetto.models.openai import GPT - return GPT(GPT4o_MODEL_NAME) + @staticmethod + @abc.abstractmethod + def supported_models() -> list[str]: + pass + + @staticmethod + @abc.abstractmethod + def get_menu_name() -> str: + pass diff --git a/gepetto/models/groq.py b/gepetto/models/groq.py index 2570dee..2b83483 100644 --- a/gepetto/models/groq.py +++ b/gepetto/models/groq.py @@ -2,10 +2,21 @@ import httpx as _httpx import gepetto.config +import gepetto.models.model_manager from gepetto.models.openai import GPT +GROQ_MODEL_NAME = "llama-3.1-70b-versatile" + class Groq(GPT): + @staticmethod + def get_menu_name() -> str: + return "Groq" + + @staticmethod + def supported_models(): + return [GROQ_MODEL_NAME] + def __init__(self, model): try: super().__init__(model) @@ -28,4 +39,6 @@ def __init__(self, model): http_client=_httpx.Client( proxies=proxy, ) if proxy else None - ) \ No newline at end of file + ) + +gepetto.models.model_manager.register_model(Groq) diff --git a/gepetto/models/local_ollama.py b/gepetto/models/local_ollama.py new file mode 100644 index 0000000..dbc1a65 --- /dev/null +++ b/gepetto/models/local_ollama.py @@ -0,0 +1,64 @@ +import functools +import threading + +import httpx as _httpx +import ida_kernwin +import ollama + +from gepetto.models.base import LanguageModel +import gepetto.models.model_manager +import gepetto.config + +OLLAMA_MODELS = None + +def create_client(): + host = gepetto.config.get_config("Ollama", "HOST", default="http://localhost:11434") + return ollama.Client(host=host) + +class Ollama(LanguageModel): + @staticmethod + def get_menu_name() -> str: + return "Ollama" + + @staticmethod + def supported_models(): + global OLLAMA_MODELS + if OLLAMA_MODELS is None: + try: + OLLAMA_MODELS = [m["name"] for m in create_client().list()["models"]] + except _httpx.ConnectError: + OLLAMA_MODELS = [] + return OLLAMA_MODELS + + def __str__(self): + return self.model + + def __init__(self, model): + self.model = model + self.client = create_client() + + def query_model_async(self, query, cb, additional_model_options = None): + if additional_model_options is None: + additional_model_options = {} + print(_("Request to {model} sent...").format(model=self.model)) + t = threading.Thread(target=self.query_model, args=[query, cb, additional_model_options]) + t.start() + + def query_model(self, query, cb, additional_model_options=None): + # Convert the OpenAI json parameter for Ollama + kwargs = {} + if "response_format" in additional_model_options and additional_model_options["response_format"]["type"] == "json_object": + kwargs["format"] = "json" + + try: + stream = self.client.generate(model=self.model, + prompt=query, + stream=False, + **kwargs) + ida_kernwin.execute_sync(functools.partial(cb, response=stream["response"]), + ida_kernwin.MFF_WRITE) + except Exception as e: + print(e) + + +gepetto.models.model_manager.register_model(Ollama) diff --git a/gepetto/models/model_manager.py b/gepetto/models/model_manager.py new file mode 100644 index 0000000..25531c6 --- /dev/null +++ b/gepetto/models/model_manager.py @@ -0,0 +1,39 @@ +import importlib.util +import os +import pathlib + +from gepetto.models.base import LanguageModel + +MODEL_LIST: list[LanguageModel] = list() +FALLBACK_MODEL = "gpt-4o" + +def register_model(model: LanguageModel): + if not issubclass(model, LanguageModel): + return + if any(existing.get_menu_name() == model.get_menu_name() for existing in MODEL_LIST): + return + MODEL_LIST.append(model) + +def list_models(): + return MODEL_LIST + +def instantiate_model(model): + """ + Instantiates a model based on its name + :param model: The model to use + :return: + """ + for m in MODEL_LIST: + if model in m.supported_models(): + return m(model) + # If nothing was found, use the default model. + print(f"Warning: {model} does not exist! Using default model ({FALLBACK_MODEL}).") + return instantiate_model(FALLBACK_MODEL) + +def load_available_models(): + folder = pathlib.Path(os.path.dirname(__file__)) + for py_file in folder.glob("*.py"): + module_name = py_file.stem # Get the file name without extension + spec = importlib.util.spec_from_file_location(module_name, py_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) diff --git a/gepetto/models/openai.py b/gepetto/models/openai.py index cdbbade..9cf76d0 100644 --- a/gepetto/models/openai.py +++ b/gepetto/models/openai.py @@ -7,10 +7,23 @@ import openai from gepetto.models.base import LanguageModel +import gepetto.models.model_manager import gepetto.config +GPT3_MODEL_NAME = "gpt-3.5-turbo-0125" +GPT4_MODEL_NAME = "gpt-4-turbo" +GPT4o_MODEL_NAME = "gpt-4o" + class GPT(LanguageModel): + @staticmethod + def get_menu_name() -> str: + return "OpenAI" + + @staticmethod + def supported_models(): + return [GPT3_MODEL_NAME, GPT4_MODEL_NAME, GPT4o_MODEL_NAME] + def __init__(self, model): self.model = model # Get API key @@ -79,7 +92,8 @@ def query_model_async(self, query, cb, additional_model_options=None): """ if additional_model_options is None: additional_model_options = {} - print(_("Request to {model} sent...").format(model=str(gepetto.config.model))) - t = threading.Thread(target=self.query_model, args=[query, cb, additional_model_options]) - t.start() + print(_("Request to {model} sent...").format(model=str(gepetto.config.model))) + t = threading.Thread(target=self.query_model, args=[query, cb, additional_model_options]) + t.start() +gepetto.models.model_manager.register_model(GPT) diff --git a/gepetto/models/together.py b/gepetto/models/together.py index 99da39d..1aa37a9 100644 --- a/gepetto/models/together.py +++ b/gepetto/models/together.py @@ -1,10 +1,20 @@ import together import gepetto.config +import gepetto.models.model_manager from gepetto.models.openai import GPT +MISTRAL_MODEL_NAME = "mistralai/Mixtral-8x22B-Instruct-v0.1" class Together(GPT): + @staticmethod + def get_menu_name() -> str: + return "Together" + + @staticmethod + def supported_models(): + return [MISTRAL_MODEL_NAME] + def __init__(self, model): try: super().__init__(model) @@ -23,3 +33,5 @@ def __init__(self, model): self.client = together.Together( api_key=api_key, base_url=base_url) + +gepetto.models.model_manager.register_model(Together) diff --git a/requirements.txt b/requirements.txt index 2366069..3983516 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ openai >= 1.0.0 groq >= 0.8.0 together >= 1.2.0 +ollama >= 0.3.3