Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(traces): add reranking span kind for document reranking in llama index #1588

Merged
merged 13 commits into from
Oct 12, 2023
1 change: 1 addition & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ enum SpanKind {
retriever
embedding
agent
reranker
unknown
}

Expand Down
25 changes: 25 additions & 0 deletions app/src/components/trace/SpanKindIcon.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,27 @@ const RetrieverSVG = () => (
</svg>
);

const RerankerSVG = () => (
<svg
width="20"
height="20"
viewBox="0 0 20 20"
fill="none"
xmlns="http://www.w3.org/2000/svg"
>
<rect
x="0.5"
y="0.5"
width="19"
height="19"
rx="3.5"
stroke="currentColor"
/>
<path d="M4.5359 10L8 4L11.4641 10H4.5359Z" stroke="currentColor" />
<path d="M8.5359 10L12 16L15.4641 10H8.5359Z" stroke="currentColor" />
</svg>
);

const ChainSVG = () => (
<svg
width="20"
Expand Down Expand Up @@ -219,6 +240,10 @@ export function SpanKindIcon({ spanKind }: { spanKind: string }) {
color = "--ac-global-color-yellow-1200";
icon = <ToolSVG />;
break;
case "reranker":
color = "--ac-global-color-celery-1000";
icon = <RerankerSVG />;
break;
}

return (
Expand Down
3 changes: 3 additions & 0 deletions app/src/components/trace/SpanKindLabel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ export function SpanKindLabel(props: { spanKind: string }) {
case "retriever":
color = "seafoam-1000";
break;
case "reranker":
color = "celery-1000";
break;
case "embedding":
color = "indigo-1000";
break;
Expand Down
9 changes: 9 additions & 0 deletions app/src/openInference/tracing/semanticConventions.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export const SemanticAttributePrefixes = {
llm: "llm",
retrieval: "retrieval",
reranker: "reranker",
messages: "messages",
message: "message",
document: "document",
Expand All @@ -26,6 +27,14 @@ export const RetrievalAttributePostfixes = {
documents: "documents",
} as const;

export const RerankerAttributePostfixes = {
input_documents: "input_documents",
output_documents: "output_documents",
query: "query",
model_name: "model_name",
top_k: "top_k",
} as const;

export const EmbeddingAttributePostfixes = {
embeddings: "embeddings",
text: "text",
Expand Down
101 changes: 101 additions & 0 deletions app/src/pages/trace/TracePage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import {
MESSAGE_FUNCTION_CALL_NAME,
MESSAGE_NAME,
MESSAGE_ROLE,
RerankerAttributePostfixes,
RetrievalAttributePostfixes,
SemanticAttributePrefixes,
ToolAttributePostfixes,
Expand Down Expand Up @@ -324,6 +325,12 @@ function SpanInfo({ span }: { span: Span }) {
);
break;
}
case "reranker": {
content = (
<RerankerSpanInfo span={span} spanAttributes={attributesObject} />
);
break;
}
case "embedding": {
content = (
<EmbeddingSpanInfo span={span} spanAttributes={attributesObject} />
Expand Down Expand Up @@ -572,6 +579,100 @@ function RetrieverSpanInfo(props: {
);
}

function RerankerSpanInfo(props: {
span: Span;
spanAttributes: AttributeObject;
}) {
const { spanAttributes } = props;
const rerankerAttributes = useMemo<AttributeObject | null>(() => {
const rerankerAttrs = spanAttributes[SemanticAttributePrefixes.reranker];
if (typeof rerankerAttrs === "object") {
return rerankerAttrs as AttributeObject;
}
return null;
}, [spanAttributes]);
const query = useMemo<string>(() => {
if (rerankerAttributes == null) {
return "";
}
return (rerankerAttributes[RerankerAttributePostfixes.query] ||
"") as string;
}, [rerankerAttributes]);
const input_documents = useMemo<AttributeDocument[]>(() => {
if (rerankerAttributes == null) {
return [];
}
return (rerankerAttributes[RerankerAttributePostfixes.input_documents] ||
[]) as AttributeDocument[];
}, [rerankerAttributes]);
const output_documents = useMemo<AttributeDocument[]>(() => {
if (rerankerAttributes == null) {
return [];
}
return (rerankerAttributes[RerankerAttributePostfixes.output_documents] ||
[]) as AttributeDocument[];
}, [rerankerAttributes]);

const numInputDocuments = input_documents.length;
const numOutputDocuments = output_documents.length;
return (
<Flex direction="column" gap="size-200">
<Card title="Query" {...defaultCardProps}>
<CodeBlock value={query} mimeType="text" />
</Card>
<Card
title={"Input Documents"}
titleExtra={<Counter variant="light">{numInputDocuments}</Counter>}
{...defaultCardProps}
defaultOpen={false}
>
{
<ul
css={css`
padding: var(--ac-global-dimension-static-size-200);
display: flex;
flex-direction: column;
gap: var(--ac-global-dimension-static-size-200);
`}
>
{input_documents.map((document, idx) => {
return (
<li key={idx}>
<DocumentItem document={document} />
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might re-color these just so that there's a visual hierarchy of color (e.g. that re-ranked documents take on a different tint) - this way as you are clicking around you can clearly see the difference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-using the DocumentItem component is good but I think just showing the new score label might be a tad confusing? Or just using score as an abstract is intended here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one case it's often spacial distance where as when running through a reranker it is a relevance rank. Just thinking that from a user's perspective displaying score: XX alongside both we lose a bit of an opportunity to explain the score in this context a bit better - score being pretty generic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think even though score is generic, it is still accurate. On the input side of the reranker, score may or may not exist, and even if it does exist, it's not considered by the reranker. But if the "input" score does exist it was generated by a preprocessor for a separate purpose. The general mental picture here is that there could be millions of documents in a corpus, and only a relatively small set are chosen to be reranked, and that selection process can have a score of its own based on the query in question. Even though that score is not meaningful to the reranker, it is still an informative attribute of the input document, because it relays the reason for how the document became a candidate in the first place (especially when the preprocessor is missing in the trace). On the other hand, we can't really get more specific that the score verbiage because we don't have more information. On balance, although it may seem confusing at first, a user should have enough context to reason their way through it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I wasn't disputing the way we capture the score - was just thinking of ways to avoid the mental "eason their way through it." a bit. But I don't have an immediate good prefix for the reranker score so let's keep it for now.

</li>
);
})}
mikeldking marked this conversation as resolved.
Show resolved Hide resolved
</ul>
}
</Card>
<Card
title={"Output Documents"}
titleExtra={<Counter variant="light">{numOutputDocuments}</Counter>}
{...defaultCardProps}
>
{
<ul
css={css`
padding: var(--ac-global-dimension-static-size-200);
display: flex;
flex-direction: column;
gap: var(--ac-global-dimension-static-size-200);
`}
>
{output_documents.map((document, idx) => {
return (
<li key={idx}>
<DocumentItem document={document} />
</li>
);
})}
</ul>
}
</Card>
</Flex>
);
}

function EmbeddingSpanInfo(props: {
span: Span;
spanAttributes: AttributeObject;
Expand Down
4 changes: 2 additions & 2 deletions app/src/pages/trace/__generated__/TracePageQuery.graphql.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 37 additions & 1 deletion integration-tests/trace/llama_index/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

import pytest
from gcsfs import GCSFileSystem
from llama_index import ServiceContext, StorageContext, load_index_from_storage
from llama_index import (
ServiceContext,
StorageContext,
load_index_from_storage,
)
from llama_index.agent import OpenAIAgent
from llama_index.callbacks import CallbackManager
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.graph_stores.simple import SimpleGraphStore
from llama_index.indices.postprocessor.cohere_rerank import CohereRerank
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.llms import OpenAI
from llama_index.query_engine import RetrieverQueryEngine
Expand All @@ -33,6 +38,10 @@
MESSAGE_ROLE,
OUTPUT_MIME_TYPE,
OUTPUT_VALUE,
RERANKER_INPUT_DOCUMENTS,
RERANKER_MODEL_NAME,
RERANKER_OUTPUT_DOCUMENTS,
RERANKER_TOP_K,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
Expand Down Expand Up @@ -241,3 +250,30 @@ def add(a: int, b: int) -> int:
"title": "multiply",
"type": "object",
}


@pytest.mark.parametrize("model_name", ["text-davinci-003"], indirect=True)
def test_cohere_rerank(index: VectorStoreIndex) -> None:
callback_handler = OpenInferenceTraceCallbackHandler(exporter=NoOpExporter())
service_context = ServiceContext.from_defaults(
callback_manager=CallbackManager(handlers=[callback_handler])
)
cohere_rerank = CohereRerank(top_n=2)
query_engine = index.as_query_engine(
similarity_top_k=5,
node_postprocessors=[cohere_rerank],
service_context=service_context,
)
query_engine.query("How should timestamps be formatted?")

spans = {span.name: span for span in callback_handler.get_spans()}
assert "reranking" in spans
reranker_span = spans["reranking"]
assert reranker_span.span_kind == SpanKind.RERANKER
assert (
len(reranker_span.attributes[RERANKER_INPUT_DOCUMENTS])
== query_engine.retriever.similarity_top_k
)
assert len(reranker_span.attributes[RERANKER_OUTPUT_DOCUMENTS]) == cohere_rerank.top_n
assert reranker_span.attributes[RERANKER_TOP_K] == cohere_rerank.top_n
assert reranker_span.attributes[RERANKER_MODEL_NAME] == cohere_rerank.model
1 change: 1 addition & 0 deletions src/phoenix/server/api/types/Span.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class SpanKind(Enum):
retriever = trace_schema.SpanKind.RETRIEVER
embedding = trace_schema.SpanKind.EMBEDDING
agent = trace_schema.SpanKind.AGENT
reranker = trace_schema.SpanKind.RERANKER
unknown = trace_schema.SpanKind.UNKNOWN

@classmethod
Expand Down
34 changes: 26 additions & 8 deletions src/phoenix/trace/llama_index/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
MESSAGE_ROLE,
OUTPUT_MIME_TYPE,
OUTPUT_VALUE,
RERANKER_INPUT_DOCUMENTS,
RERANKER_MODEL_NAME,
RERANKER_OUTPUT_DOCUMENTS,
RERANKER_QUERY,
RERANKER_TOP_K,
RETRIEVAL_DOCUMENTS,
TOOL_DESCRIPTION,
TOOL_NAME,
Expand Down Expand Up @@ -82,6 +87,7 @@ class CBEventData(TypedDict, total=False):
def payload_to_semantic_attributes(
event_type: CBEventType,
payload: Dict[str, Any],
is_event_end: bool = False,
) -> Dict[str, Any]:
"""
Converts a LLMapp payload to a dictionary of semantic conventions compliant attributes.
Expand All @@ -95,10 +101,10 @@ def payload_to_semantic_attributes(
{EMBEDDING_TEXT: text, EMBEDDING_VECTOR: vector}
for text, vector in zip(payload[EventPayload.CHUNKS], payload[EventPayload.EMBEDDINGS])
]
if EventPayload.QUERY_STR in payload:
if event_type is not CBEventType.RERANKING and EventPayload.QUERY_STR in payload:
attributes[INPUT_VALUE] = payload[EventPayload.QUERY_STR]
attributes[INPUT_MIME_TYPE] = MimeType.TEXT
if EventPayload.NODES in payload:
if event_type is not CBEventType.RERANKING and EventPayload.NODES in payload:
attributes[RETRIEVAL_DOCUMENTS] = [
{
DOCUMENT_ID: node_with_score.node.node_id,
Expand Down Expand Up @@ -128,11 +134,22 @@ def payload_to_semantic_attributes(
if (usage := getattr(raw, "usage", None)) is not None:
attributes.update(_get_token_counts(usage))
if event_type is CBEventType.RERANKING:
... # TODO
# if EventPayload.TOP_K in payload:
# attributes[RERANKING_TOP_K] = payload[EventPayload.TOP_K]
# if EventPayload.MODEL_NAME in payload:
# attributes[RERANKING_MODEL_NAME] = payload[EventPayload.MODEL_NAME]
if EventPayload.TOP_K in payload:
attributes[RERANKER_TOP_K] = payload[EventPayload.TOP_K]
if EventPayload.MODEL_NAME in payload:
attributes[RERANKER_MODEL_NAME] = payload[EventPayload.MODEL_NAME]
if EventPayload.QUERY_STR in payload:
attributes[RERANKER_QUERY] = payload[EventPayload.QUERY_STR]
if nodes := payload.get(EventPayload.NODES):
attributes[RERANKER_OUTPUT_DOCUMENTS if is_event_end else RERANKER_INPUT_DOCUMENTS] = [
{
DOCUMENT_ID: node_with_score.node.node_id,
DOCUMENT_SCORE: node_with_score.score,
DOCUMENT_CONTENT: node_with_score.node.text,
DOCUMENT_METADATA: node_with_score.node.metadata,
}
for node_with_score in nodes
]
if EventPayload.TOOL in payload:
tool_metadata = cast(ToolMetadata, payload.get(EventPayload.TOOL))
attributes[TOOL_NAME] = tool_metadata.name
Expand Down Expand Up @@ -221,7 +238,7 @@ def on_event_end(
# Parse the payload to extract the parameters
if payload is not None:
event_data["attributes"].update(
payload_to_semantic_attributes(event_type, payload),
payload_to_semantic_attributes(event_type, payload, is_event_end=True),
)

def start_trace(self, trace_id: Optional[str] = None) -> None:
Expand Down Expand Up @@ -351,6 +368,7 @@ def _get_span_kind(event_type: CBEventType) -> SpanKind:
CBEventType.RETRIEVE: SpanKind.RETRIEVER,
CBEventType.FUNCTION_CALL: SpanKind.TOOL,
CBEventType.AGENT_STEP: SpanKind.AGENT,
CBEventType.RERANKING: SpanKind.RERANKER,
}.get(event_type, SpanKind.CHAIN)


Expand Down
Loading