Skip to content

Commit

Permalink
Merge pull request #43 from iryna-kondr/fix_gpt4allv1
Browse files Browse the repository at this point in the history
Added support of gpt4all>=1.0
  • Loading branch information
iryna-kondr authored Jul 24, 2023
2 parents c9aa097 + 16f8f26 commit 2b2fc82
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ SKLLMConfig.set_openai_org("<YOUR_ORGANISATION>")
```python
from skllm.config import SKLLMConfig

SKLLMConfig.set_openai_key("<YOUR_KEY>") #use azure key instead
SKLLMConfig.set_openai_key("<YOUR_KEY>") # use azure key instead
SKLLMConfig.set_azure_api_base("<API_BASE>")

# start with "azure::" prefix when setting the model name
Expand Down Expand Up @@ -76,7 +76,7 @@ In order to switch from OpenAI to GPT4ALL model, simply provide a string of the
SKLLMConfig.set_openai_key("any string")
SKLLMConfig.set_openai_org("any string")

ZeroShotGPTClassifier(openai_model="gpt4all::ggml-gpt4all-j-v1.3-groovy")
ZeroShotGPTClassifier(openai_model="gpt4all::ggml-model-gpt4all-falcon-q4_0.bin")
```

When running for the first time, the model file will be downloaded automatially.
Expand Down Expand Up @@ -225,11 +225,12 @@ labels = clf.predict(X)
### Text Classification with Google PaLM 2

At the moment 3 PaLM based models are available in test mode:

- `ZeroShotPaLMClassifier` - zero-shot text classification with PaLM 2;
- `PaLMClassifier` - fine-tunable text classifier with PaLM 2;
- `PaLM` - fine-tunable estimator that can be trained on arbitrary text input-output pairs.

Example:
Example:

```python
from skllm.models.palm import PaLMClassifier
Expand Down Expand Up @@ -311,4 +312,3 @@ X = get_translation_dataset()
t = GPTTranslator(openai_model="gpt-3.5-turbo", output_language="English")
translated_text = t.fit_transform(X)
```

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]

[project.optional-dependencies]
gpt4all = ["gpt4all>=0.2.0"]
gpt4all = ["gpt4all>=1.0.0"]

[tool.ruff]
select = [
Expand Down
26 changes: 20 additions & 6 deletions skllm/gpt4all_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
_loaded_models = {}


def get_chat_completion(messages: Dict, model: str="ggml-gpt4all-j-v1.3-groovy") -> Dict:
"""
Gets a chat completion from GPT4All
def _make_openai_compatabile(message: str) -> Dict:
return {"choices": [{"message": {"content": message, "role": "assistant"}}]}


def get_chat_completion(
messages: Dict, model: str = "ggml-model-gpt4all-falcon-q4_0.bin"
) -> Dict:
"""Gets a chat completion from GPT4All.
Parameters
----------
Expand All @@ -28,11 +33,20 @@ def get_chat_completion(messages: Dict, model: str="ggml-gpt4all-j-v1.3-groovy")
"gpt4all is not installed, try `pip install scikit-llm[gpt4all]`"
)
if model not in _loaded_models.keys():
_loaded_models[model] = GPT4All(model)
loaded_model = GPT4All(model)
_loaded_models[model] = loaded_model
loaded_model._current_prompt_template = loaded_model.config["promptTemplate"]

return _loaded_models[model].chat_completion(
messages, verbose=False, streaming=False, temp=1e-10
prompt = _loaded_models[model]._format_chat_prompt_template(
messages, _loaded_models[model].config["systemPrompt"]
)
generated = _loaded_models[model].generate(
prompt,
streaming=False,
temp=1e-10,
)

return _make_openai_compatabile(generated)


def unload_models() -> None:
Expand Down

0 comments on commit 2b2fc82

Please sign in to comment.