diff --git a/gpt4all_llm.py b/gpt4all_llm.py index 550b85a23..31058d735 100644 --- a/gpt4all_llm.py +++ b/gpt4all_llm.py @@ -73,32 +73,60 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None: pass -def get_llm_gpt4all(model_name, model=None, +def get_model_kwargs(env_kwargs, default_kwargs, cls): + # default from class + model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items()} + # from our defaults + model_kwargs.update(default_kwargs) + # from user defaults + model_kwargs.update(env_kwargs) + # ensure only valid keys + func_names = list(inspect.signature(cls).parameters) + model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names} + return model_kwargs + + +def get_llm_gpt4all(model_name, + model=None, max_new_tokens=256, temperature=0.1, repetition_penalty=1.0, top_k=40, top_p=0.7): env_gpt4all_file = ".env_gpt4all" - model_kwargs = dotenv_values(env_gpt4all_file) + env_kwargs = dotenv_values(env_gpt4all_file) callbacks = [H2OStreamingStdOutCallbackHandler()] - n_ctx = model_kwargs.pop('n_ctx', 2048 - max_new_tokens) - default_params = {'context_erase': 0.5, 'n_batch': 1, 'n_ctx': n_ctx, 'n_predict': max_new_tokens, - 'repeat_last_n': 64 if repetition_penalty != 1.0 else 0, 'repeat_penalty': repetition_penalty, - 'temp': temperature, 'top_k': top_k, 'top_p': top_p, 'use_mlock': True} + n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens) + default_kwargs = dict(context_erase=0.5, + n_batch=1, + n_ctx=n_ctx, + n_predict=max_new_tokens, + repeat_last_n=64 if repetition_penalty != 1.0 else 0, + repeat_penalty=repetition_penalty, + temp=temperature, + temperature=temperature, + top_k=top_k, + top_p=top_p, + use_mlock=True, + verbose=False) if model_name == 'llama': - model_path = model_kwargs.pop('model_path_llama') if model is None else model - llm = H2OLlamaCpp(model_path=model_path, n_ctx=n_ctx, callbacks=callbacks, verbose=False) + cls = H2OLlamaCpp + model_path = env_kwargs.pop('model_path_llama') if model is None else model + model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls) + model_kwargs.update(dict(model_path=model_path, callbacks=callbacks)) + llm = cls(**model_kwargs) elif model_name == 'gpt4all_llama': - model_path = model_kwargs.pop('model_path_gpt4all_llama') if model is None else model - llm = H2OGPT4All(model=model_path, backend='llama', callbacks=callbacks, - verbose=False, **default_params, - ) + cls = H2OGPT4All + model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model + model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls) + model_kwargs.update(dict(model=model_path, backend='llama', callbacks=callbacks)) + llm = cls(**model_kwargs) elif model_name == 'gptj': - model_path = model_kwargs.pop('model_path_gptj') if model is None else model - llm = H2OGPT4All(model=model_path, backend='gptj', callbacks=callbacks, - verbose=False, **default_params, - ) + cls = H2OGPT4All + model_path = env_kwargs.pop('model_path_gptj') if model is None else model + model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls) + model_kwargs.update(dict(model=model_path, backend='gptj', callbacks=callbacks)) + llm = cls(**model_kwargs) else: raise RuntimeError("No such model_name %s" % model_name) return llm diff --git a/prompter.py b/prompter.py index cdcdbc5fa..67bbd8887 100644 --- a/prompter.py +++ b/prompter.py @@ -44,6 +44,7 @@ class PromptType(Enum): 'mosaicml/mpt-7b-chat', # NC, internal code handles instruct 'gptj', # internally handles prompting 'llama', # internally handles prompting + 'gpt4all_llama', # internally handles prompting ], 'prompt_answer': [ 'h2oai/h2ogpt-gm-oasst1-en-1024-20b',