Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions docs/config/yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,9 @@ Where to put all vectors for the system. Configured for lancedb by default. This

The supported embeddings are:

- `text_unit.text`
- `document.text`
- `entity.title`
- `entity.description`
- `relationship.description`
- `community.title`
- `community.summary`
- `community.full_content`
- `text_unit_text`
- `entity_description`
- `community_full_content`

For example:

Expand All @@ -199,12 +194,12 @@ vector_store:
db_uri: output/lancedb
index_prefix: "christmas-carol"
embeddings_schema:
text_unit.text:
text_unit_text:
index_name: "text-unit-embeddings"
id_field: "id_custom"
vector_field: "vector_custom"
vector_size: 3072
entity.description:
entity_description:
id_field: "id_custom"

```
Expand All @@ -224,14 +219,9 @@ By default, the GraphRAG indexer will only export embeddings required for our qu

Supported embeddings names are:

- `text_unit.text`
- `document.text`
- `entity.title`
- `entity.description`
- `relationship.description`
- `community.title`
- `community.summary`
- `community.full_content`
- `text_unit_text`
- `entity_description`
- `community_full_content`

#### Fields

Expand Down
9 changes: 5 additions & 4 deletions docs/examples_notebooks/api_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
"from pathlib import Path\n",
"from pprint import pprint\n",
"\n",
"import graphrag.api as api\n",
"import pandas as pd\n",
"from graphrag.config.load_config import load_config\n",
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n",
"\n",
"import graphrag.api as api"
]
},
{
Expand Down Expand Up @@ -170,7 +171,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "graphrag-monorepo",
"language": "python",
"name": "python3"
},
Expand All @@ -184,7 +185,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.9"
}
},
"nbformat": 4,
Expand Down
2 changes: 0 additions & 2 deletions docs/examples_notebooks/index_migration_to_v1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@
"tokenizer = get_tokenizer(model_config)\n",
"\n",
"await generate_text_embeddings(\n",
" documents=None,\n",
" relationships=None,\n",
" text_units=final_text_units,\n",
" entities=final_entities,\n",
" community_reports=final_community_reports,\n",
Expand Down
9 changes: 5 additions & 4 deletions docs/examples_notebooks/input_documents.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
"from pathlib import Path\n",
"from pprint import pprint\n",
"\n",
"import graphrag.api as api\n",
"import pandas as pd\n",
"from graphrag.config.load_config import load_config\n",
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n",
"\n",
"import graphrag.api as api"
]
},
{
Expand Down Expand Up @@ -171,7 +172,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "graphrag",
"display_name": "graphrag-monorepo",
"language": "python",
"name": "python3"
},
Expand All @@ -185,7 +186,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
"version": "3.12.9"
}
},
"nbformat": 4,
Expand Down
20 changes: 5 additions & 15 deletions packages/graphrag/graphrag/config/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,12 @@

"""A module containing embeddings values."""

entity_title_embedding = "entity.title"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
document_text_embedding = "document.text"
community_title_embedding = "community.title"
community_summary_embedding = "community.summary"
community_full_content_embedding = "community.full_content"
text_unit_text_embedding = "text_unit.text"
entity_description_embedding = "entity_description"
community_full_content_embedding = "community_full_content"
text_unit_text_embedding = "text_unit_text"

all_embeddings: set[str] = {
entity_title_embedding,
entity_description_embedding,
relationship_description_embedding,
document_text_embedding,
community_title_embedding,
community_summary_embedding,
community_full_content_embedding,
text_unit_text_embedding,
}
Expand Down Expand Up @@ -47,5 +37,5 @@ def create_index_name(
raise KeyError(msg)

if index_prefix:
return f"{index_prefix}-{embedding_name}".replace(".", "-")
return embedding_name.replace(".", "-")
return f"{index_prefix}-{embedding_name}"
return embedding_name
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,8 @@
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import (
community_full_content_embedding,
community_summary_embedding,
community_title_embedding,
create_index_name,
document_text_embedding,
entity_description_embedding,
entity_title_embedding,
relationship_description_embedding,
text_unit_text_embedding,
)
from graphrag.config.models.graph_rag_config import GraphRagConfig
Expand Down Expand Up @@ -47,29 +42,14 @@ async def run_workflow(
logger.info("Workflow started: generate_text_embeddings")
embedded_fields = config.embed_text.names
logger.info("Embedding the following fields: %s", embedded_fields)
documents = None
relationships = None
text_units = None
entities = None
community_reports = None
if document_text_embedding in embedded_fields:
documents = await load_table_from_storage("documents", context.output_storage)
if relationship_description_embedding in embedded_fields:
relationships = await load_table_from_storage(
"relationships", context.output_storage
)
if text_unit_text_embedding in embedded_fields:
text_units = await load_table_from_storage("text_units", context.output_storage)
if (
entity_title_embedding in embedded_fields
or entity_description_embedding in embedded_fields
):
if entity_description_embedding in embedded_fields:
entities = await load_table_from_storage("entities", context.output_storage)
if (
community_title_embedding in embedded_fields
or community_summary_embedding in embedded_fields
or community_full_content_embedding in embedded_fields
):
if community_full_content_embedding in embedded_fields:
community_reports = await load_table_from_storage(
"community_reports", context.output_storage
)
Expand All @@ -87,8 +67,6 @@ async def run_workflow(
tokenizer = get_tokenizer(model_config)

output = await generate_text_embeddings(
documents=documents,
relationships=relationships,
text_units=text_units,
entities=entities,
community_reports=community_reports,
Expand All @@ -115,8 +93,6 @@ async def run_workflow(


async def generate_text_embeddings(
documents: pd.DataFrame | None,
relationships: pd.DataFrame | None,
text_units: pd.DataFrame | None,
entities: pd.DataFrame | None,
community_reports: pd.DataFrame | None,
Expand All @@ -131,26 +107,12 @@ async def generate_text_embeddings(
) -> dict[str, pd.DataFrame]:
"""All the steps to generate all embeddings."""
embedding_param_map = {
document_text_embedding: {
"data": documents.loc[:, ["id", "text"]] if documents is not None else None,
"embed_column": "text",
},
relationship_description_embedding: {
"data": relationships.loc[:, ["id", "description"]]
if relationships is not None
else None,
"embed_column": "description",
},
text_unit_text_embedding: {
"data": text_units.loc[:, ["id", "text"]]
if text_units is not None
else None,
"embed_column": "text",
},
entity_title_embedding: {
"data": entities.loc[:, ["id", "title"]] if entities is not None else None,
"embed_column": "title",
},
entity_description_embedding: {
"data": entities.loc[:, ["id", "title", "description"]].assign(
title_description=lambda df: df["title"] + ":" + df["description"]
Expand All @@ -159,18 +121,6 @@ async def generate_text_embeddings(
else None,
"embed_column": "title_description",
},
community_title_embedding: {
"data": community_reports.loc[:, ["id", "title"]]
if community_reports is not None
else None,
"embed_column": "title",
},
community_summary_embedding: {
"data": community_reports.loc[:, ["id", "summary"]]
if community_reports is not None
else None,
"embed_column": "summary",
},
community_full_content_embedding: {
"data": community_reports.loc[:, ["id", "full_content"]]
if community_reports is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ async def run_workflow(
output_storage, _, _ = get_update_storages(
config, context.state["update_timestamp"]
)

final_documents_df = context.state["incremental_update_final_documents"]
merged_relationships_df = context.state["incremental_update_merged_relationships"]
merged_text_units = context.state["incremental_update_merged_text_units"]
merged_entities_df = context.state["incremental_update_merged_entities"]
merged_community_reports = context.state[
Expand All @@ -50,8 +47,6 @@ async def run_workflow(
tokenizer = get_tokenizer(model_config)

result = await generate_text_embeddings(
documents=final_documents_df,
relationships=merged_relationships_df,
text_units=merged_text_units,
entities=merged_entities_df,
community_reports=merged_community_reports,
Expand Down
6 changes: 3 additions & 3 deletions tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@
],
"max_runtime": 150,
"expected_artifacts": [
"embeddings.text_unit.text.parquet",
"embeddings.entity.description.parquet",
"embeddings.community.full_content.parquet"
"embeddings.text_unit_text.parquet",
"embeddings.entity_description.parquet",
"embeddings.community_full_content.parquet"
]
}
},
Expand Down
6 changes: 3 additions & 3 deletions tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@
],
"max_runtime": 150,
"expected_artifacts": [
"embeddings.text_unit.text.parquet",
"embeddings.entity.description.parquet",
"embeddings.community.full_content.parquet"
"embeddings.text_unit_text.parquet",
"embeddings.entity_description.parquet",
"embeddings.community_full_content.parquet"
]
}
},
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/utils/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


def test_create_index_name():
collection = create_index_name("default", "entity.title")
assert collection == "default-entity-title"
collection = create_index_name("default", "entity_description")
assert collection == "default-entity_description"


def test_create_index_name_invalid_embedding_throws():
Expand All @@ -16,5 +16,5 @@ def test_create_index_name_invalid_embedding_throws():


def test_create_index_name_invalid_embedding_does_not_throw():
collection = create_index_name("default", "invalid.name", validate=False)
assert collection == "default-invalid-name"
collection = create_index_name("default", "invalid_name", validate=False)
assert collection == "default-invalid_name"
9 changes: 5 additions & 4 deletions tests/verbs/test_create_community_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS
from graphrag.index.operations.summarize_communities.community_reports_extractor import (
CommunityReportResponse,
FindingModel,
)
from graphrag.index.workflows.create_community_reports import (
run_workflow,
)
from graphrag.utils.storage import load_table_from_storage

from graphrag.index.operations.summarize_communities.community_reports_extractor import (
CommunityReportResponse,
FindingModel,
)

from .util import (
DEFAULT_MODEL_CONFIG,
compare_outputs,
Expand Down
11 changes: 1 addition & 10 deletions tests/verbs/test_generate_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,9 @@ async def test_generate_text_embeddings():

# entity description should always be here, let's assert its format
entity_description_embeddings = await load_table_from_storage(
"embeddings.entity.description", context.output_storage
"embeddings.entity_description", context.output_storage
)

assert len(entity_description_embeddings.columns) == 2
assert "id" in entity_description_embeddings.columns
assert "embedding" in entity_description_embeddings.columns

# every other embedding is optional but we've turned them all on, so check a random one
document_text_embeddings = await load_table_from_storage(
"embeddings.document.text", context.output_storage
)

assert len(document_text_embeddings.columns) == 2
assert "id" in document_text_embeddings.columns
assert "embedding" in document_text_embeddings.columns
3 changes: 2 additions & 1 deletion unified-search-app/app/app_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
from typing import TYPE_CHECKING

import graphrag.api as api
import streamlit as st
from knowledge_loader.data_sources.loader import (
create_datasource,
Expand All @@ -18,6 +17,8 @@
from state.session_variables import SessionVariables
from ui.search import display_search_result

import graphrag.api as api

if TYPE_CHECKING:
import pandas as pd

Expand Down
Loading