-
Notifications
You must be signed in to change notification settings - Fork 274
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #21 from iryna-kondr/feature_gpt4all
Added initial gpt4all support
- Loading branch information
Showing
8 changed files
with
135 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from skllm.gpt4all_client import get_chat_completion as _g4a_get_chat_completion | ||
from skllm.openai.chatgpt import get_chat_completion as _oai_get_chat_completion | ||
|
||
|
||
def get_chat_completion( | ||
messages, openai_key=None, openai_org=None, model="gpt-3.5-turbo", max_retries=3 | ||
): | ||
if model.startswith("gpt4all::"): | ||
return _g4a_get_chat_completion(messages, model[9:]) | ||
else: | ||
return _oai_get_chat_completion( | ||
messages, openai_key, openai_org, model, max_retries | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
try: | ||
from gpt4all import GPT4All | ||
except (ImportError, ModuleNotFoundError): | ||
GPT4All = None | ||
|
||
_loaded_models = {} | ||
|
||
|
||
def get_chat_completion(messages, model="ggml-gpt4all-j-v1.3-groovy"): | ||
if GPT4All is None: | ||
raise ImportError( | ||
"gpt4all is not installed, try `pip install scikit-llm[gpt4all]`" | ||
) | ||
if model not in _loaded_models.keys(): | ||
_loaded_models[model] = GPT4All(model) | ||
return _loaded_models[model].chat_completion( | ||
messages, verbose=False, streaming=False, temp=1e-10 | ||
) | ||
|
||
|
||
def unload_models(): | ||
global _loaded_models | ||
_loaded_models = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,45 @@ | ||
import openai | ||
from time import sleep | ||
import json | ||
from time import sleep | ||
|
||
import openai | ||
|
||
from skllm.openai.credentials import set_credentials | ||
from skllm.utils import find_json_in_string | ||
|
||
|
||
def construct_message(role, content): | ||
if role not in ("system", "user", "assistant"): | ||
raise ValueError("Invalid role") | ||
return {"role": role, "content": content} | ||
|
||
def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries = 3): | ||
|
||
def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries=3): | ||
set_credentials(key, org) | ||
error_msg = None | ||
error_type = None | ||
for _ in range(max_retries): | ||
try: | ||
completion = openai.ChatCompletion.create( | ||
model=model, temperature=0., messages=messages | ||
model=model, temperature=0.0, messages=messages | ||
) | ||
return completion | ||
except Exception as e: | ||
error_msg = str(e) | ||
error_type = type(e).__name__ | ||
error_type = type(e).__name__ | ||
sleep(3) | ||
print(f"Could not obtain the completion after {max_retries} retries: `{error_type} :: {error_msg}`") | ||
print( | ||
f"Could not obtain the completion after {max_retries} retries: `{error_type} ::" | ||
f" {error_msg}`" | ||
) | ||
|
||
|
||
def extract_json_key(json_, key): | ||
try: | ||
as_json = json.loads(json_.replace('\n', '')) | ||
json_ = json_.replace("\n", "") | ||
json_ = find_json_in_string(json_) | ||
as_json = json.loads(json_) | ||
if key not in as_json.keys(): | ||
raise KeyError("The required key was not found") | ||
return as_json[key] | ||
except Exception as e: | ||
except Exception: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,22 @@ | ||
import numpy as np | ||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
def to_numpy(X): | ||
if isinstance(X, pd.Series): | ||
X = X.to_numpy().astype(object) | ||
elif isinstance(X, list): | ||
X = np.asarray(X, dtype = object) | ||
X = np.asarray(X, dtype=object) | ||
if isinstance(X, np.ndarray) and len(X.shape) > 1: | ||
X = np.squeeze(X) | ||
return X | ||
return X | ||
|
||
|
||
def find_json_in_string(string): | ||
start = string.find("{") | ||
end = string.rfind("}") | ||
if start != -1 and end != -1: | ||
json_string = string[start : end + 1] | ||
else: | ||
json_string = {} | ||
return json_string |