Skip to content

Commit

Permalink
feat: chore: show how to log context in RAG notebook example
Browse files Browse the repository at this point in the history
  • Loading branch information
Stainless Bot committed Sep 25, 2024
1 parent 0e7f228 commit 5610593
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions examples/tracing/rag/rag_tracing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"outputs": [],
"source": [
"import os\n",
"import openai\n",
"\n",
"# OpenAI env variables\n",
"os.environ[\"OPENAI_API_KEY\"] = \"YOUR_OPENAI_API_KEY_HERE\"\n",
Expand Down Expand Up @@ -58,13 +57,12 @@
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import time\n",
"from typing import List\n",
"\n",
"import numpy as np\n",
"from openai import OpenAI\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"\n",
"from openlayer.lib import trace, trace_openai"
]
Expand Down Expand Up @@ -93,31 +91,35 @@
"\n",
" Answers to a user query with the LLM.\n",
" \"\"\"\n",
" context = self.retrieve_context(user_query)\n",
" context = self.retrieve_contexts(user_query)\n",
" prompt = self.inject_prompt(user_query, context)\n",
" answer = self.generate_answer_with_gpt(prompt)\n",
" return answer\n",
"\n",
" @trace()\n",
" def retrieve_context(self, query: str) -> str:\n",
" def retrieve_contexts(self, query: str) -> List[str]:\n",
" \"\"\"Context retriever.\n",
"\n",
" Given the query, returns the most similar context (using TFIDF).\n",
" \"\"\"\n",
" query_vector = self.vectorizer.transform([query])\n",
" cosine_similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten()\n",
" most_relevant_idx = np.argmax(cosine_similarities)\n",
" return self.context_sections[most_relevant_idx]\n",
" contexts = [self.context_sections[most_relevant_idx]]\n",
" return contexts\n",
"\n",
" @trace()\n",
" def inject_prompt(self, query: str, context: str):\n",
" # You can also specify the name of the `context_kwarg` to unlock RAG metrics that\n",
" # evaluate the performance of the context retriever. The value of the `context_kwarg`\n",
" # should be a list of strings.\n",
" @trace(context_kwarg=\"contexts\")\n",
" def inject_prompt(self, query: str, contexts: List[str]) -> List[dict]:\n",
" \"\"\"Combines the query with the context and returns\n",
" the prompt (formatted to conform with OpenAI models).\"\"\"\n",
" return [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": f\"Answer the user query using only the following context: {context}. \\nUser query: {query}\",\n",
" \"content\": f\"Answer the user query using only the following context: {contexts[0]}. \\nUser query: {query}\",\n",
" },\n",
" ]\n",
"\n",
Expand Down Expand Up @@ -172,7 +174,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f960a36f-3438-4c81-8cdb-ca078aa509cd",
"id": "a45d5562",
"metadata": {},
"outputs": [],
"source": []
Expand Down

0 comments on commit 5610593

Please sign in to comment.