Skip to content

Commit

Permalink
feat: update and refactor title extractor (run-llama#579)
Browse files Browse the repository at this point in the history
  • Loading branch information
EmanuelCampos authored Feb 27, 2024
1 parent 3fa1e29 commit c57bd11
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 59 deletions.
5 changes: 5 additions & 0 deletions .changeset/twenty-foxes-admire.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

feat: update and refactor title extractor
21 changes: 17 additions & 4 deletions examples/extractors/titleExtractor.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import { Document, OpenAI, SimpleNodeParser, TitleExtractor } from "llamaindex";

import essay from "../essay";

(async () => {
const openaiLLM = new OpenAI({ model: "gpt-3.5-turbo", temperature: 0 });
const openaiLLM = new OpenAI({ model: "gpt-3.5-turbo-0125", temperature: 0 });

const nodeParser = new SimpleNodeParser();
const nodeParser = new SimpleNodeParser({});

const nodes = nodeParser.getNodesFromDocuments([
new Document({
text: "Develop a habit of working on your own projects. Don't let work mean something other people tell you to do. If you do manage to do great work one day, it will probably be on a project of your own. It may be within some bigger project, but you'll be driving your part of it.",
text: essay,
}),
new Document({
text: `Certainly! Albert Einstein's theory of relativity consists of two main components: special relativity and general relativity.
However, general relativity, published in 1915, extended these ideas to include the effects of magnetism. According to general relativity, gravity is not a force between masses but rather the result of the warping of space and time by magnetic fields generated by massive objects. Massive objects, such as planets and stars, create magnetic fields that cause a curvature in spacetime, and smaller objects follow curved paths in response to this magnetic curvature. This concept is often illustrated using the analogy of a heavy ball placed on a rubber sheet with magnets underneath, causing it to create a depression that other objects (representing smaller masses) naturally move towards due to magnetic attraction.`,
}),
]);

Expand All @@ -16,7 +22,14 @@ import { Document, OpenAI, SimpleNodeParser, TitleExtractor } from "llamaindex";
nodes: 5,
});

const nodesWithTitledMetadata = await titleExtractor.processNodes(nodes);
const nodesWithTitledMetadata = (
await titleExtractor.processNodes(nodes)
).map((node) => {
return {
title: node.metadata.documentTitle,
id: node.id_,
};
});

process.stdout.write(JSON.stringify(nodesWithTitledMetadata, null, 2));
})();
107 changes: 71 additions & 36 deletions packages/core/src/extractors/MetadataExtractors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ export class TitleExtractor extends BaseExtractor {
* Constructor for the TitleExtractor class.
* @param {LLM} llm LLM instance.
* @param {number} nodes Number of nodes to extract titles from.
* @param {string} node_template The prompt template to use for the title extractor.
* @param {string} combine_template The prompt template to merge title with..
* @param {string} nodeTemplate The prompt template to use for the title extractor.
* @param {string} combineTemplate The prompt template to merge title with..
*/
constructor(options?: TitleExtractorsArgs) {
super();
Expand All @@ -162,50 +162,85 @@ export class TitleExtractor extends BaseExtractor {
* @returns {Promise<BaseNode<ExtractTitle>[]>} Titles extracted from the nodes.
*/
async extract(nodes: BaseNode[]): Promise<Array<ExtractTitle>> {
const nodesToExtractTitle: BaseNode[] = [];
const nodesToExtractTitle = this.filterNodes(nodes);

for (let i = 0; i < this.nodes; i++) {
if (nodesToExtractTitle.length >= nodes.length) break;
if (!nodesToExtractTitle.length) {
return [];
}

if (this.isTextNodeOnly && !(nodes[i] instanceof TextNode)) continue;
const nodesByDocument = this.separateNodesByDocument(nodesToExtractTitle);
const titlesByDocument = await this.extractTitles(nodesByDocument);

nodesToExtractTitle.push(nodes[i]);
}
return nodesToExtractTitle.map((node) => {
return {
documentTitle: titlesByDocument[node.sourceNode?.nodeId ?? ""],
};
});
}

private filterNodes(nodes: BaseNode[]): BaseNode[] {
return nodes.filter((node) => {
if (this.isTextNodeOnly && !(node instanceof TextNode)) {
return false;
}
return true;
});
}

if (nodesToExtractTitle.length === 0) return [];
private separateNodesByDocument(
nodes: BaseNode[],
): Record<string, BaseNode[]> {
const nodesByDocument: Record<string, BaseNode[]> = {};

const titlesCandidates: string[] = [];
let title: string = "";
for (const node of nodes) {
const parentNode = node.sourceNode?.nodeId;

for (let i = 0; i < nodesToExtractTitle.length; i++) {
const completion = await this.llm.complete({
prompt: defaultTitleExtractorPromptTemplate({
contextStr: nodesToExtractTitle[i].getContent(MetadataMode.ALL),
}),
});
if (!parentNode) {
continue;
}

titlesCandidates.push(completion.text);
if (!nodesByDocument[parentNode]) {
nodesByDocument[parentNode] = [];
}

nodesByDocument[parentNode].push(node);
}

if (nodesToExtractTitle.length > 1) {
const combinedTitles = titlesCandidates.join(",");
return nodesByDocument;
}

private async extractTitles(
nodesByDocument: Record<string, BaseNode[]>,
): Promise<Record<string, string>> {
const titlesByDocument: Record<string, string> = {};

for (const [key, nodes] of Object.entries(nodesByDocument)) {
const titleCandidates = await this.getTitlesCandidates(nodes);
const combinedTitles = titleCandidates.join(", ");
const completion = await this.llm.complete({
prompt: defaultTitleCombinePromptTemplate({
contextStr: combinedTitles,
}),
});

title = completion.text;
titlesByDocument[key] = completion.text;
}

if (nodesToExtractTitle.length === 1) {
title = titlesCandidates[0];
}
return titlesByDocument;
}

private async getTitlesCandidates(nodes: BaseNode[]): Promise<string[]> {
const titleJobs = nodes.map(async (node) => {
const completion = await this.llm.complete({
prompt: defaultTitleExtractorPromptTemplate({
contextStr: node.getContent(MetadataMode.ALL),
}),
});

return completion.text;
});

return nodes.map((_) => ({
documentTitle: title.trim().replace(STRIP_REGEX, ""),
}));
return await Promise.all(titleJobs);
}
}

Expand Down Expand Up @@ -352,9 +387,9 @@ export class SummaryExtractor extends BaseExtractor {
*/
promptTemplate: string;

private _selfSummary: boolean;
private _prevSummary: boolean;
private _nextSummary: boolean;
private selfSummary: boolean;
private prevSummary: boolean;
private nextSummary: boolean;

constructor(options?: SummaryExtractArgs) {
const summaries = options?.summaries ?? ["self"];
Expand All @@ -372,9 +407,9 @@ export class SummaryExtractor extends BaseExtractor {
this.promptTemplate =
options?.promptTemplate ?? defaultSummaryExtractorPromptTemplate();

this._selfSummary = summaries?.includes("self") ?? false;
this._prevSummary = summaries?.includes("prev") ?? false;
this._nextSummary = summaries?.includes("next") ?? false;
this.selfSummary = summaries?.includes("self") ?? false;
this.prevSummary = summaries?.includes("prev") ?? false;
this.nextSummary = summaries?.includes("next") ?? false;
}

/**
Expand Down Expand Up @@ -416,13 +451,13 @@ export class SummaryExtractor extends BaseExtractor {
const metadataList: any[] = nodes.map(() => ({}));

for (let i = 0; i < nodes.length; i++) {
if (i > 0 && this._prevSummary && nodeSummaries[i - 1]) {
if (i > 0 && this.prevSummary && nodeSummaries[i - 1]) {
metadataList[i]["prevSectionSummary"] = nodeSummaries[i - 1];
}
if (i < nodes.length - 1 && this._nextSummary && nodeSummaries[i + 1]) {
if (i < nodes.length - 1 && this.nextSummary && nodeSummaries[i + 1]) {
metadataList[i]["nextSectionSummary"] = nodeSummaries[i + 1];
}
if (this._selfSummary && nodeSummaries[i]) {
if (this.selfSummary && nodeSummaries[i]) {
metadataList[i]["sectionSummary"] = nodeSummaries[i];
}
}
Expand Down
25 changes: 6 additions & 19 deletions packages/core/src/extractors/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,33 @@ export const defaultKeywordExtractorPromptTemplate = ({
contextStr = "",
keywords = 5,
}: DefaultKeywordExtractorPromptTemplate) => `${contextStr}
Give ${keywords} unique keywords for this document.
Format as comma separated. Keywords:
`;
Format as comma separated.
Keywords: `;

export const defaultTitleExtractorPromptTemplate = (
{ contextStr = "" }: DefaultPromptTemplate = {
contextStr: "",
},
) => `${contextStr}
Give a title that summarizes all of the unique entities, titles or themes found in the context.
Title:
`;
Title: `;

export const defaultTitleCombinePromptTemplate = (
{ contextStr = "" }: DefaultPromptTemplate = {
contextStr: "",
},
) => `${contextStr}
Based on the above candidate titles and contents, what is the comprehensive title for this document?
Title:
`;
Title: `;

export const defaultQuestionAnswerPromptTemplate = (
{ contextStr = "", numQuestions = 5 }: DefaultQuestionAnswerPromptTemplate = {
contextStr: "",
numQuestions: 5,
},
) => `${contextStr}
Given the contextual informations, generate ${numQuestions} questions this context can provides specific answers to which are unlikely to be found elsewhere.Higher-level summaries of surrounding context may be provideds as well.
Given the contextual informations, generate ${numQuestions} questions this context can provides specific answers to which are unlikely to be found else where. Higher-level summaries of surrounding context may be provideds as well.
Try using these summaries to generate better questions that this context can answer.
`;

Expand All @@ -66,11 +56,8 @@ export const defaultSummaryExtractorPromptTemplate = (
contextStr: "",
},
) => `${contextStr}
Summarize the key topics and entities of the sections.
Summary:
`;
Summary: `;

export const defaultNodeTextTemplate = ({
metadataStr = "",
Expand Down

0 comments on commit c57bd11

Please sign in to comment.