-
Notifications
You must be signed in to change notification settings - Fork 0
/
webui_llm.py
91 lines (73 loc) · 2.99 KB
/
webui_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from langchain.schema import LLMResult, Generation
from langchain.llms.base import BaseLLM
from pydantic import BaseModel
from typing import Any, List, Optional
from modules.text_generation import generate_reply
class WebUILLM(BaseLLM, BaseModel):
streaming: bool = False
"""Whether to stream the results or not."""
generation_attempts: int = 1
"""Number of generations per prompt."""
max_new_tokens: int = 200
"""Maximum number of newly generated tokens."""
do_sample: bool = True
"""Whether to do sample."""
temperature: float = 0.72
"""Creativity of the model."""
top_p: float = 0.18
"""Top P."""
typical_p: float = 1
"""Typical P."""
top_k: int = 30
"""Top K."""
min_length: int = 0
"""Minimum length of the generated result."""
repetition_penalty: float = 1.5
"""Penalizes repetition."""
encoder_repetition_penalty: float = 1
"""Penalizes encoder repetition."""
penalty_alpha: float = 0
"""Alpha for Contrastive Search penalties."""
no_repeat_ngram_size: int = 0
"""Size of ngrams for repetition penalty."""
num_beams: int = 1
"""Number of beams."""
length_penalty: int = 1
"""Penalizes length."""
seed: int = -1
"""Generation Seed."""
generate_state: Any = None
def set_state(self, state):
self.generate_state = state
def _llm_type(self):
return "text-generation-webui"
def _generate(self, prompts: List[str], stop: Optional[List[str]] = []) -> LLMResult:
generations = []
if not stop:
stop = []
stop.append('</end>')
with open('llm-file.log', 'w') as f:
f.write(prompts[0])
for prompt in prompts:
prompt_generations = []
prompt_length = len(prompt)
for _ in range(self.generation_attempts):
generated_length = 0
generated_string = ""
for continuation in generate_reply(prompt,
self.generate_state,
stopping_strings=stop):
old_generated_length = generated_length
generated_length = len(continuation) - prompt_length
continuation = continuation[prompt_length + old_generated_length:]
generated_string += continuation
if self.streaming:
self.callback_manager.on_llm_new_token(token=continuation)
if any(map(lambda x: generated_string.strip().endswith(x), stop)):
break
prompt_generations.append(Generation(text=generated_string))
generations.append(prompt_generations)
print('G:', generations, flush=True)
return LLMResult(generations=generations)
async def _agenerate(self, prompts: List[str], stop: Optional[List[str]] = None) -> LLMResult:
return self._generate(prompts=prompts, stop=stop)