Skip to content

Commit

Permalink
Merge pull request #21 from iryna-kondr/feature_gpt4all
Browse files Browse the repository at this point in the history
Added initial gpt4all support
  • Loading branch information
iryna-kondr authored May 30, 2023
2 parents 3750b07 + d43b264 commit 610bbc3
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 42 deletions.
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ There are several ways you can contribute to this project:
**Important:** before contributing, we recommend that you open an issue to discuss your planned changes. This allows us to align our goals, provide guidance, and potentially find other contributors interested in collaborating on the same feature or bug fix.

> ### Legal Notice <!-- omit in toc -->
>
> When contributing to this project, you must agree that you have authored 100% of the content, that you have the necessary rights to the content and that the content you contribute may be provided under the project license.
## Development dependencies

In order to install all development dependencies, run the following command:

```shell
pip install -e ".[dev]"
pip install -r requirements-dev.txt
```

To ensure that you follow the development workflow, please setup the pre-commit hooks:
Expand Down
59 changes: 43 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,21 @@ You can support the project in the following ways:

- ⭐ Star Scikit-LLM on GitHub (click the star button in the top right corner)
- 🐦 Check out our related project - [Falcon AutoML](https://github.com/OKUA1/falcon)
- 💡 Provide your feedback or propose ideas in the [issues](https://github.com/iryna-kondr/scikit-llm/issues) section
- 💡 Provide your feedback or propose ideas in the [issues](https://github.com/iryna-kondr/scikit-llm/issues) section or [Discord](https://discord.gg/NTaRnRpf)
- 🔗 Post about Scikit-LLM on LinkedIn or other platforms

## Documentation 📚

### Configuring OpenAI API Key

At the moment Scikit-LLM is only compatible with some of the OpenAI models. Hence, a user-provided OpenAI API key is required.
At the moment the majority of the Scikit-LLM estimators are only compatible with some of the OpenAI models. Hence, a user-provided OpenAI API key is required.

```python
from skllm.config import SKLLMConfig

SKLLMConfig.set_openai_key("<YOUR_KEY>")
SKLLMConfig.set_openai_org("<YOUR_ORGANISATION>")
```

```python
from skllm.config import SKLLMConfig
Expand All @@ -39,6 +46,40 @@ SKLLMConfig.set_openai_org("<YOUR_ORGANISATION>")
- If you have a free trial OpenAI account, the [rate limits](https://platform.openai.com/docs/guides/rate-limits/overview) are not sufficient (specifically 3 requests per minute). Please switch to the "pay as you go" plan first.
- When calling `SKLLMConfig.set_openai_org`, you have to provide your organization ID and **NOT** the name. You can find your ID [here](https://platform.openai.com/account/org-settings).

### Using GPT4ALL

In addition to OpenAI, some of the models can use [gpt4all](https://gpt4all.io/index.html) as a backend.

**This feature is considered higly experimental!**

In order to use gpt4all, you need to install the corresponding submodule:

```bash
pip install scikit-llm[gpt4all]
```

In order to switch from OpenAI to GPT4ALL model, simply provide a string of the format `gpt4all::<model_name>` as an argument. While the model runs completely locally, the estimator still treats it as an OpenAI endpoint and will try to check that the API key is present. You can provide any string as a key.

```python
SKLLMConfig.set_openai_key("any string")
SKLLMConfig.set_openai_org("any string")

ZeroShotGPTClassifier(openai_model="gpt4all::ggml-gpt4all-j-v1.3-groovy")
```

When running for the first time, the model file will be downloaded automatially.

At the moment only the following estimators support gpt4all as a backend:
- `ZeroShotGPTClassifier`
- `MultiLabelZeroShotGPTClassifier`
- `FewShotGPTClassifier`

When using gpt4all please keep the following in mind:

1. Not all gpt4all models are commercially licensable, please consult gpt4all website for more details.
2. The accuracy of the models may be much lower compared to ones provided by OpenAI (especially gpt-4).
3. Not all of the available models were tested, some may not work with scikit-llm at all.

### Zero-Shot Text Classification

One of the powerful ChatGPT features is the ability to perform text classification without being re-trained. For that, the only requirement is that the labels must be descriptive.
Expand Down Expand Up @@ -222,17 +263,3 @@ translated_text = t.fit_transform(X)
- [ ] Open source models

*The order of the elements in the roadmap is arbitrary and does not reflect the planned order of implementation.*

## Contributing

In order to install all development dependencies, run the following command:

```shell
pip install -e ".[dev]"
```

To ensure that you follow the development workflow, please setup the pre-commit hooks:

```shell
pre-commit install
```
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [
"tqdm>=4.60.0",
]
name = "scikit-llm"
version = "0.1.0b3"
version = "0.1.0"
authors = [
{ name="Oleg Kostromin", email="kostromin97@gmail.com" },
{ name="Iryna Kondrashchenko", email="iryna230520@gmail.com" },
Expand All @@ -24,10 +24,9 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dynamic = ["optional-dependencies"]

[tool.setuptools.dynamic.optional-dependencies]
dev = { file = ["requirements-dev.txt"] }
[project.optional-dependencies]
gpt4all = ["gpt4all>=0.2.0"]

[tool.ruff]
select = [
Expand Down
13 changes: 13 additions & 0 deletions skllm/completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from skllm.gpt4all_client import get_chat_completion as _g4a_get_chat_completion
from skllm.openai.chatgpt import get_chat_completion as _oai_get_chat_completion


def get_chat_completion(
messages, openai_key=None, openai_org=None, model="gpt-3.5-turbo", max_retries=3
):
if model.startswith("gpt4all::"):
return _g4a_get_chat_completion(messages, model[9:])
else:
return _oai_get_chat_completion(
messages, openai_key, openai_org, model, max_retries
)
23 changes: 23 additions & 0 deletions skllm/gpt4all_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
try:
from gpt4all import GPT4All
except (ImportError, ModuleNotFoundError):
GPT4All = None

_loaded_models = {}


def get_chat_completion(messages, model="ggml-gpt4all-j-v1.3-groovy"):
if GPT4All is None:
raise ImportError(
"gpt4all is not installed, try `pip install scikit-llm[gpt4all]`"
)
if model not in _loaded_models.keys():
_loaded_models[model] = GPT4All(model)
return _loaded_models[model].chat_completion(
messages, verbose=False, streaming=False, temp=1e-10
)


def unload_models():
global _loaded_models
_loaded_models = {}
28 changes: 18 additions & 10 deletions skllm/models/gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from tqdm import tqdm

from skllm.openai.chatgpt import (
construct_message,
extract_json_key,
get_chat_completion,
)
from skllm.completions import get_chat_completion
from skllm.openai.chatgpt import construct_message, extract_json_key
from skllm.openai.mixin import OpenAIMixin as _OAIMixin
from skllm.prompts.builders import (
build_zero_shot_prompt_mlc,
Expand Down Expand Up @@ -101,14 +98,25 @@ def _get_prompt(self, x) -> str:
def _predict_single(self, x):
completion = self._get_chat_completion(x)
try:
label = str(
extract_json_key(completion.choices[0].message["content"], "label")
)
except Exception:
if self.openai_model.startswith("gpt4all::"):
label = str(
extract_json_key(
completion["choices"][0]["message"]["content"], "label"
)
)
else:
label = str(
extract_json_key(completion.choices[0].message["content"], "label")
)
except Exception as e:
print(completion)
print(f"Could not extract the label from the completion: {str(e)}")
label = ""

if label not in self.classes_:
label = random.choices(self.classes_, self.probabilities_)[0]
label = label.replace("'", "").replace('"', "")
if label not in self.classes_: # try again
label = random.choices(self.classes_, self.probabilities_)[0]
return label

def fit(
Expand Down
27 changes: 19 additions & 8 deletions skllm/openai/chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,45 @@
import openai
from time import sleep
import json
from time import sleep

import openai

from skllm.openai.credentials import set_credentials
from skllm.utils import find_json_in_string


def construct_message(role, content):
if role not in ("system", "user", "assistant"):
raise ValueError("Invalid role")
return {"role": role, "content": content}

def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries = 3):

def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries=3):
set_credentials(key, org)
error_msg = None
error_type = None
for _ in range(max_retries):
try:
completion = openai.ChatCompletion.create(
model=model, temperature=0., messages=messages
model=model, temperature=0.0, messages=messages
)
return completion
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
error_type = type(e).__name__
sleep(3)
print(f"Could not obtain the completion after {max_retries} retries: `{error_type} :: {error_msg}`")
print(
f"Could not obtain the completion after {max_retries} retries: `{error_type} ::"
f" {error_msg}`"
)


def extract_json_key(json_, key):
try:
as_json = json.loads(json_.replace('\n', ''))
json_ = json_.replace("\n", "")
json_ = find_json_in_string(json_)
as_json = json.loads(json_)
if key not in as_json.keys():
raise KeyError("The required key was not found")
return as_json[key]
except Exception as e:
except Exception:
return None
17 changes: 14 additions & 3 deletions skllm/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import numpy as np
import numpy as np
import pandas as pd


def to_numpy(X):
if isinstance(X, pd.Series):
X = X.to_numpy().astype(object)
elif isinstance(X, list):
X = np.asarray(X, dtype = object)
X = np.asarray(X, dtype=object)
if isinstance(X, np.ndarray) and len(X.shape) > 1:
X = np.squeeze(X)
return X
return X


def find_json_in_string(string):
start = string.find("{")
end = string.rfind("}")
if start != -1 and end != -1:
json_string = string[start : end + 1]
else:
json_string = {}
return json_string

0 comments on commit 610bbc3

Please sign in to comment.