-
Notifications
You must be signed in to change notification settings - Fork 212
/
Copy pathindex_data.py
193 lines (154 loc) · 6.21 KB
/
index_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import json
import os
from sys import stdout
import time
from halo import Halo
from warnings import warn
from elasticsearch import (
ApiError,
Elasticsearch,
NotFoundError,
BadRequestError,
)
from elastic_transport._exceptions import ConnectionTimeout
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_elasticsearch import ElasticsearchStore
# Global variables
# Modify these if you want to use a different file, index or model
INDEX = os.getenv("ES_INDEX", "workplace-app-docs")
FILE = os.getenv("FILE", f"{os.path.dirname(__file__)}/data.json")
ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL")
ELASTICSEARCH_USER = os.getenv("ELASTICSEARCH_USER")
ELASTICSEARCH_PASSWORD = os.getenv("ELASTICSEARCH_PASSWORD")
ELASTICSEARCH_API_KEY = os.getenv("ELASTICSEARCH_API_KEY")
ELSER_MODEL = os.getenv("ELSER_MODEL", ".elser_model_2")
if ELASTICSEARCH_USER:
es = Elasticsearch(
hosts=[ELASTICSEARCH_URL],
basic_auth=(ELASTICSEARCH_USER, ELASTICSEARCH_PASSWORD),
)
elif ELASTICSEARCH_API_KEY:
es = Elasticsearch(hosts=[ELASTICSEARCH_URL], api_key=ELASTICSEARCH_API_KEY)
else:
raise ValueError(
"Please provide either ELASTICSEARCH_USER or ELASTICSEARCH_API_KEY"
)
def install_elser():
# This script is re-entered on ctrl-c or someone just running it twice.
# Hence, both steps need to be careful about being potentially redundant.
# Step 1: Ensure ELSER_MODEL is defined
try:
es.ml.get_trained_models(model_id=ELSER_MODEL)
except NotFoundError:
print(f'"{ELSER_MODEL}" model not available, downloading it now')
es.ml.put_trained_model(
model_id=ELSER_MODEL, input={"field_names": ["text_field"]}
)
while True:
status = es.ml.get_trained_models(
model_id=ELSER_MODEL, include="definition_status"
)
if status["trained_model_configs"][0]["fully_defined"]:
break
time.sleep(1)
# Step 2: Ensure ELSER_MODEL is fully allocated
if not is_elser_fully_allocated():
try:
es.ml.start_trained_model_deployment(
model_id=ELSER_MODEL, wait_for="fully_allocated"
)
print(f'"{ELSER_MODEL}" model is deployed')
except BadRequestError:
# Already started, and likely fully allocated
pass
print(f'"{ELSER_MODEL}" model is ready')
def is_elser_fully_allocated():
stats = es.ml.get_trained_models_stats(model_id=ELSER_MODEL)
deployment_stats = stats["trained_model_stats"][0].get("deployment_stats", {})
allocation_status = deployment_stats.get("allocation_status", {})
return allocation_status.get("state") == "fully_allocated"
def main():
install_elser()
print(f"Loading data from ${FILE}")
metadata_keys = ["name", "summary", "url", "category", "updated_at"]
workplace_docs = []
with open(FILE, "rt") as f:
for doc in json.loads(f.read()):
workplace_docs.append(
Document(
page_content=doc["content"],
metadata={k: doc.get(k) for k in metadata_keys},
)
)
print(f"Loaded {len(workplace_docs)} documents")
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=512, chunk_overlap=256
)
docs = text_splitter.transform_documents(workplace_docs)
print(f"Split {len(workplace_docs)} documents into {len(docs)} chunks")
print(f"Creating Elasticsearch sparse vector store for {ELASTICSEARCH_URL}")
store = ElasticsearchStore(
es_connection=es,
index_name=INDEX,
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(model_id=ELSER_MODEL),
)
# The first call creates ML tasks to support the index, and typically fails
# with the default 10-second timeout, at least when Elasticsearch is a
# container running on Apple Silicon.
#
# Once elastic/elasticsearch#107077 is fixed, we can use bulk_kwargs to
# adjust the timeout.
print(f"Adding documents to index {INDEX}")
if stdout.isatty():
spinner = Halo(text="Processing bulk operation", spinner="dots")
spinner.start()
try:
es.indices.delete(index=INDEX, ignore_unavailable=True)
store.add_documents(list(docs))
except BadRequestError:
# This error means the index already exists
pass
except (ConnectionTimeout, ApiError) as e:
if isinstance(e, ApiError) and e.status_code != 408:
raise
warn(f"Error occurred, will retry after ML jobs complete: {e}")
await_ml_tasks()
es.indices.delete(index=INDEX, ignore_unavailable=True)
store.add_documents(list(docs))
if stdout.isatty():
spinner.stop()
print(f"Documents added to index {INDEX}")
def await_ml_tasks(max_timeout=1200, interval=5):
"""
Waits for all machine learning tasks to complete within a specified timeout period.
Parameters:
max_timeout (int): Maximum time to wait for tasks to complete, in seconds.
interval (int): Time to wait between status checks, in seconds.
Raises:
TimeoutError: If the timeout is reached and machine learning tasks are still running.
"""
start_time = time.time()
ml_tasks = get_ml_tasks()
if not ml_tasks:
return # likely a lost race on tasks
print(f"Awaiting {len(ml_tasks)} ML tasks")
while time.time() - start_time < max_timeout:
ml_tasks = get_ml_tasks()
if not ml_tasks:
return
time.sleep(interval)
raise TimeoutError(
f"Timeout reached. ML tasks are still running: {', '.join(ml_tasks)}"
)
def get_ml_tasks():
"""Return a list of ML task actions from the ES tasks API."""
tasks = []
resp = es.tasks.list(detailed=True, actions=["cluster:monitor/xpack/ml/*"])
for node_info in resp["nodes"].values():
for task_info in node_info.get("tasks", {}).values():
tasks.append(task_info["action"])
return tasks
# Unless we run through flask, we can miss critical settings or telemetry signals.
if __name__ == "__main__":
raise RuntimeError("Run via the parent directory: 'flask create-index'")