Skip to content

Commit

Permalink
feat: Use Taskprocessing TextToText provider as LLM
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
  • Loading branch information
marcelklehr committed Jul 22, 2024
1 parent b24ef9c commit 78cd111
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 4 deletions.
2 changes: 2 additions & 0 deletions config.cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ embedding:
device: cpu

llm:
nc_texttotext:

llama:
model_path: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
n_batch: 512
Expand Down
4 changes: 3 additions & 1 deletion config.gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ embedding:
device: cuda

llm:
nc_texttotext:

llama:
model_path: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
n_batch: 512
Expand Down Expand Up @@ -69,4 +71,4 @@ llm:
pipeline_kwargs:
config:
max_length: 200
template: ""
template: ""
2 changes: 1 addition & 1 deletion context_chat_backend/chain/query_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_pruned_query(llm: LLM, config: TConfig, query: str, template: str, text_
or llm_config.get('config', {}).get('max_new_tokens') \
or max(
llm_config.get('pipeline_kwargs', {}).get('config', {}).get('max_new_tokens', 0),
llm_config.get('pipeline_kwargs', {}).get('config', {}).get('max_length')
llm_config.get('pipeline_kwargs', {}).get('config', {}).get('max_length', 0)
) \
or 4096

Expand Down
3 changes: 2 additions & 1 deletion context_chat_backend/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def background_init(app: FastAPI):
for model_type in ('embedding', 'llm'):
model_name = _get_model_name_or_path(config, model_type)
if model_name is None:
raise Exception(f'Error: Model name/path not found for {model_type}')
update_progress(app, progress := progress + 50)
continue

if not _download_model(model_name):
raise Exception(f'Error: Model download failed for {model_name}')
Expand Down
2 changes: 1 addition & 1 deletion context_chat_backend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain.schema.embeddings import Embeddings

_embedding_models = ['llama', 'hugging_face', 'instructor']
_llm_models = ['llama', 'hugging_face', 'ctransformer']
_llm_models = ['nc_texttotext', 'llama', 'hugging_face', 'ctransformer']

models = {
'embedding': _embedding_models,
Expand Down
83 changes: 83 additions & 0 deletions context_chat_backend/models/nc_texttotext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
import time
from typing import Any, Dict, List, Optional

from nc_py_api import Nextcloud
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM

def get_model_for(model_type: str, model_config: dict):
if model_config is None:
return None

if model_type == 'llm':
return CustomLLM()

return None

class CustomLLM(LLM):
"""A custom chat model that queries Nextcloud's TextToText provider
"""

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.
Override this method to implement the LLM logic.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
If stop tokens are not supported consider raising NotImplementedError.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
The model output as a string. Actual completions SHOULD NOT include the prompt.
"""
nc = Nextcloud()

print(json.dumps(prompt))

response = nc.ocs("POST", "/ocs/v1.php/taskprocessing/schedule", json={
"type": "core:text2text",
"appId": "context_chat_backend",
"input": {
"input": prompt
}
})

task_id = response["task"]["id"]

while response['task']['status'] != 'STATUS_SUCCESSFUL' and response['task']['status'] != 'STATUS_FAILED':
time.sleep(5)
response = nc.ocs("GET", f"/ocs/v1.php/taskprocessing/task/{task_id}")
print(json.dumps(response))

if response['task']['status'] == 'STATUS_FAILED':
raise RuntimeError('Nextcloud TaskProcessing Task failed')

return response['task']['output']['output']

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "NextcloudTextToTextProvider",
}

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "nc_texttotetx"
1 change: 1 addition & 0 deletions requirements.in.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ unstructured @ git+https://github.com/kyteinsky/unstructured@d3a404cfb541dae8e16
unstructured-client
weaviate-client
xlrd
nc_py_api
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ mpmath==1.3.0
msg-parser==1.2.0
multidict==6.0.5
mypy-extensions==1.0.0
nc-py-api==0.14.0
nest-asyncio==1.6.0
networkx==3.3
nltk==3.8.1
Expand Down Expand Up @@ -189,5 +190,6 @@ websockets==12.0
wrapt==1.16.0
xlrd==2.0.1
XlsxWriter==3.2.0
xmltodict==0.13.0
yarl==1.9.4
zipp==3.19.2

0 comments on commit 78cd111

Please sign in to comment.