-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
Copy pathkernel_function_from_prompt.py
356 lines (313 loc) · 17.1 KB
/
kernel_function_from_prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
# Copyright (c) Microsoft. All rights reserved.
import logging
import os
from collections.abc import AsyncGenerator
from html import unescape
from typing import TYPE_CHECKING, Any
import yaml
from pydantic import Field, ValidationError, model_validator
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase
from semantic_kernel.const import DEFAULT_SERVICE_NAME
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.exceptions import FunctionExecutionException, FunctionInitializationError
from semantic_kernel.exceptions.function_exceptions import PromptRenderingException
from semantic_kernel.filters.filter_types import FilterTypes
from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext
from semantic_kernel.filters.kernel_filters_extension import _rebuild_prompt_render_context
from semantic_kernel.filters.prompts.prompt_render_context import PromptRenderContext
from semantic_kernel.functions.function_result import FunctionResult
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.functions.kernel_function import TEMPLATE_FORMAT_MAP, KernelFunction
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata
from semantic_kernel.functions.prompt_rendering_result import PromptRenderingResult
from semantic_kernel.prompt_template.const import KERNEL_TEMPLATE_FORMAT_NAME, TEMPLATE_FORMAT_TYPES
from semantic_kernel.prompt_template.prompt_template_base import PromptTemplateBase
from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig
if TYPE_CHECKING:
from semantic_kernel.services.ai_service_client_base import AIServiceClientBase
logger: logging.Logger = logging.getLogger(__name__)
PROMPT_FILE_NAME = "skprompt.txt"
CONFIG_FILE_NAME = "config.json"
PROMPT_RETURN_PARAM = KernelParameterMetadata(
name="return",
description="The completion result",
default_value=None,
type="FunctionResult", # type: ignore
is_required=True,
)
class KernelFunctionFromPrompt(KernelFunction):
"""Semantic Kernel Function from a prompt."""
prompt_template: PromptTemplateBase
prompt_execution_settings: dict[str, PromptExecutionSettings] = Field(default_factory=dict)
def __init__(
self,
function_name: str,
plugin_name: str | None = None,
description: str | None = None,
prompt: str | None = None,
template_format: TEMPLATE_FORMAT_TYPES = KERNEL_TEMPLATE_FORMAT_NAME,
prompt_template: PromptTemplateBase | None = None,
prompt_template_config: PromptTemplateConfig | None = None,
prompt_execution_settings: None
| (PromptExecutionSettings | list[PromptExecutionSettings] | dict[str, PromptExecutionSettings]) = None,
) -> None:
"""Initializes a new instance of the KernelFunctionFromPrompt class.
Args:
function_name (str): The name of the function
plugin_name (str): The name of the plugin
description (str): The description for the function
prompt (Optional[str]): The prompt
template_format (Optional[str]): The template format, default is "semantic-kernel"
prompt_template (Optional[KernelPromptTemplate]): The prompt template
prompt_template_config (Optional[PromptTemplateConfig]): The prompt template configuration
prompt_execution_settings (Optional): instance, list or dict of PromptExecutionSettings to be used
by the function, can also be supplied through prompt_template_config,
but the supplied one is used if both are present.
prompt_template_config (Optional[PromptTemplateConfig]): the prompt template config.
"""
if not prompt and not prompt_template_config and not prompt_template:
raise FunctionInitializationError(
"The prompt cannot be empty, must be supplied directly, \
through prompt_template_config or in the prompt_template."
)
if prompt and prompt_template_config and prompt_template_config.template != prompt:
logger.warning(
f"Prompt ({prompt}) and PromptTemplateConfig ({prompt_template_config.template}) both supplied, "
"using the template in PromptTemplateConfig, ignoring prompt."
)
if template_format and prompt_template_config and prompt_template_config.template_format != template_format:
logger.warning(
f"Template ({template_format}) and PromptTemplateConfig ({prompt_template_config.template_format}) "
"both supplied, using the template format in PromptTemplateConfig, ignoring template."
)
if not prompt_template:
if not prompt_template_config:
# prompt must be there if prompt_template and prompt_template_config is not supplied
prompt_template_config = PromptTemplateConfig(
name=function_name,
description=description,
template=prompt,
template_format=template_format,
)
prompt_template = TEMPLATE_FORMAT_MAP[prompt_template_config.template_format](
prompt_template_config=prompt_template_config
) # type: ignore
try:
metadata = KernelFunctionMetadata(
name=function_name,
plugin_name=plugin_name,
description=description,
parameters=prompt_template.prompt_template_config.get_kernel_parameter_metadata(), # type: ignore
is_prompt=True,
is_asynchronous=True,
return_parameter=PROMPT_RETURN_PARAM,
)
except ValidationError as exc:
raise FunctionInitializationError("Failed to create KernelFunctionMetadata") from exc
super().__init__(
metadata=metadata,
prompt_template=prompt_template, # type: ignore
prompt_execution_settings=prompt_execution_settings or {}, # type: ignore
)
@model_validator(mode="before")
@classmethod
def rewrite_execution_settings(
cls,
data: dict[str, Any],
) -> dict[str, PromptExecutionSettings]:
"""Rewrite execution settings to a dictionary.
If the prompt_execution_settings is not a dictionary, it is converted to a dictionary.
If it is not supplied, but prompt_template is, the prompt_template's execution settings are used.
"""
prompt_execution_settings = data.get("prompt_execution_settings")
prompt_template = data.get("prompt_template")
if not prompt_execution_settings:
if prompt_template:
prompt_execution_settings = prompt_template.prompt_template_config.execution_settings
data["prompt_execution_settings"] = prompt_execution_settings
if not prompt_execution_settings:
return data
if isinstance(prompt_execution_settings, PromptExecutionSettings):
data["prompt_execution_settings"] = {
prompt_execution_settings.service_id or DEFAULT_SERVICE_NAME: prompt_execution_settings
}
if isinstance(prompt_execution_settings, list):
data["prompt_execution_settings"] = {
s.service_id or DEFAULT_SERVICE_NAME: s for s in prompt_execution_settings
}
return data
async def _invoke_internal(self, context: FunctionInvocationContext) -> None:
"""Invokes the function with the given arguments."""
prompt_render_result = await self._render_prompt(context)
if prompt_render_result.function_result is not None:
context.result = prompt_render_result.function_result
return
if isinstance(prompt_render_result.ai_service, ChatCompletionClientBase):
chat_history = ChatHistory.from_rendered_prompt(prompt_render_result.rendered_prompt)
try:
chat_message_contents = await prompt_render_result.ai_service.get_chat_message_contents(
chat_history=chat_history,
settings=prompt_render_result.execution_settings,
**{"kernel": context.kernel, "arguments": context.arguments},
)
except Exception as exc:
raise FunctionExecutionException(f"Error occurred while invoking function {self.name}: {exc}") from exc
if not chat_message_contents:
raise FunctionExecutionException(f"No completions returned while invoking function {self.name}")
context.result = self._create_function_result(
completions=chat_message_contents, chat_history=chat_history, arguments=context.arguments
)
return
if isinstance(prompt_render_result.ai_service, TextCompletionClientBase):
try:
texts = await prompt_render_result.ai_service.get_text_contents(
prompt=unescape(prompt_render_result.rendered_prompt),
settings=prompt_render_result.execution_settings,
)
except Exception as exc:
raise FunctionExecutionException(f"Error occurred while invoking function {self.name}: {exc}") from exc
context.result = self._create_function_result(
completions=texts, arguments=context.arguments, prompt=prompt_render_result.rendered_prompt
)
return
raise ValueError(f"Service `{type(prompt_render_result.ai_service).__name__}` is not a valid AI service")
async def _invoke_internal_stream(self, context: FunctionInvocationContext) -> None:
"""Invokes the function stream with the given arguments."""
prompt_render_result = await self._render_prompt(context)
if isinstance(prompt_render_result.ai_service, ChatCompletionClientBase):
chat_history = ChatHistory.from_rendered_prompt(prompt_render_result.rendered_prompt)
value: AsyncGenerator = prompt_render_result.ai_service.get_streaming_chat_message_contents(
chat_history=chat_history,
settings=prompt_render_result.execution_settings,
**{"kernel": context.kernel, "arguments": context.arguments},
)
elif isinstance(prompt_render_result.ai_service, TextCompletionClientBase):
value = prompt_render_result.ai_service.get_streaming_text_contents(
prompt=prompt_render_result.rendered_prompt, settings=prompt_render_result.execution_settings
)
else:
raise FunctionExecutionException(
f"Service `{type(prompt_render_result.ai_service)}` is not a valid AI service"
)
context.result = FunctionResult(function=self.metadata, value=value)
async def _render_prompt(self, context: FunctionInvocationContext) -> PromptRenderingResult:
"""Render the prompt and apply the prompt rendering filters."""
self.update_arguments_with_defaults(context.arguments)
_rebuild_prompt_render_context()
prompt_render_context = PromptRenderContext(function=self, kernel=context.kernel, arguments=context.arguments)
stack = context.kernel.construct_call_stack(
filter_type=FilterTypes.PROMPT_RENDERING,
inner_function=self._inner_render_prompt,
)
await stack(prompt_render_context)
if prompt_render_context.rendered_prompt is None:
raise PromptRenderingException("Prompt rendering failed, no rendered prompt was returned.")
selected_service: tuple["AIServiceClientBase", PromptExecutionSettings] = context.kernel.select_ai_service(
function=self, arguments=context.arguments
)
return PromptRenderingResult(
rendered_prompt=prompt_render_context.rendered_prompt,
ai_service=selected_service[0],
execution_settings=selected_service[1],
)
async def _inner_render_prompt(self, context: PromptRenderContext) -> None:
"""Render the prompt using the prompt template."""
context.rendered_prompt = await self.prompt_template.render(context.kernel, context.arguments)
def _create_function_result(
self,
completions: list[ChatMessageContent] | list[TextContent],
arguments: KernelArguments,
chat_history: ChatHistory | None = None,
prompt: str | None = None,
) -> FunctionResult:
"""Creates a function result with the given completions."""
metadata: dict[str, Any] = {
"arguments": arguments,
"metadata": [completion.metadata for completion in completions],
}
if chat_history:
metadata["messages"] = chat_history
if prompt:
metadata["prompt"] = prompt
return FunctionResult(
function=self.metadata,
value=completions,
metadata=metadata,
)
def update_arguments_with_defaults(self, arguments: KernelArguments) -> None:
"""Update any missing values with their defaults."""
for parameter in self.prompt_template.prompt_template_config.input_variables:
if parameter.name not in arguments and parameter.default not in {None, "", False, 0}:
arguments[parameter.name] = parameter.default
@classmethod
def from_yaml(cls, yaml_str: str, plugin_name: str | None = None) -> "KernelFunctionFromPrompt":
"""Creates a new instance of the KernelFunctionFromPrompt class from a YAML string."""
try:
data = yaml.safe_load(yaml_str)
except yaml.YAMLError as exc: # pragma: no cover
raise FunctionInitializationError(f"Invalid YAML content: {yaml_str}, error: {exc}") from exc
if not isinstance(data, dict):
raise FunctionInitializationError(f"The YAML content must represent a dictionary, got {yaml_str}")
try:
prompt_template_config = PromptTemplateConfig(**data)
except ValidationError as exc:
raise FunctionInitializationError(
f"Error initializing PromptTemplateConfig: {exc} from yaml data: {data}"
) from exc
return cls(
function_name=prompt_template_config.name,
plugin_name=plugin_name,
description=prompt_template_config.description,
prompt_template_config=prompt_template_config,
template_format=prompt_template_config.template_format,
)
@classmethod
def from_directory(cls, path: str, plugin_name: str | None = None) -> "KernelFunctionFromPrompt":
"""Creates a new instance of the KernelFunctionFromPrompt class from a directory.
The directory needs to contain:
- A prompt file named `skprompt.txt`
- A config file named `config.json`
Returns:
KernelFunctionFromPrompt: The kernel function from prompt
"""
prompt_path = os.path.join(path, PROMPT_FILE_NAME)
config_path = os.path.join(path, CONFIG_FILE_NAME)
prompt_exists = os.path.exists(prompt_path)
config_exists = os.path.exists(config_path)
if not config_exists and not prompt_exists:
raise FunctionInitializationError(
f"{PROMPT_FILE_NAME} and {CONFIG_FILE_NAME} files are required to create a "
f"function from a directory, path: {path!s}."
)
if not config_exists:
raise FunctionInitializationError(
f"{CONFIG_FILE_NAME} files are required to create a function from a directory, "
f"path: {path!s}, prompt file is there."
)
if not prompt_exists:
raise FunctionInitializationError(
f"{PROMPT_FILE_NAME} files are required to create a function from a directory, "
f"path: {path!s}, config file is there."
)
function_name = os.path.basename(path)
with open(config_path) as config_file:
prompt_template_config = PromptTemplateConfig.from_json(config_file.read())
prompt_template_config.name = function_name
with open(prompt_path) as prompt_file:
prompt_template_config.template = prompt_file.read()
prompt_template = TEMPLATE_FORMAT_MAP[prompt_template_config.template_format]( # type: ignore
prompt_template_config=prompt_template_config
)
return cls(
function_name=function_name,
plugin_name=plugin_name,
prompt_template=prompt_template,
prompt_template_config=prompt_template_config,
template_format=prompt_template_config.template_format,
description=prompt_template_config.description,
)