diff --git a/data/timesketch.conf b/data/timesketch.conf index fb25a1732b..dad4a5263e 100644 --- a/data/timesketch.conf +++ b/data/timesketch.conf @@ -366,7 +366,13 @@ LLM_PROVIDER_CONFIGS = { 'vertexai': { 'model': 'gemini-1.5-flash-001', 'project_id': '', - } + }, + # To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/ + # pip install google-generativeai + 'aistudio': { + 'api_key': '', + 'model': 'gemini-2.0-flash-exp', + }, } # LLM nl2q configuration diff --git a/timesketch/lib/llms/__init__.py b/timesketch/lib/llms/__init__.py index 6612f76aa0..bb52e18d42 100644 --- a/timesketch/lib/llms/__init__.py +++ b/timesketch/lib/llms/__init__.py @@ -15,3 +15,4 @@ from timesketch.lib.llms import ollama from timesketch.lib.llms import vertexai +from timesketch.lib.llms import aistudio diff --git a/timesketch/lib/llms/aistudio.py b/timesketch/lib/llms/aistudio.py new file mode 100644 index 0000000000..77b6502efa --- /dev/null +++ b/timesketch/lib/llms/aistudio.py @@ -0,0 +1,94 @@ +# Copyright 2024 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Google AI Studio LLM provider.""" + +import json +from typing import Optional +from timesketch.lib.llms import interface +from timesketch.lib.llms import manager + + +# Check if the required dependencies are installed. +has_required_deps = True +try: + import google.generativeai as genai +except ImportError: + has_required_deps = False + + +class AIStudio(interface.LLMProvider): + """AI Studio LLM provider.""" + + NAME = "aistudio" + + def __init__(self, config: dict): + """Initialize the AI Studio provider. + Args: + config: The configuration for the provider. + """ + super().__init__(config) + self._api_key = self.config.get("api_key") + self._model_name = self.config.get("model", "gemini-1.5-flash") + self._temperature = self.config.get("temperature", 0.2) + self._top_p = self.config.get("top_p", 0.95) + self._top_k = self.config.get("top_k", 10) + self._max_output_tokens = self.config.get("max_output_tokens", 8192) + + if not self._api_key: + raise ValueError("API key is required for AI Studio provider") + genai.configure(api_key=self._api_key) + self.model = genai.GenerativeModel(model_name=self._model_name) + + def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str: + """ + Generate text using the Google AI Studio service. + + Args: + prompt: The prompt to use for the generation. + response_schema: An optional JSON schema to define the expected + response format. + + Returns: + The generated text as a string (or parsed data if + response_schema is provided). + """ + + generation_config = genai.GenerationConfig( + temperature=self._temperature, + top_p=self._top_p, + top_k=self._top_k, + max_output_tokens=self._max_output_tokens, + ) + + if response_schema: + generation_config.response_mime_type = "application/json" + generation_config.response_schema = response_schema + + response = self.model.generate_content( + contents=prompt, + generation_config=generation_config, + ) + + if response_schema: + try: + return json.loads(response.text) + except Exception as error: + raise ValueError( + f"Error JSON parsing text: {response.text}: {error}" + ) from error + return response.text + + +if has_required_deps: + manager.LLMManager.register_provider(AIStudio) diff --git a/timesketch/lib/llms/interface.py b/timesketch/lib/llms/interface.py index 743b0253c3..014e0887b4 100644 --- a/timesketch/lib/llms/interface.py +++ b/timesketch/lib/llms/interface.py @@ -14,6 +14,7 @@ """Interface for LLM providers.""" import string +from typing import Optional from flask import current_app @@ -82,13 +83,13 @@ def prompt_from_template(self, template: str, kwargs: dict) -> str: formatter = string.Formatter() return formatter.format(template, **kwargs) - def generate(self, prompt: str) -> str: + def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str: """Generate a response from the LLM provider. Args: prompt: The prompt to generate a response for. - temperature: The temperature to use for the response. - stream: Whether to stream the response. + response_schema: An optional JSON schema to define the expected + response format. Returns: The generated response.