Skip to content

Commit

Permalink
Merge pull request #31 from WladimirLct/main
Browse files Browse the repository at this point in the history
Add support for Groq's API
  • Loading branch information
potsawee authored May 31, 2024
2 parents 9d01cea + 69d2a71 commit abe58b3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
selfcheck_prompt = SelfCheckLLMPrompt(llm_model, device)

# Option2: API access (currently only support client_type="openai")
# Option2: API access
# (currently only support OpenAI and Groq)
# from selfcheckgpt.modeling_selfcheck_apiprompt import SelfCheckAPIPrompt
# selfcheck_prompt = SelfCheckAPIPrompt(client_type="openai", model="gpt-3.5-turbo")
# selfcheck_prompt = SelfCheckAPIPrompt(client_type="groq", model="llama3-70b-8192", api_key="your-api-key")

sent_scores_prompt = selfcheck_prompt.predict(
sentences = sentences, # list of sentences
Expand Down
8 changes: 6 additions & 2 deletions selfcheckgpt/modeling_selfcheck_apiprompt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from openai import OpenAI
from groq import Groq
from tqdm import tqdm
from typing import Dict, List, Set, Tuple, Union
import numpy as np
Expand All @@ -11,14 +12,17 @@ def __init__(
self,
client_type = "openai",
model = "gpt-3.5-turbo",
api_key = None,
):
assert client_type in ["openai"]
if client_type == "openai":
# using default keys
# os.environ.get("OPENAI_ORGANIZATION")
# os.environ.get("OPENAI_API_KEY")
self.client = OpenAI()
print("Initiate OpenAI client... model = {}".format(model))
elif client_type == "groq":
self.client = Groq(api_key=api_key)
print("Initiate Groq client... model = {}".format(model))

self.client_type = client_type
self.model = model
Expand All @@ -31,7 +35,7 @@ def set_prompt_template(self, prompt_template: str):
self.prompt_template = prompt_template

def completion(self, prompt: str):
if self.client_type == "openai":
if self.client_type == "openai" or self.client_type == "groq":
chat_completion = self.client.chat.completions.create(
model=self.model,
messages=[
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"spacy",
"nltk",
"openai",
"groq",
]

# some more details
Expand Down

0 comments on commit abe58b3

Please sign in to comment.