-
Notifications
You must be signed in to change notification settings - Fork 403
/
actions.py
146 lines (123 loc) · 5.83 KB
/
actions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_core.language_models.llms import BaseLLM
from nemoguardrails import RailsConfig
from nemoguardrails.actions import action
from nemoguardrails.actions.llm.utils import (
get_multiline_response,
llm_call,
strip_quotes,
)
from nemoguardrails.context import llm_call_info_var
from nemoguardrails.llm.params import llm_params
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.llm.types import Task
from nemoguardrails.logging.callbacks import logging_callback_manager_for_chain
from nemoguardrails.logging.explain import LLMCallInfo
log = logging.getLogger(__name__)
HALLUCINATION_NUM_EXTRA_RESPONSES = 2
@action()
async def self_check_hallucination(
llm: BaseLLM,
llm_task_manager: LLMTaskManager,
context: Optional[dict] = None,
use_llm_checking: bool = True,
config: Optional[RailsConfig] = None,
):
"""Checks if the last bot response is a hallucination by checking multiple completions for self-consistency.
:return: True if hallucination is detected, False otherwise.
"""
try:
from langchain_openai import OpenAI
except ImportError:
log.warning(
"The langchain_openai module is not installed. Please install it using pip: pip install langchain_openai"
)
bot_response = context.get("bot_message")
last_bot_prompt_string = context.get("_last_bot_prompt")
if bot_response and last_bot_prompt_string:
num_responses = HALLUCINATION_NUM_EXTRA_RESPONSES
# Use beam search for the LLM call, to get several completions with only one call.
# At the current moment, only OpenAI LLM engines are supported for computing the additional completions.
if "openai" not in str(type(llm)).lower():
log.warning(
f"Hallucination rail is optimized for OpenAI LLM engines. "
f"Current LLM engine is {type(llm).__name__}, which may not support all features."
)
if "n" not in llm.__fields__:
log.warning(
f"LLM engine {type(llm).__name__} does not support the 'n' parameter for generating multiple completion choices. "
f"Please use an OpenAI LLM engine or a model that supports the 'n' parameter for optimal performance."
)
return False
# Use the "generate" call from langchain to get all completions in the same response.
last_bot_prompt = PromptTemplate(template="{text}", input_variables=["text"])
chain = LLMChain(prompt=last_bot_prompt, llm=llm)
# Generate multiple responses with temperature 1.
with llm_params(llm, temperature=1.0, n=num_responses):
extra_llm_response = await chain.agenerate(
[{"text": last_bot_prompt_string}],
run_manager=logging_callback_manager_for_chain,
)
extra_llm_completions = []
if len(extra_llm_response.generations) > 0:
extra_llm_completions = extra_llm_response.generations[0]
extra_responses = []
i = 0
while i < num_responses and i < len(extra_llm_completions):
result = extra_llm_completions[i].text
# We need the same post-processing of responses as in "generate_bot_message"
result = get_multiline_response(result)
result = strip_quotes(result)
extra_responses.append(result)
i += 1
if len(extra_responses) == 0:
# Log message and return that no hallucination was found
log.warning(
f"No extra LLM responses were generated for '{bot_response}' hallucination check."
)
return False
elif len(extra_responses) < num_responses:
log.warning(
f"Requested {num_responses} extra LLM responses for hallucination check, "
f"received {len(extra_responses)}."
)
if use_llm_checking:
# Only support LLM-based agreement check in current version
prompt = llm_task_manager.render_task_prompt(
task=Task.SELF_CHECK_HALLUCINATION,
context={
"statement": bot_response,
"paragraph": ". ".join(extra_responses),
},
)
# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.SELF_CHECK_HALLUCINATION.value))
stop = llm_task_manager.get_stop_tokens(task=Task.SELF_CHECK_HALLUCINATION)
with llm_params(llm, temperature=config.lowest_temperature):
agreement = await llm_call(llm, prompt, stop=stop)
agreement = agreement.lower().strip()
log.info(f"Agreement result for looking for hallucination is {agreement}.")
# Return True if the hallucination check fails
return "no" in agreement
else:
# TODO Implement BERT-Score based consistency method proposed by SelfCheckGPT paper
# See details: https://arxiv.org/abs/2303.08896
return False
return False