Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AIStudio as a supported LLM library #3254

Merged
merged 14 commits into from
Jan 16, 2025
8 changes: 7 additions & 1 deletion data/timesketch.conf
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,13 @@ LLM_PROVIDER_CONFIGS = {
'vertexai': {
'model': 'gemini-1.5-flash-001',
'project_id': '',
}
},
# To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/
# pip install google-generativeai
'aistudio': {
'api_key': '',
'model': 'gemini-2.0-flash-exp',
},
}

# LLM nl2q configuration
Expand Down
1 change: 1 addition & 0 deletions timesketch/lib/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@

from timesketch.lib.llms import ollama
from timesketch.lib.llms import vertexai
from timesketch.lib.llms import aistudio
94 changes: 94 additions & 0 deletions timesketch/lib/llms/aistudio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2024 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Google AI Studio LLM provider."""

import json
from typing import Optional
from timesketch.lib.llms import interface
from timesketch.lib.llms import manager


# Check if the required dependencies are installed.
has_required_deps = True
try:
import google.generativeai as genai
except ImportError:
has_required_deps = False


class AIStudio(interface.LLMProvider):
"""AI Studio LLM provider."""

NAME = "aistudio"

def __init__(self, config: dict):
"""Initialize the AI Studio provider.
Args:
config: The configuration for the provider.
"""
super().__init__(config)
self._api_key = self.config.get("api_key")
self._model_name = self.config.get("model", "gemini-1.5-flash")
self._temperature = self.config.get("temperature", 0.2)
self._top_p = self.config.get("top_p", 0.95)
self._top_k = self.config.get("top_k", 10)
self._max_output_tokens = self.config.get("max_output_tokens", 8192)

if not self._api_key:
raise ValueError("API key is required for AI Studio provider")
genai.configure(api_key=self._api_key)
self.model = genai.GenerativeModel(model_name=self._model_name)

def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str:
"""
Generate text using the Google AI Studio service.

Args:
prompt: The prompt to use for the generation.
response_schema: An optional JSON schema to define the expected
response format.

Returns:
The generated text as a string (or parsed data if
response_schema is provided).
"""

generation_config = genai.GenerationConfig(
temperature=self._temperature,
top_p=self._top_p,
top_k=self._top_k,
max_output_tokens=self._max_output_tokens,
)

if response_schema:
generation_config.response_mime_type = "application/json"
generation_config.response_schema = response_schema

response = self.model.generate_content(
contents=prompt,
generation_config=generation_config,
)

if response_schema:
try:
return json.loads(response.text)
except Exception as error:
raise ValueError(
f"Error JSON parsing text: {response.text}: {error}"
) from error
return response.text


if has_required_deps:
manager.LLMManager.register_provider(AIStudio)
7 changes: 4 additions & 3 deletions timesketch/lib/llms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Interface for LLM providers."""

import string
from typing import Optional

from flask import current_app

Expand Down Expand Up @@ -82,13 +83,13 @@ def prompt_from_template(self, template: str, kwargs: dict) -> str:
formatter = string.Formatter()
return formatter.format(template, **kwargs)

def generate(self, prompt: str) -> str:
def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str:
"""Generate a response from the LLM provider.

Args:
prompt: The prompt to generate a response for.
temperature: The temperature to use for the response.
stream: Whether to stream the response.
response_schema: An optional JSON schema to define the expected
response format.

Returns:
The generated response.
Expand Down
Loading