-
-
Notifications
You must be signed in to change notification settings - Fork 3
/
llm_command_r.py
131 lines (114 loc) · 4.17 KB
/
llm_command_r.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import click
import cohere
import llm
from pydantic import Field
import sqlite_utils
import sys
from typing import Optional, List
@llm.hookimpl
def register_commands(cli):
@cli.command()
@click.argument("prompt")
@click.option("-s", "--system", help="System prompt to use")
@click.option("model_id", "-m", "--model", help="Model to use")
@click.option(
"options",
"-o",
"--option",
type=(str, str),
multiple=True,
help="key/value options for the model",
)
@click.option("-n", "--no-log", is_flag=True, help="Don't log to database")
@click.option("--key", help="API key to use")
def command_r_search(prompt, system, model_id, options, no_log, key):
"Prompt Command R with the web search feature"
from llm.cli import logs_on, logs_db_path
from llm.migrations import migrate
model_id = model_id or "command-r"
model = llm.get_model(model_id)
if model.needs_key:
model.key = llm.get_key(key, model.needs_key, model.key_env_var)
validated_options = {}
options = list(options)
options.append(("websearch", "1"))
try:
validated_options = dict(
(key, value)
for key, value in model.Options(**dict(options))
if value is not None
)
except pydantic.ValidationError as ex:
raise click.ClickException(render_errors(ex.errors()))
response = model.prompt(prompt, system=system, **validated_options)
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
# Log to the database
if (logs_on() or log) and not no_log:
log_path = logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
migrate(db)
response.log_to_db(db)
# Now output the citations
documents = response.response_json.get("documents", [])
if documents:
print()
print()
print("Sources:")
print()
for doc in documents:
print("-", doc["title"], "-", doc["url"])
@llm.hookimpl
def register_models(register):
# https://docs.cohere.com/docs/models
register(CohereMessages("command-r"), aliases=("r",))
register(CohereMessages("command-r-plus"), aliases=("r-plus",))
class CohereMessages(llm.Model):
needs_key = "cohere"
key_env_var = "COHERE_API_KEY"
can_stream = True
class Options(llm.Options):
websearch: Optional[bool] = Field(
description="Use web search connector",
default=False,
)
def __init__(self, model_id):
self.model_id = model_id
def build_chat_history(self, conversation) -> List[dict]:
chat_history = []
if conversation:
for response in conversation.responses:
chat_history.extend(
[
{"role": "USER", "text": response.prompt.prompt},
{"role": "CHATBOT", "text": response.text()},
]
)
return chat_history
def execute(self, prompt, stream, response, conversation):
client = cohere.Client(self.get_key())
kwargs = {
"message": prompt.prompt,
"model": self.model_id,
}
if prompt.system:
kwargs["preamble"] = prompt.system
if conversation:
kwargs["chat_history"] = self.build_chat_history(conversation)
if prompt.options.websearch:
kwargs["connectors"] = [{"id": "web-search"}]
if stream:
for event in client.chat_stream(**kwargs):
if event.event_type == "text-generation":
yield event.text
elif event.event_type == "stream-end":
response.response_json = event.response.dict()
else:
event = client.chat(**kwargs)
answer = event.text
yield answer
response.response_json = event.dict()
def __str__(self):
return "Cohere Messages: {}".format(self.model_id)