Skip to content

Commit

Permalink
fix message role
Browse files Browse the repository at this point in the history
  • Loading branch information
willydouhard committed Aug 2, 2023
1 parent 0955829 commit 7d1c74b
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { Playground } from 'state/playground';

export default function getProvider(playground: Playground) {
const isChat = !!playground?.prompt?.messages;

const providers = playground?.providers
? playground.providers.filter((p) => p.is_chat === isChat)
: [];

if (!providers?.length) {
throw new Error('No LLM provider available');
}

let provider = providers.find(
(provider) => provider.id === playground.prompt?.provider
);

const providerFound = !!provider;

provider = provider || providers[0];

return {
provider,
providerFound
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import ActionBar from './actionBar';
import BasicPromptPlayground from './basic';
import ChatPromptPlayground from './chat';
import VariableModal from './editor/variableModal';
import getProvider from './helpers';
import ModelSettings from './modelSettings';

export type PromptMode = 'Template' | 'Formatted';
Expand Down Expand Up @@ -78,7 +79,9 @@ export default function Playground() {

const submit = async () => {
try {
const { provider } = getProvider(playground);
const prompt = preparePrompt(playground.prompt);
prompt.provider = provider.id;
setLoading(true);
const completion = await client.getCompletion(prompt, userEnv);
setPlayground((old) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
import { playgroundState } from 'state/playground';

import FormInput, { TFormInput, TFormInputValue } from '../FormInput';
import getProvider from './helpers';

type Schema = {
[key: string]: yup.Schema;
Expand All @@ -26,23 +27,7 @@ type Schema = {
const ModelSettings = () => {
const [playground, setPlayground] = useRecoilState(playgroundState);

const isChat = !!playground?.prompt?.messages;

const providers = playground?.providers
? playground.providers.filter((p) => p.is_chat === isChat)
: [];

if (!providers) {
throw new Error('No LLM provider available.');
}

let provider = providers.find(
(provider) => provider.id === playground.prompt?.provider
);

const providerFound = !!provider;

provider = provider || providers[0];
const { provider, providerFound } = getProvider(playground);

const providerWarning = !providerFound ? (
<Alert severity="warning">
Expand Down Expand Up @@ -95,7 +80,6 @@ const ModelSettings = () => {
setPlayground((old) =>
merge(cloneDeep(old), {
prompt: {
provider: provider?.id,
settings: formik.values
}
})
Expand Down
2 changes: 1 addition & 1 deletion src/chainlit/frontend/src/state/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ interface IBaseTemplate {
}

export interface IPromptMessage extends IBaseTemplate {
role: string;
role: 'system' | 'assistant' | 'user' | 'function';
}

export type ILLMSettings = Record<string, string | string[] | number>;
Expand Down
18 changes: 16 additions & 2 deletions src/chainlit/lc/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ def build_prompt(serialized: Dict[str, Any], inputs: Dict[str, Any]):
return Prompt(inputs=inputs, messages=messages)


def convert_role(role: str):
if role == "human" or role == "chat":
return "user"
elif role == "system":
return "system"
elif role == "ai":
return "assistant"
elif role == "function":
return "function"
else:
raise ValueError(f"Unsupported role {role}")


class BaseLangchainCallbackHandler(BaseCallbackHandler):
emitter: ChainlitEmitter
# Keep track of the prompt to display them in the prompt playground.
Expand Down Expand Up @@ -210,14 +223,15 @@ def _on_chat_model_start(
if self.current_prompt.messages:
for idx, m in enumerate(messages[0]):
self.current_prompt.messages[idx].formatted = m.content
self.current_prompt.messages[idx].role = m.type
self.current_prompt.messages[idx].role = convert_role(m.type)

elif self.current_prompt.template:
formatted_prompt = "\n".join([m.content for m in messages[0]])
self.current_prompt.formatted = formatted_prompt
else:
prompt_messages = [
PromptMessage(formatted=m.content, role=m.type) for m in messages[0]
PromptMessage(formatted=m.content, role=convert_role(m.type))
for m in messages[0]
]
self.current_prompt = Prompt(
messages=prompt_messages,
Expand Down
2 changes: 1 addition & 1 deletion src/chainlit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class BaseTemplate:
@dataclass_json
@dataclass
class PromptMessage(BaseTemplate):
role: Optional[str] = None
role: Optional[Literal["system", "assistant", "user", "function"]] = None


@dataclass_json
Expand Down

0 comments on commit 7d1c74b

Please sign in to comment.