-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
157 lines (128 loc) · 4.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import click
import logging
import sys
import requests
import os
from api.urls import urls
from config import Config
from const import SAMPLE_QUERY
from lib.history import HistoryRepository
from lib.json_validator import JsonSchemaValidatorTool, JsonSchemaValidationException
from lib.models import SerVerlessWorkflow
from lib.ollama import Ollama
from lib.repository import VectorRepository
from lib.retriever import Retriever
from lib.validator import OutputValidator
from lib.serverless_validation import ServerlessValidation
from flask import Flask, g
logging.basicConfig(stream=sys.stderr, level=os.environ.get('LOG_LEVEL', 'INFO').upper())
MODELS_EMBEDDINGS = {
"llama3.2:3b": 3072,
"granite3-moe:1b": 1024
}
class Context:
def __init__(self, config):
self.config = config
self.ollama = Ollama(self.config.base_url, self.config.model)
self.repo = VectorRepository(self.config.db, self.ollama.embeddings, embeddings_len=MODELS_EMBEDDINGS.get(self.config.model, 4096))
self.validator = OutputValidator(
SerVerlessWorkflow,
JsonSchemaValidatorTool.load_from_file("lib/schema/workflow.json"))
self.history_repo = HistoryRepository(
session_id="empty",
connection="sqlite:///{0}".format(self.config.chat_db))
app = Flask(
__name__,
static_folder='static',
)
@app.before_request
def before_request():
g.ctx = Context(Config())
# @TODO delete this method
@click.group()
@click.pass_context
def cli(ctx):
ctx.obj = Context(Config())
pass
# @TODO Move this method to the servives
@click.command()
@click.argument('file-path')
@click.pass_obj
def load_data(obj, file_path):
repo = obj.repo
try:
content = Retriever.fetch(file_path)
except Exception as e:
click.echo(f"cannot read file-path {e}")
sys.exit(1)
splitter = Retriever.get_splitters(file_path)
documents = obj.ollama.parse_document(splitter, content)
if len(documents) == 0:
click.echo("The len of the documents is 0")
sys.exit(1)
try:
res = repo.add_documents(documents)
except Exception as e:
click.echo(f"cannot create or storing the embeddings: {e}")
sys.exit(1)
repo.save()
click.echo("{0} documents added with ids {1}".format(len(documents), res))
@click.command()
@click.pass_obj
def run(obj):
for x in urls:
app.add_url_rule(x[0], view_func=x[1], methods=x[2])
app.run(debug=True)
@click.command()
@click.argument('example', required=False)
@click.pass_obj
def sample_request(obj, example):
url = "http://localhost:5000/chat"
headers = {
'Content-type': 'application/json',
}
query = SAMPLE_QUERY
if example:
with open(f"examples/prompts/{example}.txt", "r") as fp:
query = fp.read()
data = {'input': query}
response = requests.post(url, json=data, headers=headers, stream=True)
for line in response.iter_lines():
print(line.decode('utf-8'))
session_id = response.headers.get('session_id')
click.echo(f"The session_id is: {session_id}")
@click.command()
@click.argument('file-path', required=True)
@click.pass_obj
def validate_json(obj, file_path):
fp = open(file_path, "r")
workflow = fp.read()
fp.close()
click.echo("JSONschema validation:")
try:
obj.validator.invoke(workflow)
except JsonSchemaValidationException as e:
click.echo(e.get_error())
click.echo("Maven compilation validation:")
serverless_validation, valid = ServerlessValidation(workflow).run()
click.echo(f"{serverless_validation}")
click.echo(f"The workflow can compile, result: {valid}")
@click.command()
@click.argument('text', required=True)
@click.pass_obj
def embedding(obj, text):
click.echo(f"Checking text: '{text}'")
data = obj.repo.retriever.invoke(text)
click.echo("Number of items: {0}".format(len(data)))
for i,item in enumerate(data):
#click.echo(click.style(f"I am colored {color}", fg=color))
click.echo(click.style(f"Document {i}: {item.metadata.get('source')}", fg="green"))
click.echo(click.style(item.page_content, fg="bright_yellow"))
click.echo("\n\n")
cli.add_command(load_data)
cli.add_command(run)
cli.add_command(sample_request)
cli.add_command(validate_json)
cli.add_command(embedding)
if __name__ == '__main__':
cli()