Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create new NL2Q API. #3073

Merged
merged 22 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,802 changes: 1,802 additions & 0 deletions data/nl2q/data_types.csv

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions data/nl2q/prompt_nl2q
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Convert the following question to a Lucene query for Timesketch.

Sketch data types:
{data_types}
Question: {question}
Answer:
8 changes: 7 additions & 1 deletion data/timesketch.conf
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ LLM_PROVIDER_CONFIGS = {
# See instructions at: https://ollama.ai/
'ollama': {
'server_url': 'http://localhost:11434',
'model': 'mistral',
'model': 'gemma:7b',
},
# To use the Vertex AI provider you need to:
# 1. Create and export a Service Account Key from the Google Cloud Console.
Expand All @@ -367,3 +367,9 @@ LLM_PROVIDER_CONFIGS = {
'project_id': '',
}
}

# LLM nl2q configuration
DATA_TYPES_PATH = '/etc/timesketch/nl2q/data_types.csv'
PROMPT_NL2Q = '/etc/timesketch/nl2q/prompt_nl2q'
LLM_PROVIDER = ''

5 changes: 5 additions & 0 deletions test_data/nl2q/test_data_types.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
data_type,field,type,description
test:data_type:1,field_test_1,str,field test 1 description.
test:data_type:1,field_test_2,str,field test 2 description.
test:data_type:2,field_test_1,str,field test 1 description.
test:data_type:2,field_test_2,str,field test 2 description.
6 changes: 6 additions & 0 deletions test_data/nl2q/test_prompt_nl2q
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Convert the following question to a Lucene query for Timesketch.

Sketch data types:
{data_types}
Question: {question}
Answer:
220 changes: 220 additions & 0 deletions timesketch/api/v1/resources/nl2q.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright 2024 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Natural language to query (NL2Q) API for version 1 of the Timesketch API."""

import logging

from flask import jsonify
from flask import request
from flask import abort
from flask import current_app
from flask_restful import Resource
from flask_login import login_required
from flask_login import current_user

import pandas as pd

from timesketch.api.v1 import utils
from timesketch.lib.llms import manager
from timesketch.lib.definitions import HTTP_STATUS_CODE_BAD_REQUEST
from timesketch.lib.definitions import HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR
from timesketch.lib.definitions import HTTP_STATUS_CODE_NOT_FOUND
from timesketch.lib.definitions import HTTP_STATUS_CODE_FORBIDDEN
from timesketch.models.sketch import Sketch


logger = logging.getLogger("timesketch.api_nl2q")


class Nl2qResource(Resource):
"""Resource to get NL2Q prediction."""

def build_prompt(self, question, sketch_id):
"""Builds the prompt.

Args:
sketch_id: Sketch ID.

Return:
String containing the whole prompt.
"""
prompt = ""
prompt_file = current_app.config.get("PROMPT_NL2Q", "")
try:
with open(prompt_file, "r") as file:
prompt = file.read()
prompt = prompt.format(
question=question,
data_types=self.data_types_descriptions(
self.sketch_data_types(sketch_id)
),
)
except (OSError, IOError):
abort(HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR, "No prompt file found")
return prompt

def sketch_data_types(self, sketch_id):
"""Get the data types for the current sketch.

Args:
sketch_id: Sketch ID.

Returns:
List of data types in a sketch.
"""
output = []
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

if not sketch.has_permission(current_user, "read"):
abort(
HTTP_STATUS_CODE_FORBIDDEN, "User does not have read access to sketch"
)

data_type_aggregation = utils.run_aggregator(
sketch_id, "field_bucket", {"field": "data_type", "limit": "1000"}
)

if not data_type_aggregation or not data_type_aggregation[0]:
abort(
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"Internal problem with the aggregations.",
)
data_types = data_type_aggregation[0].values
if not data_types:
abort(
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR, "No data types in the sketch."
)
for data_type in data_types:
output.append(data_type.get("data_type"))
return ",".join(output)

def data_types_descriptions(self, data_types):
"""Creates a dict of data types and attribute descriptions.

Args:
data_types: List of data types in the sketch.

Returns:
Dict of data types and attribute descriptions.
"""
df_data_types = utils.load_csv_file("DATA_TYPES_PATH")
if df_data_types.empty:
abort(
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"No data types description file or the file is empty.",
)
df_short_data_types = pd.DataFrame(
df_data_types.groupby("data_type").apply(self.concatenate_values),
columns=["fields"],
)
df_short_data_types["data_type"] = df_short_data_types.index
df_short_data_types["data_type"] = df_short_data_types["data_type"].apply(
lambda x: x.strip()
)
df_short_data_types.reset_index(drop=True, inplace=True)
output = []
for dtype in data_types.split(","):
extract = df_short_data_types[
df_short_data_types["data_type"] == dtype.strip()
]
if extract.empty:
print(f"'{dtype.strip()}' not found in [{data_types}]")
continue
output.append(extract.iloc[0]["fields"])
return "\n".join(output)

def generate_fields(self, group):
"""Generated the fields for a data type.

Args:
group: Data type fields.

Returns:
String of the generated fields.
"""
generated_fields = ", ".join(
f'"{n}" ({t}, {d})'
for n, t, d in zip(group["field"], group["type"], group["description"])
)
return generated_fields

def concatenate_values(self, group):
"""Concatenates the fields for a data type.

Args:
group: Data type fields.

Returns:
String of the concatenated fields.
"""
concatenated_valued = '- "{}" fields: [{}]'.format(
group["data_type"].iloc[0], self.generate_fields(group)
)
return concatenated_valued

@login_required
def post(self, sketch_id):
"""Handles POST request to the resource.

Args:
sketch_id: Sketch ID.

Returns:
JSON representing the LLM prediction.
"""
llm_provider = current_app.config.get("LLM_PROVIDER", "")
if not llm_provider:
logger.error("No LLM provider was defined in the main configuration file")
abort(
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"No LLM provider was defined in the main configuration file",
)
form = request.json
if not form:
abort(
HTTP_STATUS_CODE_BAD_REQUEST,
"No JSON data provided",
)

if "question" not in form:
abort(
HTTP_STATUS_CODE_BAD_REQUEST,
"The 'question' parameter is required!",
)

question = form.get("question")
prompt = self.build_prompt(question, sketch_id)
try:
llm = manager.LLMManager().get_provider(llm_provider)()
except Exception as e: # pylint: disable=broad-except
logger.error("Error LLM Provider: {}".format(e))
abort(
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"Error in loading the LLM Provider. Please contact your "
"Timesketch administrator.",
)

try:
prediction = llm.generate(prompt)
except Exception as e: # pylint: disable=broad-except
logger.error("Error NL2Q prompt: {}".format(e))
abort(
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"An error occurred generating the NL2Q prediction via the "
"defined LLM. Please contact your Timesketch administrator.",
)
result = {"question": question, "llm_query": prediction}
return jsonify(result)
Loading
Loading