Skip to content

Commit

Permalink
Merge pull request stanfordnlp#461 from nbqu/main
Browse files Browse the repository at this point in the history
Improve google LLM support
  • Loading branch information
CShorten authored Feb 27, 2024
2 parents a4a3397 + c6e96ed commit 42a5943
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 18 deletions.
86 changes: 71 additions & 15 deletions dsp/modules/google.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import math
from typing import Any, Optional
import os
from typing import Any, Iterable, Optional
import backoff

from dsp.modules.lm import LM

try:
import google.generativeai as genai
from google.api_core.exceptions import GoogleAPICallError
google_api_error = GoogleAPICallError
except ImportError:
google_api_error = Exception
# print("Not loading Google because it is not installed.")
Expand All @@ -27,14 +29,38 @@ def giveup_hdlr(details):
return True


BLOCK_ONLY_HIGH = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH"
},
]


class Google(LM):
"""Wrapper around Google's API.
Currently supported models include `gemini-pro-1.0`.
"""

def __init__(
self, model: str = "gemini-pro-1.0", api_key: Optional[str] = None, **kwargs
self,
model: str = "models/gemini-1.0-pro",
api_key: Optional[str] = None,
safety_settings: Optional[Iterable] = BLOCK_ONLY_HIGH,
**kwargs
):
"""
Parameters
Expand All @@ -49,16 +75,30 @@ def __init__(
Additional arguments to pass to the API provider.
"""
super().__init__(model)
self.google = genai.configure(api_key=api_key)
api_key = os.environ.get("GOOGLE_API_KEY") if api_key is None else api_key
genai.configure(api_key=api_key)

# Google API uses "candidate_count" instead of "n" or "num_generations"
# For now, google API only supports 1 generation at a time. Raises an error if candidate_count > 1
num_generations = kwargs.pop("n", kwargs.pop("num_generations", 1))

self.provider = "google"
self.kwargs = {
"model_name": model,
"temperature": 0.0
if "temperature" not in kwargs
else kwargs["temperature"],
kwargs = {
"candidate_count": 1,
"temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"],
"max_output_tokens": 2048,
"top_p": 1,
"top_k": 1,
**kwargs
}

self.config = genai.GenerationConfig(**kwargs)
self.llm = genai.GenerativeModel(model_name=model,
generation_config=self.config,
safety_settings=safety_settings)

self.kwargs = {
"n": num_generations,
**kwargs,
}

Expand All @@ -68,14 +108,19 @@ def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs
kwargs = {
**self.kwargs,
"prompt": prompt,
**kwargs,
}
response = self.co.generate(**kwargs)

# Google disallows "n" arguments
n = kwargs.pop("n", None)
if n is not None and n > 1 and kwargs['temperature'] == 0.0:
kwargs['temperature'] = 0.7

response = self.llm.generate_content(prompt, generation_config=kwargs)

history = {
"prompt": prompt,
"response": response,
"response": [response],
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
Expand All @@ -85,8 +130,9 @@ def basic_request(self, prompt: str, **kwargs):

@backoff.on_exception(
backoff.expo,
(Exception),
(google_api_error),
max_time=1000,
max_tries=8,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
Expand All @@ -99,6 +145,16 @@ def __call__(
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
**kwargs
):
return self.request(prompt, **kwargs)
assert only_completed, "for now"
assert return_sorted is False, "for now"

n = kwargs.pop("n", 1)

completions = []
for i in range(n):
response = self.request(prompt, **kwargs)
completions.append(response.parts[0].text)

return completions
8 changes: 5 additions & 3 deletions dsp/modules/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def inspect_history(self, n: int = 1, skip: int = 0):

if prompt != last_prompt:

if provider=="clarifai":
if provider == "clarifai" or provider == "google":
printed.append(
(
prompt,
x['response']
)
)
)
else:
else:
printed.append(
(
prompt,
Expand Down Expand Up @@ -82,6 +82,8 @@ def inspect_history(self, n: int = 1, skip: int = 0):
text = ' ' + self._get_choice_text(choices[0]).strip()
elif provider == "clarifai":
text=choices
elif provider == "google":
text = choices[0].parts[0].text
else:
text = choices[0]["text"]
self.print_green(text, end="")
Expand Down

0 comments on commit 42a5943

Please sign in to comment.