Skip to content

Commit

Permalink
Merge pull request #38 from iryna-kondr/feature_palm
Browse files Browse the repository at this point in the history
Added Azure OpenAI + Google PaLM support
  • Loading branch information
iryna-kondr authored Jul 4, 2023
2 parents 2dcdcd0 + ffc352a commit d09ba21
Show file tree
Hide file tree
Showing 20 changed files with 874 additions and 438 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ cython_debug/
test.py
tmp.ipynb
tmp.py
*.pickle
66 changes: 45 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ SKLLMConfig.set_openai_org("<YOUR_ORGANISATION>")
- If you have a free trial OpenAI account, the [rate limits](https://platform.openai.com/docs/guides/rate-limits/overview) are not sufficient (specifically 3 requests per minute). Please switch to the "pay as you go" plan first.
- When calling `SKLLMConfig.set_openai_org`, you have to provide your organization ID and **NOT** the name. You can find your ID [here](https://platform.openai.com/account/org-settings).

### Using Azure OpenAI

```python
from skllm.config import SKLLMConfig

SKLLMConfig.set_openai_key("<YOUR_KEY>") #use azure key instead
SKLLMConfig.set_azure_api_base("<API_BASE>")

# start with "azure::" prefix when setting the model name
model_name = "azure::<model_name>"
# e.g. ZeroShotGPTClassifier(openai_model="azure::gpt-3.5-turbo")
```

Note: Azure OpenAI is not supported by the preprocessors at the moment.

### Using GPT4ALL

In addition to OpenAI, some of the models can use [gpt4all](https://gpt4all.io/index.html) as a backend.
Expand All @@ -66,18 +81,20 @@ ZeroShotGPTClassifier(openai_model="gpt4all::ggml-gpt4all-j-v1.3-groovy")

When running for the first time, the model file will be downloaded automatially.

At the moment only the following estimators support gpt4all as a backend:

- `ZeroShotGPTClassifier`
- `MultiLabelZeroShotGPTClassifier`
- `FewShotGPTClassifier`

When using gpt4all please keep the following in mind:

1. Not all gpt4all models are commercially licensable, please consult gpt4all website for more details.
2. The accuracy of the models may be much lower compared to ones provided by OpenAI (especially gpt-4).
3. Not all of the available models were tested, some may not work with scikit-llm at all.

## Supported models by a non-standard backend

At the moment only the following estimators support non-standard backends (gpt4all, azure):

- `ZeroShotGPTClassifier`
- `MultiLabelZeroShotGPTClassifier`
- `FewShotGPTClassifier`

### Zero-Shot Text Classification

One of the powerful ChatGPT features is the ability to perform text classification without being re-trained. For that, the only requirement is that the labels must be descriptive.
Expand Down Expand Up @@ -205,6 +222,28 @@ clf.fit(X, y)
labels = clf.predict(X)
```

### Text Classification with Google PaLM 2

At the moment 3 PaLM based models are available in test mode:
- `ZeroShotPaLMClassifier` - zero-shot text classification with PaLM 2;
- `PaLMClassifier` - fine-tuneable text classifier with PaLM 2;
- `PaLM` - fine-tuneable estimator that can be trained on arbitrary text input-output pairs.

Example:

```python
from skllm.models.palm import PaLMClassifier
from skllm.datasets import get_classification_dataset

X, y = get_classification_dataset()

clf = PaLMClassifier(n_update_steps=100)
clf.fit(X, y)
labels = clf.predict(X)
```

A more detailed documentation will follow soon. For now, please refer to our [official guide on Medium](https://medium.com/@iryna230520).

### Text Vectorization

As an alternative to using GPT as a classifier, it can be used solely for data preprocessing. `GPTVectorizer` allows to embed a chunk of text of arbitrary length to a fixed-dimensional vector, that can be used with virtually any classification or regression model.
Expand Down Expand Up @@ -273,18 +312,3 @@ t = GPTTranslator(openai_model="gpt-3.5-turbo", output_language="English")
translated_text = t.fit_transform(X)
```

## Roadmap 🧭

- [x] Zero-Shot Classification with OpenAI GPT 3/4
- [x] Multiclass classification
- [x] Multi-label classification
- [ ] Few-Shot classifier
- [x] Multiclass classification
- [ ] Multi-label classification
- [x] GPT Vectorizer
- [x] ChatGPT models
- [ ] InstructGPT models
- [ ] InstructGPT Fine-tuning (optional)
- [ ] Open source models

*The order of the elements in the roadmap is arbitrary and does not reflect the planned order of implementation.*
6 changes: 3 additions & 3 deletions skllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ordering is important here to prevent circular imports
from skllm.models.gpt_zero_shot_clf import (
from skllm.models.gpt.gpt_zero_shot_clf import (
MultiLabelZeroShotGPTClassifier,
ZeroShotGPTClassifier,
)
from skllm.models.gpt_few_shot_clf import FewShotGPTClassifier
from skllm.models.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier
from skllm.models.gpt.gpt_few_shot_clf import FewShotGPTClassifier
from skllm.models.gpt.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier
15 changes: 10 additions & 5 deletions skllm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@


def get_chat_completion(
messages: dict, openai_key: str=None, openai_org: str=None, model: str="gpt-3.5-turbo", max_retries: int=3
messages: dict,
openai_key: str = None,
openai_org: str = None,
model: str = "gpt-3.5-turbo",
max_retries: int = 3,
):
"""
Gets a chat completion from the OpenAI API.
"""
"""Gets a chat completion from the OpenAI API."""
if model.startswith("gpt4all::"):
return _g4a_get_chat_completion(messages, model[9:])
else:
api = "azure" if model.startswith("azure::") else "openai"
if api == "azure":
model = model[7:]
return _oai_get_chat_completion(
messages, openai_key, openai_org, model, max_retries
messages, openai_key, openai_org, model, max_retries, api=api
)
110 changes: 105 additions & 5 deletions skllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,121 @@

_OPENAI_KEY_VAR = "SKLLM_CONFIG_OPENAI_KEY"
_OPENAI_ORG_VAR = "SKLLM_CONFIG_OPENAI_ORG"
_AZURE_API_BASE_VAR = "SKLLM_CONFIG_AZURE_API_BASE"
_AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION"
_GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT"

class SKLLMConfig():

class SKLLMConfig:
@staticmethod
def set_openai_key(key: str) -> None:
"""Sets the OpenAI key.
Parameters
----------
key : str
OpenAI key.
"""
os.environ[_OPENAI_KEY_VAR] = key

@staticmethod
def get_openai_key() -> Optional[str]:
"""Gets the OpenAI key.
Returns
-------
Optional[str]
OpenAI key.
"""
return os.environ.get(_OPENAI_KEY_VAR, None)

@staticmethod
def set_openai_org(key: str) -> None:
"""Sets OpenAI organization ID.
Parameters
----------
key : str
OpenAI organization ID.
"""
os.environ[_OPENAI_ORG_VAR] = key

@staticmethod
def get_openai_org() -> Optional[str]:
return os.environ.get(_OPENAI_ORG_VAR, None)
def get_openai_org() -> str:
"""Gets the OpenAI organization ID.
Returns
-------
str
OpenAI organization ID.
"""
return os.environ.get(_OPENAI_ORG_VAR, "")

@staticmethod
def get_azure_api_base() -> str:
"""Gets the API base for Azure.
Returns
-------
str
URL to be used as the base for the Azure API.
"""
base = os.environ.get(_AZURE_API_BASE_VAR, None)
if base is None:
raise RuntimeError("Azure API base is not set")
return base

@staticmethod
def set_azure_api_base(base: str) -> None:
"""Set the API base for Azure.
Parameters
----------
base : str
URL to be used as the base for the Azure API.
"""
os.environ[_AZURE_API_BASE_VAR] = base

@staticmethod
def set_azure_api_version(ver: str) -> None:
"""Set the API version for Azure.
Parameters
----------
ver : str
Azure API version.
"""
os.environ[_AZURE_API_VERSION_VAR] = ver

@staticmethod
def get_azure_api_version() -> str:
"""Gets the API version for Azure.
Returns
-------
str
Azure API version.
"""
return os.environ.get(_AZURE_API_VERSION_VAR, "2023-05-15")

@staticmethod
def get_google_project() -> Optional[str]:
"""Gets the Google Cloud project ID.
Returns
-------
Optional[str]
Google Cloud project ID.
"""
return os.environ.get(_GOOGLE_PROJECT, None)

@staticmethod
def set_google_project(project: str) -> None:
"""Sets the Google Cloud project ID.
Parameters
----------
project : str
Google Cloud project ID.
"""
os.environ[_GOOGLE_PROJECT] = project
41 changes: 41 additions & 0 deletions skllm/google/completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from time import sleep

from vertexai.preview.language_models import ChatModel, TextGenerationModel

# TODO reduce code duplication for retrying logic


def get_completion(model: str, text: str, max_retries: int = 3):
for _ in range(max_retries):
try:
if model.startswith("text-"):
model = TextGenerationModel.from_pretrained(model)
else:
model = TextGenerationModel.get_tuned_model(model)
response = model.predict(text, temperature=0.0)
return response.text
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
sleep(3)
print(
f"Could not obtain the completion after {max_retries} retries: `{error_type} ::"
f" {error_msg}`"
)


def get_completion_chat_mode(model: str, context: str, text: str, max_retries: int = 3):
for _ in range(max_retries):
try:
model = ChatModel.from_pretrained(model)
chat = model.start_chat(context=context)
response = chat.send_message(text, temperature=0.0)
return response.text
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
sleep(3)
print(
f"Could not obtain the completion after {max_retries} retries: `{error_type} ::"
f" {error_msg}`"
)
13 changes: 13 additions & 0 deletions skllm/google/tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pandas import DataFrame
from vertexai.preview.language_models import TextGenerationModel


def tune(model: str, data: DataFrame, train_steps: int = 100):
model = TextGenerationModel.from_pretrained(model)
model.tune_model(
training_data=data,
train_steps=train_steps,
tuning_job_location="europe-west4", # the only supported training location atm
tuned_model_location="us-central1", # the only supported deployment location atm
)
return model # ._job
Loading

0 comments on commit d09ba21

Please sign in to comment.