Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix recursive report generation #1669

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250130182248267480.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix report generation recursion."
}
31 changes: 9 additions & 22 deletions graphrag/index/flows/create_final_community_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from graphrag.config.enums import AsyncType
from graphrag.index.operations.summarize_communities import (
prepare_community_reports,
restore_community_hierarchy,
summarize_communities,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor import (
from graphrag.index.operations.summarize_communities.community_reports_extractor.prep_community_report_context import (
prep_community_report_context,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.schemas import (
Expand All @@ -39,9 +38,6 @@
NODE_ID,
NODE_NAME,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
get_levels,
)


async def create_final_community_reports(
Expand All @@ -66,35 +62,26 @@ async def create_final_community_reports(
if claims_input is not None:
claims = _prep_claims(claims_input)

max_input_length = summarization_strategy.get(
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
)

local_contexts = prepare_community_reports(
nodes,
edges,
claims,
callbacks,
summarization_strategy.get("max_input_length", 16_000),
max_input_length,
)

community_hierarchy = restore_community_hierarchy(nodes)
levels = get_levels(nodes)

level_contexts = []
for level in levels:
level_context = prep_community_report_context(
local_context_df=local_contexts,
community_hierarchy_df=community_hierarchy,
level=level,
max_tokens=summarization_strategy.get(
"max_input_tokens", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
),
)
level_contexts.append(level_context)

community_reports = await summarize_communities(
nodes,
local_contexts,
level_contexts,
prep_community_report_context,
callbacks,
cache,
summarization_strategy,
max_input_length=max_input_length,
async_mode=async_mode,
num_threads=num_threads,
)
Expand Down
31 changes: 8 additions & 23 deletions graphrag/index/flows/create_final_community_reports_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@
from graphrag.config import defaults
from graphrag.config.enums import AsyncType
from graphrag.index.operations.summarize_communities import (
restore_community_hierarchy,
summarize_communities,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
get_levels,
)
from graphrag.index.operations.summarize_communities_text.context_builder import (
prep_community_report_context,
prep_local_context,
Expand Down Expand Up @@ -46,36 +42,25 @@ async def create_final_community_reports_text(
nodes_df = nodes_input.merge(entities_df, on="id")
nodes = nodes_df.loc[nodes_df.loc[:, "community"] != -1]

max_input_length = summarization_strategy.get("max_input_length", 16_000)

# TEMP: forcing override of the prompt until we can put it into config
summarization_strategy["extraction_prompt"] = COMMUNITY_REPORT_PROMPT
# build initial local context for all communities

max_input_length = summarization_strategy.get(
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
)

local_contexts = prep_local_context(
communities, text_units, nodes, max_input_length
)

community_hierarchy = restore_community_hierarchy(nodes)
levels = get_levels(nodes)

level_contexts = []
for level in levels:
level_context = prep_community_report_context(
local_context_df=local_contexts,
community_hierarchy_df=community_hierarchy,
level=level,
max_tokens=summarization_strategy.get(
"max_input_tokens", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
),
)
level_contexts.append(level_context)

community_reports = await summarize_communities(
nodes,
local_contexts,
level_contexts,
prep_community_report_context,
callbacks,
cache,
summarization_strategy,
max_input_length=max_input_length,
async_mode=async_mode,
num_threads=num_threads,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


def prep_community_report_context(
report_df: pd.DataFrame | None,
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
level: int,
Expand All @@ -42,8 +43,6 @@ def prep_community_report_context(
- Check if local context fits within the limit, if yes use local context
- If local context exceeds the limit, iteratively replace local context with sub-community reports, starting from the biggest sub-community
"""
report_df = pd.DataFrame()

# Filter by community level
level_context_df = local_context_df.loc[
local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level
Expand All @@ -62,7 +61,7 @@ def prep_community_report_context(
if invalid_context_df.empty:
return valid_context_df

if report_df.empty:
if report_df is None or report_df.empty:
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
invalid_context_df, max_tokens
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""A module containing create_community_reports and load_strategy methods definition."""

import logging
from collections.abc import Callable

import pandas as pd

Expand All @@ -12,6 +13,12 @@
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
get_levels,
)
from graphrag.index.operations.summarize_communities.restore_community_hierarchy import (
restore_community_hierarchy,
)
from graphrag.index.operations.summarize_communities.typing import (
CommunityReport,
CommunityReportsStrategy,
Expand All @@ -24,11 +31,13 @@


async def summarize_communities(
nodes: pd.DataFrame,
local_contexts,
level_contexts,
level_context_builder: Callable,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
strategy: dict,
max_input_length: int,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
):
Expand All @@ -37,6 +46,20 @@ async def summarize_communities(
tick = progress_ticker(callbacks.progress, len(local_contexts))
runner = load_strategy(strategy["type"])

community_hierarchy = restore_community_hierarchy(nodes)
levels = get_levels(nodes)

level_contexts = []
for level in levels:
level_context = level_context_builder(
pd.DataFrame(reports),
community_hierarchy_df=community_hierarchy,
local_context_df=local_contexts,
level=level,
max_tokens=max_input_length,
)
level_contexts.append(level_context)

for level_context in level_contexts:

async def run_generate(record):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def prep_local_context(


def prep_community_report_context(
local_context_df: pd.DataFrame,
report_df: pd.DataFrame | None,
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
level: int,
report_df: pd.DataFrame | None = None,
max_tokens: int = 16000,
) -> pd.DataFrame:
"""
Expand Down
Loading