Skip to content

Commit

Permalink
Update functionality of thumbs up/down buttons in UI to update thread…
Browse files Browse the repository at this point in the history
… checkpoint metadata score.
  • Loading branch information
andrewnguonly committed May 15, 2024
1 parent 210f8e3 commit 3eb93ac
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 8 deletions.
27 changes: 27 additions & 0 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class ThreadPostRequest(BaseModel):
config: Optional[Dict[str, Any]] = None


class ThreadPatchRequest(BaseModel):
"""Payload for patching thread state."""

metdata: Dict[str, Any]
config: Optional[Dict[str, Any]] = None


@router.get("/")
async def list_threads(user: AuthedUser) -> List[Thread]:
"""List all threads for the current user."""
Expand Down Expand Up @@ -75,6 +82,26 @@ async def add_thread_state(
)


@router.patch("/{tid}/state")
async def patch_thread_state(
user: AuthedUser,
tid: ThreadID,
payload: ThreadPatchRequest,
):
"""Patch state for a thread."""
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"])
if not assistant:
raise HTTPException(status_code=400, detail="Thread has no assistant")

return await storage.patch_thread_state(
payload.config or {"configurable": {"thread_id": tid}},
payload.metdata,
)


@router.get("/{tid}/history")
async def get_thread_history(
user: AuthedUser,
Expand Down
9 changes: 8 additions & 1 deletion backend/app/storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union

from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -145,6 +145,13 @@ async def update_thread_state(
return await get_langserve().threads.update_state(config, values)


async def patch_thread_state(
config: RunnableConfig,
metadata: Dict[str, Any],
):
"""Patch state of a thread."""
return await get_langserve().threads.patch_state(config, metadata)

async def get_thread_history(*, user_id: str, thread_id: str, assistant: Assistant):
"""Get the history of a thread."""
return await get_langserve().threads.get_history(thread_id)
Expand Down
1 change: 1 addition & 0 deletions frontend/src/components/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ export function Chat(props: ChatProps) {
}
startEditing={() => recordEdits(msg)}
alwaysShowControls={i === messages.length - 1}
threadId={currentChat.thread_id}
/>
),
)}
Expand Down
10 changes: 4 additions & 6 deletions frontend/src/components/LangSmithActions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@ import {
} from "@heroicons/react/24/outline";
import { useState } from "react";

export function LangSmithActions(props: { runId: string }) {
export function LangSmithActions(props: { runId: string, threadId: string }) {
const [state, setState] = useState<{
score: number;
inflight: boolean;
} | null>(null);
const sendFeedback = async (score: number) => {
setState({ score, inflight: true });
await fetch(`/runs/feedback`, {
method: "POST",
await fetch(`/threads/${props.threadId}/state`, {
method: "PATCH",
body: JSON.stringify({
run_id: props.runId,
key: "user_score",
score: score,
metadata: {score: score},
}),
headers: {
"Content-Type": "application/json",
Expand Down
3 changes: 2 additions & 1 deletion frontend/src/components/Message.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export const MessageViewer = memo(function (
runId?: string;
startEditing?: () => void;
alwaysShowControls?: boolean;
threadId: string;
},
) {
const [open, setOpen] = useState(false);
Expand Down Expand Up @@ -116,7 +117,7 @@ export const MessageViewer = memo(function (
</div>
{props.runId && (
<div className="mt-2 pl-[148px]">
<LangSmithActions runId={props.runId} />
<LangSmithActions runId={props.runId} threadId={props.threadId}/>
</div>
)}
</div>
Expand Down

0 comments on commit 3eb93ac

Please sign in to comment.