Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20260219144634370703.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "create_communities streaming"
}
2 changes: 1 addition & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,4 @@ graphrag-common (no internal deps)
├── graphrag-cache (common, storage)
├── graphrag-llm (cache, common)
└── graphrag (all of the above)
```
```
26 changes: 17 additions & 9 deletions packages/graphrag/graphrag/index/operations/cluster_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""A module containing cluster_graph method definition."""

import logging
from collections import defaultdict

import pandas as pd

Expand Down Expand Up @@ -34,12 +35,9 @@ def cluster_graph(

clusters: dict[int, dict[int, list[str]]] = {}
for level in levels:
result = {}
result: dict[int, list[str]] = defaultdict(list)
clusters[level] = result
for node_id, raw_community_id in node_id_to_community_map[level].items():
community_id = raw_community_id
if community_id not in result:
result[community_id] = []
for node_id, community_id in node_id_to_community_map[level].items():
result[community_id].append(node_id)

results: Communities = []
Expand All @@ -64,15 +62,25 @@ def _compute_leiden_communities(
# so we replicate that by normalizing direction then keeping last.
lo = edge_df[["source", "target"]].min(axis=1)
hi = edge_df[["source", "target"]].max(axis=1)
edge_df = edge_df.assign(source=lo, target=hi)
edge_df = edge_df.drop_duplicates(subset=["source", "target"], keep="last")
edge_df["source"] = lo
edge_df["target"] = hi
edge_df.drop_duplicates(subset=["source", "target"], keep="last", inplace=True)

if use_lcc:
edge_df = stable_lcc(edge_df)

weights = (
edge_df["weight"].astype(float)
if "weight" in edge_df.columns
else pd.Series(1.0, index=edge_df.index)
)
edge_list: list[tuple[str, str, float]] = sorted(
(str(row["source"]), str(row["target"]), float(row.get("weight", 1.0)))
for _, row in edge_df.iterrows()
zip(
edge_df["source"].astype(str),
edge_df["target"].astype(str),
weights,
strict=True,
)
)

community_mapping = hierarchical_leiden(
Expand Down
161 changes: 108 additions & 53 deletions packages/graphrag/graphrag/index/workflows/create_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

import logging
from datetime import datetime, timezone
from typing import cast
from typing import Any, cast
from uuid import uuid4

import numpy as np
import pandas as pd
from graphrag_storage.tables.table import Table

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.data_reader import DataReader
Expand All @@ -28,88 +29,122 @@ async def run_workflow(
"""All the steps to transform final communities."""
logger.info("Workflow started: create_communities")
reader = DataReader(context.output_table_provider)
entities = await reader.entities()
relationships = await reader.relationships()

max_cluster_size = config.cluster_graph.max_cluster_size
use_lcc = config.cluster_graph.use_lcc
seed = config.cluster_graph.seed

output = create_communities(
entities,
relationships,
max_cluster_size=max_cluster_size,
use_lcc=use_lcc,
seed=seed,
)

await context.output_table_provider.write_dataframe("communities", output)
async with (
context.output_table_provider.open("entities") as entities_table,
context.output_table_provider.open("communities") as communities_table,
):
sample_rows = await create_communities(
communities_table,
entities_table,
relationships,
max_cluster_size=max_cluster_size,
use_lcc=use_lcc,
seed=seed,
)

logger.info("Workflow completed: create_communities")
return WorkflowFunctionOutput(result=output)
return WorkflowFunctionOutput(result=sample_rows)


def create_communities(
entities: pd.DataFrame,
async def create_communities(
communities_table: Table,
entities_table: Table,
relationships: pd.DataFrame,
max_cluster_size: int,
use_lcc: bool,
seed: int | None = None,
) -> pd.DataFrame:
"""All the steps to transform final communities."""
) -> list[dict[str, Any]]:
"""Build communities from clustered relationships and stream rows to the table.

Args
----
communities_table: Table
Output table to write community rows to.
entities_table: Table
Table containing entity rows.
relationships: pd.DataFrame
Relationships DataFrame with source, target, weight,
text_unit_ids columns.
max_cluster_size: int
Maximum cluster size for hierarchical Leiden.
use_lcc: bool
Whether to restrict to the largest connected component.
seed: int | None
Random seed for deterministic clustering.

