Skip to content

Commit

Permalink
feat(community): Add property extraction for Nodes and Relationships …
Browse files Browse the repository at this point in the history
…in LLMGraphTransformer (#7256)

Co-authored-by: Jacob Lee <jacoblee93@gmail.com>
  • Loading branch information
gvasilei and jacoblee93 authored Dec 3, 2024
1 parent ad1cf28 commit 3b36a33
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
test.skip("convertToGraphDocuments", async () => {
const model = new ChatOpenAI({
temperature: 0,
modelName: "gpt-4-turbo-preview",
modelName: "gpt-4o-mini",
});

const llmGraphTransformer = new LLMGraphTransformer({
Expand All @@ -22,14 +22,12 @@ test.skip("convertToGraphDocuments", async () => {
const result = await llmGraphTransformer.convertToGraphDocuments([
new Document({ pageContent: "Elon Musk is suing OpenAI" }),
]);

// console.log(result);
});

test("convertToGraphDocuments with allowed", async () => {
const model = new ChatOpenAI({
temperature: 0,
modelName: "gpt-4-turbo-preview",
modelName: "gpt-4o-mini",
});

const llmGraphTransformer = new LLMGraphTransformer({
Expand All @@ -42,8 +40,6 @@ test("convertToGraphDocuments with allowed", async () => {
new Document({ pageContent: "Elon Musk is suing OpenAI" }),
]);

// console.log(JSON.stringify(result));

expect(result).toEqual([
new GraphDocument({
nodes: [
Expand All @@ -68,7 +64,7 @@ test("convertToGraphDocuments with allowed", async () => {
test("convertToGraphDocuments with allowed lowercased", async () => {
const model = new ChatOpenAI({
temperature: 0,
modelName: "gpt-4-turbo-preview",
modelName: "gpt-4o-mini",
});

const llmGraphTransformer = new LLMGraphTransformer({
Expand All @@ -81,8 +77,6 @@ test("convertToGraphDocuments with allowed lowercased", async () => {
new Document({ pageContent: "Elon Musk is suing OpenAI" }),
]);

// console.log(JSON.stringify(result));

expect(result).toEqual([
new GraphDocument({
nodes: [
Expand All @@ -103,3 +97,82 @@ test("convertToGraphDocuments with allowed lowercased", async () => {
}),
]);
});

test("convertToGraphDocuments with node properties", async () => {
const model = new ChatOpenAI({
temperature: 0,
modelName: "gpt-4o-mini",
});

const llmGraphTransformer = new LLMGraphTransformer({
llm: model,
allowedNodes: ["Person"],
allowedRelationships: ["KNOWS"],
nodeProperties: ["age", "country"],
});

const result = await llmGraphTransformer.convertToGraphDocuments([
new Document({ pageContent: "John is 30 years old and lives in Spain" }),
]);

expect(result).toEqual([
new GraphDocument({
nodes: [
new Node({
id: "John",
type: "Person",
properties: {
age: "30",
country: "Spain",
},
}),
],
relationships: [],
source: new Document({
pageContent: "John is 30 years old and lives in Spain",
metadata: {},
}),
}),
]);
});

test("convertToGraphDocuments with relationship properties", async () => {
const model = new ChatOpenAI({
temperature: 0,
modelName: "gpt-4o-mini",
});

const llmGraphTransformer = new LLMGraphTransformer({
llm: model,
allowedNodes: ["Person"],
allowedRelationships: ["KNOWS"],
relationshipProperties: ["since"],
});

const result = await llmGraphTransformer.convertToGraphDocuments([
new Document({ pageContent: "John has known Mary since 2020" }),
]);

expect(result).toEqual([
new GraphDocument({
nodes: [
new Node({ id: "John", type: "Person" }),
new Node({ id: "Mary", type: "Person" }),
],
relationships: [
new Relationship({
source: new Node({ id: "John", type: "Person" }),
target: new Node({ id: "Mary", type: "Person" }),
type: "KNOWS",
properties: {
since: "2020",
},
}),
],
source: new Document({
pageContent: "John has known Mary since 2020",
metadata: {},
}),
}),
]);
});
149 changes: 117 additions & 32 deletions libs/langchain-community/src/experimental/graph_transformers/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ interface OptionalEnumFieldProps {
fieldKwargs?: object;
}

interface SchemaProperty {
key: string;
value: string;
}

function toTitleCase(str: string): string {
return str
.split(" ")
Expand Down Expand Up @@ -86,50 +91,112 @@ function createOptionalEnumType({
return schema;
}

function createSchema(allowedNodes: string[], allowedRelationships: string[]) {
function createNodeSchema(allowedNodes: string[], nodeProperties: string[]) {
const nodeSchema = z.object({
id: z.string(),
type: createOptionalEnumType({
enumValues: allowedNodes,
description: "The type or label of the node.",
}),
});

return nodeProperties.length > 0
? nodeSchema.extend({
properties: z
.array(
z.object({
key: createOptionalEnumType({
enumValues: nodeProperties,
description: "Property key.",
}),
value: z.string().describe("Extracted value."),
})
)
.describe(`List of node properties`),
})
: nodeSchema;
}

function createRelationshipSchema(
allowedNodes: string[],
allowedRelationships: string[],
relationshipProperties: string[]
) {
const relationshipSchema = z.object({
sourceNodeId: z.string(),
sourceNodeType: createOptionalEnumType({
enumValues: allowedNodes,
description: "The source node of the relationship.",
}),
relationshipType: createOptionalEnumType({
enumValues: allowedRelationships,
description: "The type of the relationship.",
isRel: true,
}),
targetNodeId: z.string(),
targetNodeType: createOptionalEnumType({
enumValues: allowedNodes,
description: "The target node of the relationship.",
}),
});

return relationshipProperties.length > 0
? relationshipSchema.extend({
properties: z
.array(
z.object({
key: createOptionalEnumType({
enumValues: relationshipProperties,
description: "Property key.",
}),
value: z.string().describe("Extracted value."),
})
)
.describe(`List of relationship properties`),
})
: relationshipSchema;
}

function createSchema(
allowedNodes: string[],
allowedRelationships: string[],
nodeProperties: string[],
relationshipProperties: string[]
) {
const nodeSchema = createNodeSchema(allowedNodes, nodeProperties);
const relationshipSchema = createRelationshipSchema(
allowedNodes,
allowedRelationships,
relationshipProperties
);

const dynamicGraphSchema = z.object({
nodes: z
.array(
z.object({
id: z.string(),
type: createOptionalEnumType({
enumValues: allowedNodes,
description: "The type or label of the node.",
}),
})
)
.describe("List of nodes"),
nodes: z.array(nodeSchema).describe("List of nodes"),
relationships: z
.array(
z.object({
sourceNodeId: z.string(),
sourceNodeType: createOptionalEnumType({
enumValues: allowedNodes,
description: "The source node of the relationship.",
}),
relationshipType: createOptionalEnumType({
enumValues: allowedRelationships,
description: "The type of the relationship.",
isRel: true,
}),
targetNodeId: z.string(),
targetNodeType: createOptionalEnumType({
enumValues: allowedNodes,
description: "The target node of the relationship.",
}),
})
)
.array(relationshipSchema)
.describe("List of relationships."),
});

return dynamicGraphSchema;
}

function convertPropertiesToRecord(
properties: SchemaProperty[]
): Record<string, string> {
return properties.reduce((accumulator: Record<string, string>, prop) => {
accumulator[prop.key] = prop.value;
return accumulator;
}, {});
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
function mapToBaseNode(node: any): Node {
return new Node({
id: node.id,
type: node.type ? toTitleCase(node.type) : "",
properties: node.properties
? convertPropertiesToRecord(node.properties)
: {},
});
}

Expand All @@ -149,6 +216,9 @@ function mapToBaseRelationship(relationship: any): Relationship {
: "",
}),
type: relationship.relationshipType.replace(" ", "_").toUpperCase(),
properties: relationship.properties
? convertPropertiesToRecord(relationship.properties)
: {},
});
}

Expand All @@ -158,6 +228,8 @@ export interface LLMGraphTransformerProps {
allowedRelationships?: string[];
prompt?: ChatPromptTemplate;
strictMode?: boolean;
nodeProperties?: string[];
relationshipProperties?: string[];
}

export class LLMGraphTransformer {
Expand All @@ -170,12 +242,18 @@ export class LLMGraphTransformer {

strictMode: boolean;

nodeProperties: string[];

relationshipProperties: string[];

constructor({
llm,
allowedNodes = [],
allowedRelationships = [],
prompt = DEFAULT_PROMPT,
strictMode = true,
nodeProperties = [],
relationshipProperties = [],
}: LLMGraphTransformerProps) {
if (typeof llm.withStructuredOutput !== "function") {
throw new Error(
Expand All @@ -186,9 +264,16 @@ export class LLMGraphTransformer {
this.allowedNodes = allowedNodes;
this.allowedRelationships = allowedRelationships;
this.strictMode = strictMode;
this.nodeProperties = nodeProperties;
this.relationshipProperties = relationshipProperties;

// Define chain
const schema = createSchema(allowedNodes, allowedRelationships);
const schema = createSchema(
allowedNodes,
allowedRelationships,
nodeProperties,
relationshipProperties
);
const structuredLLM = llm.withStructuredOutput(zodToJsonSchema(schema));
this.chain = prompt.pipe(structuredLLM);
}
Expand Down

0 comments on commit 3b36a33

Please sign in to comment.