Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(traces): add reranking span kind for document reranking in llama index #1588

Merged
merged 13 commits into from
Oct 12, 2023
1 change: 1 addition & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ enum SpanKind {
retriever
embedding
agent
reranking
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
unknown
}

Expand Down
9 changes: 9 additions & 0 deletions app/src/openInference/tracing/semanticConventions.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export const SemanticAttributePrefixes = {
llm: "llm",
retrieval: "retrieval",
reranking: "reranking",
messages: "messages",
message: "message",
document: "document",
Expand All @@ -26,6 +27,14 @@ export const RetrievalAttributePostfixes = {
documents: "documents",
} as const;

export const RerankingAttributePostfixes = {
input_documents: "input_documents",
output_documents: "output_documents",
query: "query",
model_name: "model_name",
top_k: "top_k",
} as const;

export const EmbeddingAttributePostfixes = {
embeddings: "embeddings",
text: "text",
Expand Down
99 changes: 99 additions & 0 deletions app/src/pages/trace/TracePage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import {
MESSAGE_FUNCTION_CALL_NAME,
MESSAGE_NAME,
MESSAGE_ROLE,
RerankingAttributePostfixes,
RetrievalAttributePostfixes,
SemanticAttributePrefixes,
ToolAttributePostfixes,
Expand Down Expand Up @@ -324,6 +325,12 @@ function SpanInfo({ span }: { span: Span }) {
);
break;
}
case "reranking": {
content = (
<RerankingSpanInfo span={span} spanAttributes={attributesObject} />
);
break;
}
case "embedding": {
content = (
<EmbeddingSpanInfo span={span} spanAttributes={attributesObject} />
Expand Down Expand Up @@ -572,6 +579,98 @@ function RetrieverSpanInfo(props: {
);
}

function RerankingSpanInfo(props: {
span: Span;
spanAttributes: AttributeObject;
}) {
const { spanAttributes } = props;
const rerankingAttributes = useMemo<AttributeObject | null>(() => {
const rerankingAttrs = spanAttributes[SemanticAttributePrefixes.reranking];
if (typeof rerankingAttrs === "object") {
return rerankingAttrs as AttributeObject;
}
return null;
}, [spanAttributes]);
const query = useMemo<string>(() => {
if (rerankingAttributes == null) {
mikeldking marked this conversation as resolved.
Show resolved Hide resolved
return "";
}
return (rerankingAttributes[RerankingAttributePostfixes.query] ||
"") as string;
}, [rerankingAttributes]);
const input_documents = useMemo<AttributeDocument[]>(() => {
if (rerankingAttributes == null) {
return [];
}
return (rerankingAttributes[RerankingAttributePostfixes.input_documents] ||
[]) as AttributeDocument[];
}, [rerankingAttributes]);
const output_documents = useMemo<AttributeDocument[]>(() => {
if (rerankingAttributes == null) {
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
return [];
}
return (rerankingAttributes[RerankingAttributePostfixes.output_documents] ||
[]) as AttributeDocument[];
}, [rerankingAttributes]);

const numInputDocuments = input_documents.length;
const numOutputDocuments = output_documents.length;
return (
<Flex direction="column" gap="size-200">
<Card title="Query" {...defaultCardProps}>
<CodeBlock value={query} mimeType="text" />
</Card>
<Card
title={`Input Documents (${numInputDocuments})`}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a titleExtra prop on card where you can place a Counter component. https://5f9739a76e154c00220dd4b9-zeknbennzf.chromatic.com/?path=/story/counter--gallery

{...defaultCardProps}
defaultOpen={false}
>
{
<ul
css={css`
padding: var(--ac-global-dimension-static-size-200);
display: flex;
flex-direction: column;
gap: var(--ac-global-dimension-static-size-200);
`}
>
{input_documents.map((document, idx) => {
return (
<li key={idx}>
<DocumentItem document={document} />
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might re-color these just so that there's a visual hierarchy of color (e.g. that re-ranked documents take on a different tint) - this way as you are clicking around you can clearly see the difference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-using the DocumentItem component is good but I think just showing the new score label might be a tad confusing? Or just using score as an abstract is intended here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one case it's often spacial distance where as when running through a reranker it is a relevance rank. Just thinking that from a user's perspective displaying score: XX alongside both we lose a bit of an opportunity to explain the score in this context a bit better - score being pretty generic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think even though score is generic, it is still accurate. On the input side of the reranker, score may or may not exist, and even if it does exist, it's not considered by the reranker. But if the "input" score does exist it was generated by a preprocessor for a separate purpose. The general mental picture here is that there could be millions of documents in a corpus, and only a relatively small set are chosen to be reranked, and that selection process can have a score of its own based on the query in question. Even though that score is not meaningful to the reranker, it is still an informative attribute of the input document, because it relays the reason for how the document became a candidate in the first place (especially when the preprocessor is missing in the trace). On the other hand, we can't really get more specific that the score verbiage because we don't have more information. On balance, although it may seem confusing at first, a user should have enough context to reason their way through it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I wasn't disputing the way we capture the score - was just thinking of ways to avoid the mental "eason their way through it." a bit. But I don't have an immediate good prefix for the reranker score so let's keep it for now.

</li>
);
})}
mikeldking marked this conversation as resolved.
Show resolved Hide resolved
</ul>
}
</Card>
<Card
title={`Re-ranked Documents (${numOutputDocuments})`}
{...defaultCardProps}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above - rely on titleExtra and counter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it and it looks like this.

Screenshot 2023-10-12 at 8 59 14 AM

>
{
<ul
css={css`
padding: var(--ac-global-dimension-static-size-200);
display: flex;
flex-direction: column;
gap: var(--ac-global-dimension-static-size-200);
`}
>
{output_documents.map((document, idx) => {
return (
<li key={idx}>
<DocumentItem document={document} />
</li>
);
})}
</ul>
}
</Card>
</Flex>
);
}

function EmbeddingSpanInfo(props: {
span: Span;
spanAttributes: AttributeObject;
Expand Down
4 changes: 2 additions & 2 deletions app/src/pages/trace/__generated__/TracePageQuery.graphql.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 38 additions & 2 deletions integration-tests/trace/llama_index/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@

import pytest
from gcsfs import GCSFileSystem
from llama_index import ServiceContext, StorageContext, load_index_from_storage
from llama_index import (
ServiceContext,
StorageContext,
load_index_from_storage,
)
from llama_index.agent import OpenAIAgent
from llama_index.callbacks import CallbackManager
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.graph_stores.simple import SimpleGraphStore
from llama_index.indices.postprocessor.cohere_rerank import CohereRerank
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.llms import OpenAI
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.tools import FunctionTool
from phoenix.trace.exporter import NoOpExporter
from phoenix.trace.exporter import HttpExporter, NoOpExporter
from phoenix.trace.llama_index import OpenInferenceTraceCallbackHandler
from phoenix.trace.schemas import SpanKind
from phoenix.trace.semantic_conventions import (
Expand All @@ -33,6 +38,10 @@
MESSAGE_ROLE,
OUTPUT_MIME_TYPE,
OUTPUT_VALUE,
RERANKING_INPUT_DOCUMENTS,
RERANKING_MODEL_NAME,
RERANKING_OUTPUT_DOCUMENTS,
RERANKING_TOP_K,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick question on this parameter: https://docs.cohere.com/docs/reranking

I'm guessing this is the same as TOP_N? If you feed say 5 documents but pass top_k of 3, does it only rank 3? Just trying to understand why this is a parameter to the rerank model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the same as TOP_N. (The letter is K in literature because N is usually the total number of docs.) The caller of the reranker usually just wants a relatively small number of docs out of a initial set of tens or hundreds. It's certainly optional because it can just rank each document, but in general, a reduction in number is expected for each stage of the retrieval process.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still confused though - if I pass say 5 documents with top_k of 3 - does it rank 5 and trim the last two?

Copy link
Contributor Author

@RogerHYang RogerHYang Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it retains up to K in the output, so in the case of top 3, two docs of the lowest scores have to be dropped. Top K is applied after ranking all 5 docs.

TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
Expand Down Expand Up @@ -241,3 +250,30 @@ def add(a: int, b: int) -> int:
"title": "multiply",
"type": "object",
}


@pytest.mark.parametrize("model_name", ["text-davinci-003"], indirect=True)
def test_cohere_rerank(index: VectorStoreIndex) -> None:
callback_handler = OpenInferenceTraceCallbackHandler(exporter=HttpExporter())
service_context = ServiceContext.from_defaults(
callback_manager=CallbackManager(handlers=[callback_handler])
)
cohere_rerank = CohereRerank(top_n=2)
query_engine = index.as_query_engine(
similarity_top_k=5,
node_postprocessors=[cohere_rerank],
service_context=service_context,
)
query_engine.query("How should timestamps be formatted?")

spans = {span.name: span for span in callback_handler.get_spans()}
assert "reranking" in spans
reranking_span = spans["reranking"]
assert reranking_span.span_kind == SpanKind.RERANKING
assert (
len(reranking_span.attributes[RERANKING_INPUT_DOCUMENTS])
== query_engine.retriever.similarity_top_k
)
assert len(reranking_span.attributes[RERANKING_OUTPUT_DOCUMENTS]) == cohere_rerank.top_n
assert reranking_span.attributes[RERANKING_TOP_K] == cohere_rerank.top_n
assert reranking_span.attributes[RERANKING_MODEL_NAME] == cohere_rerank.model
1 change: 1 addition & 0 deletions src/phoenix/server/api/types/Span.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class SpanKind(Enum):
retriever = trace_schema.SpanKind.RETRIEVER
embedding = trace_schema.SpanKind.EMBEDDING
agent = trace_schema.SpanKind.AGENT
reranking = trace_schema.SpanKind.RERANKING
unknown = trace_schema.SpanKind.UNKNOWN

@classmethod
Expand Down
36 changes: 28 additions & 8 deletions src/phoenix/trace/llama_index/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
MESSAGE_ROLE,
OUTPUT_MIME_TYPE,
OUTPUT_VALUE,
RERANKING_INPUT_DOCUMENTS,
RERANKING_MODEL_NAME,
RERANKING_OUTPUT_DOCUMENTS,
RERANKING_QUERY,
RERANKING_TOP_K,
RETRIEVAL_DOCUMENTS,
TOOL_DESCRIPTION,
TOOL_NAME,
Expand Down Expand Up @@ -82,6 +87,7 @@ class CBEventData(TypedDict, total=False):
def payload_to_semantic_attributes(
event_type: CBEventType,
payload: Dict[str, Any],
is_event_end: bool = False,
) -> Dict[str, Any]:
"""
Converts a LLMapp payload to a dictionary of semantic conventions compliant attributes.
Expand All @@ -95,10 +101,10 @@ def payload_to_semantic_attributes(
{EMBEDDING_TEXT: text, EMBEDDING_VECTOR: vector}
for text, vector in zip(payload[EventPayload.CHUNKS], payload[EventPayload.EMBEDDINGS])
]
if EventPayload.QUERY_STR in payload:
if event_type is not CBEventType.RERANKING and EventPayload.QUERY_STR in payload:
attributes[INPUT_VALUE] = payload[EventPayload.QUERY_STR]
attributes[INPUT_MIME_TYPE] = MimeType.TEXT
if EventPayload.NODES in payload:
if event_type is not CBEventType.RERANKING and EventPayload.NODES in payload:
attributes[RETRIEVAL_DOCUMENTS] = [
{
DOCUMENT_ID: node_with_score.node.node_id,
Expand Down Expand Up @@ -128,11 +134,24 @@ def payload_to_semantic_attributes(
if (usage := getattr(raw, "usage", None)) is not None:
attributes.update(_get_token_counts(usage))
if event_type is CBEventType.RERANKING:
... # TODO
# if EventPayload.TOP_K in payload:
# attributes[RERANKING_TOP_K] = payload[EventPayload.TOP_K]
# if EventPayload.MODEL_NAME in payload:
# attributes[RERANKING_MODEL_NAME] = payload[EventPayload.MODEL_NAME]
if EventPayload.TOP_K in payload:
attributes[RERANKING_TOP_K] = payload[EventPayload.TOP_K]
if EventPayload.MODEL_NAME in payload:
attributes[RERANKING_MODEL_NAME] = payload[EventPayload.MODEL_NAME]
if EventPayload.QUERY_STR in payload:
attributes[RERANKING_QUERY] = payload[EventPayload.QUERY_STR]
if nodes := payload.get(EventPayload.NODES):
attributes[
RERANKING_OUTPUT_DOCUMENTS if is_event_end else RERANKING_INPUT_DOCUMENTS
] = [
{
DOCUMENT_ID: node_with_score.node.node_id,
DOCUMENT_SCORE: node_with_score.score,
DOCUMENT_CONTENT: node_with_score.node.text,
DOCUMENT_METADATA: node_with_score.node.metadata,
}
for node_with_score in nodes
]
if EventPayload.TOOL in payload:
tool_metadata = cast(ToolMetadata, payload.get(EventPayload.TOOL))
attributes[TOOL_NAME] = tool_metadata.name
Expand Down Expand Up @@ -221,7 +240,7 @@ def on_event_end(
# Parse the payload to extract the parameters
if payload is not None:
event_data["attributes"].update(
payload_to_semantic_attributes(event_type, payload),
payload_to_semantic_attributes(event_type, payload, is_event_end=True),
)

def start_trace(self, trace_id: Optional[str] = None) -> None:
Expand Down Expand Up @@ -351,6 +370,7 @@ def _get_span_kind(event_type: CBEventType) -> SpanKind:
CBEventType.RETRIEVE: SpanKind.RETRIEVER,
CBEventType.FUNCTION_CALL: SpanKind.TOOL,
CBEventType.AGENT_STEP: SpanKind.AGENT,
CBEventType.RERANKING: SpanKind.RERANKING,
}.get(event_type, SpanKind.CHAIN)


Expand Down
1 change: 1 addition & 0 deletions src/phoenix/trace/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class SpanKind(Enum):
RETRIEVER = "RETRIEVER"
EMBEDDING = "EMBEDDING"
AGENT = "AGENT"
RERANKING = "RERANKING"
UNKNOWN = "UNKNOWN"

def __str__(self) -> str:
Expand Down
21 changes: 21 additions & 0 deletions src/phoenix/trace/semantic_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,24 @@ def _missing_(cls, v: Any) -> Optional["MimeType"]:
"""
Document metadata as a string representing a JSON object
"""

RERANKING_INPUT_DOCUMENTS = "reranking.input_documents"
"""
List of documents as input to the reranker
"""
RERANKING_OUTPUT_DOCUMENTS = "reranking.output_documents"
"""
List of documents as output from the reranker
"""
RERANKING_QUERY = "reranking.query"
"""
Query string for the reranker
"""
RERANKING_MODEL_NAME = "reranking.model_name"
"""
Model name of the reranker
"""
RERANKING_TOP_K = "reranking.top_k"
"""
Top K parameter of the reranker
"""
7 changes: 7 additions & 0 deletions tests/server/api/types/test_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def test_nested_attributes() -> None:
"retrieval": {
"documents": ...,
},
"reranking": {
"input_documents": ...,
"output_documents": ...,
"model_name": ...,
"top_k": ...,
"query": ...,
},
"tool": {
"description": ...,
"name": ...,
Expand Down
Loading