Returns
-------
list[dict[str, Any]]
Sample of up to 5 community rows for logging.
"""
clusters = cluster_graph(
relationships,
max_cluster_size,
use_lcc,
seed=seed,
)

title_to_entity_id: dict[str, str] = {}
async for row in entities_table:
title_to_entity_id[row["title"]] = row["id"]

communities = pd.DataFrame(
clusters, columns=pd.Index(["level", "community", "parent", "title"])
).explode("title")
communities["community"] = communities["community"].astype(int)

# aggregate entity ids for each community
entity_ids = communities.merge(entities, on="title", how="inner")
entity_map = communities[["community", "title"]].copy()
entity_map["entity_id"] = entity_map["title"].map(title_to_entity_id)
entity_ids = (
entity_ids.groupby("community").agg(entity_ids=("id", list)).reset_index()
entity_map
.dropna(subset=["entity_id"])
.groupby("community")
.agg(entity_ids=("entity_id", list))
.reset_index()
)

# aggregate relationships ids for each community
# these are limited to only those where the source and target are in the same community
max_level = communities["level"].max()
all_grouped = pd.DataFrame(
columns=["community", "level", "relationship_ids", "text_unit_ids"] # type: ignore
)
for level in range(max_level + 1):
communities_at_level = communities.loc[communities["level"] == level]
sources = relationships.merge(
communities_at_level, left_on="source", right_on="title", how="inner"
# aggregate relationship ids per community, limited to
# intra-community edges (source and target in the same community).
# Process one hierarchy level at a time to keep intermediate
# DataFrames small, then concat the grouped results once at the end.
level_results = []
for level in communities["level"].unique():
level_comms = communities[communities["level"] == level]
with_source = relationships.merge(
level_comms, left_on="source", right_on="title", how="inner"
)
targets = sources.merge(
communities_at_level, left_on="target", right_on="title", how="inner"
with_both = with_source.merge(
level_comms, left_on="target", right_on="title", how="inner"
)
matched = targets.loc[targets["community_x"] == targets["community_y"]]
text_units = matched.explode("text_unit_ids")
intra = with_both[with_both["community_x"] == with_both["community_y"]]
if intra.empty:
continue
grouped = (
text_units
.groupby(["community_x", "level_x", "parent_x"])
.agg(relationship_ids=("id", list), text_unit_ids=("text_unit_ids", list))
intra
.explode("text_unit_ids")
.groupby(["community_x", "parent_x"])
.agg(
relationship_ids=("id", list),
text_unit_ids=("text_unit_ids", list),
)
.reset_index()
)
grouped.rename(
columns={
"community_x": "community",
"level_x": "level",
"parent_x": "parent",
},
inplace=True,
)
all_grouped = pd.concat([
all_grouped,
grouped.loc[
:, ["community", "level", "parent", "relationship_ids", "text_unit_ids"]
],
])
grouped["level"] = level
level_results.append(grouped)

all_grouped = pd.concat(level_results, ignore_index=True).rename(
columns={
"community_x": "community",
"parent_x": "parent",
}
)

# deduplicate the lists
all_grouped["relationship_ids"] = all_grouped["relationship_ids"].apply(
Expand Down Expand Up @@ -146,7 +181,27 @@ def create_communities(
final_communities["period"] = datetime.now(timezone.utc).date().isoformat()
final_communities["size"] = final_communities.loc[:, "entity_ids"].apply(len)

return final_communities.loc[
:,
COMMUNITIES_FINAL_COLUMNS,
]
output = final_communities.loc[:, COMMUNITIES_FINAL_COLUMNS]
rows = output.to_dict("records")
sample_rows: list[dict[str, Any]] = []
for row in rows:
row = _sanitize_row(row)
await communities_table.write(row)
if len(sample_rows) < 5:
sample_rows.append(row)
return sample_rows


def _sanitize_row(row: dict[str, Any]) -> dict[str, Any]:
"""Convert numpy types to native Python types for table serialization."""
sanitized = {}
for key, value in row.items():
if isinstance(value, np.ndarray):
sanitized[key] = value.tolist()
elif isinstance(value, np.integer):
sanitized[key] = int(value)
elif isinstance(value, np.floating):
sanitized[key] = float(value)
else:
sanitized[key] = value
return sanitized
Loading
Loading