diff --git a/.semversioner/next-release/patch-20260219144634370703.json b/.semversioner/next-release/patch-20260219144634370703.json new file mode 100644 index 000000000..01e3e7555 --- /dev/null +++ b/.semversioner/next-release/patch-20260219144634370703.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "create_communities streaming" +} diff --git a/RELEASE.md b/RELEASE.md index c9280015f..b078ca3f1 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -184,4 +184,4 @@ graphrag-common (no internal deps) ├── graphrag-cache (common, storage) ├── graphrag-llm (cache, common) └── graphrag (all of the above) -``` +``` \ No newline at end of file diff --git a/packages/graphrag/graphrag/index/operations/cluster_graph.py b/packages/graphrag/graphrag/index/operations/cluster_graph.py index 745d6b514..d14db22aa 100644 --- a/packages/graphrag/graphrag/index/operations/cluster_graph.py +++ b/packages/graphrag/graphrag/index/operations/cluster_graph.py @@ -4,6 +4,7 @@ """A module containing cluster_graph method definition.""" import logging +from collections import defaultdict import pandas as pd @@ -34,12 +35,9 @@ def cluster_graph( clusters: dict[int, dict[int, list[str]]] = {} for level in levels: - result = {} + result: dict[int, list[str]] = defaultdict(list) clusters[level] = result - for node_id, raw_community_id in node_id_to_community_map[level].items(): - community_id = raw_community_id - if community_id not in result: - result[community_id] = [] + for node_id, community_id in node_id_to_community_map[level].items(): result[community_id].append(node_id) results: Communities = [] @@ -64,15 +62,25 @@ def _compute_leiden_communities( # so we replicate that by normalizing direction then keeping last. lo = edge_df[["source", "target"]].min(axis=1) hi = edge_df[["source", "target"]].max(axis=1) - edge_df = edge_df.assign(source=lo, target=hi) - edge_df = edge_df.drop_duplicates(subset=["source", "target"], keep="last") + edge_df["source"] = lo + edge_df["target"] = hi + edge_df.drop_duplicates(subset=["source", "target"], keep="last", inplace=True) if use_lcc: edge_df = stable_lcc(edge_df) + weights = ( + edge_df["weight"].astype(float) + if "weight" in edge_df.columns + else pd.Series(1.0, index=edge_df.index) + ) edge_list: list[tuple[str, str, float]] = sorted( - (str(row["source"]), str(row["target"]), float(row.get("weight", 1.0))) - for _, row in edge_df.iterrows() + zip( + edge_df["source"].astype(str), + edge_df["target"].astype(str), + weights, + strict=True, + ) ) community_mapping = hierarchical_leiden( diff --git a/packages/graphrag/graphrag/index/workflows/create_communities.py b/packages/graphrag/graphrag/index/workflows/create_communities.py index d80b78b1a..45476f941 100644 --- a/packages/graphrag/graphrag/index/workflows/create_communities.py +++ b/packages/graphrag/graphrag/index/workflows/create_communities.py @@ -5,11 +5,12 @@ import logging from datetime import datetime, timezone -from typing import cast +from typing import Any, cast from uuid import uuid4 import numpy as np import pandas as pd +from graphrag_storage.tables.table import Table from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.data_reader import DataReader @@ -28,34 +29,60 @@ async def run_workflow( """All the steps to transform final communities.""" logger.info("Workflow started: create_communities") reader = DataReader(context.output_table_provider) - entities = await reader.entities() relationships = await reader.relationships() + max_cluster_size = config.cluster_graph.max_cluster_size use_lcc = config.cluster_graph.use_lcc seed = config.cluster_graph.seed - output = create_communities( - entities, - relationships, - max_cluster_size=max_cluster_size, - use_lcc=use_lcc, - seed=seed, - ) - - await context.output_table_provider.write_dataframe("communities", output) + async with ( + context.output_table_provider.open("entities") as entities_table, + context.output_table_provider.open("communities") as communities_table, + ): + sample_rows = await create_communities( + communities_table, + entities_table, + relationships, + max_cluster_size=max_cluster_size, + use_lcc=use_lcc, + seed=seed, + ) logger.info("Workflow completed: create_communities") - return WorkflowFunctionOutput(result=output) + return WorkflowFunctionOutput(result=sample_rows) -def create_communities( - entities: pd.DataFrame, +async def create_communities( + communities_table: Table, + entities_table: Table, relationships: pd.DataFrame, max_cluster_size: int, use_lcc: bool, seed: int | None = None, -) -> pd.DataFrame: - """All the steps to transform final communities.""" +) -> list[dict[str, Any]]: + """Build communities from clustered relationships and stream rows to the table. + + Args + ---- + communities_table: Table + Output table to write community rows to. + entities_table: Table + Table containing entity rows. + relationships: pd.DataFrame + Relationships DataFrame with source, target, weight, + text_unit_ids columns. + max_cluster_size: int + Maximum cluster size for hierarchical Leiden. + use_lcc: bool + Whether to restrict to the largest connected component. + seed: int | None + Random seed for deterministic clustering. + + Returns + ------- + list[dict[str, Any]] + Sample of up to 5 community rows for logging. + """ clusters = cluster_graph( relationships, max_cluster_size, @@ -63,53 +90,61 @@ def create_communities( seed=seed, ) + title_to_entity_id: dict[str, str] = {} + async for row in entities_table: + title_to_entity_id[row["title"]] = row["id"] + communities = pd.DataFrame( clusters, columns=pd.Index(["level", "community", "parent", "title"]) ).explode("title") communities["community"] = communities["community"].astype(int) # aggregate entity ids for each community - entity_ids = communities.merge(entities, on="title", how="inner") + entity_map = communities[["community", "title"]].copy() + entity_map["entity_id"] = entity_map["title"].map(title_to_entity_id) entity_ids = ( - entity_ids.groupby("community").agg(entity_ids=("id", list)).reset_index() + entity_map + .dropna(subset=["entity_id"]) + .groupby("community") + .agg(entity_ids=("entity_id", list)) + .reset_index() ) - # aggregate relationships ids for each community - # these are limited to only those where the source and target are in the same community - max_level = communities["level"].max() - all_grouped = pd.DataFrame( - columns=["community", "level", "relationship_ids", "text_unit_ids"] # type: ignore - ) - for level in range(max_level + 1): - communities_at_level = communities.loc[communities["level"] == level] - sources = relationships.merge( - communities_at_level, left_on="source", right_on="title", how="inner" + # aggregate relationship ids per community, limited to + # intra-community edges (source and target in the same community). + # Process one hierarchy level at a time to keep intermediate + # DataFrames small, then concat the grouped results once at the end. + level_results = [] + for level in communities["level"].unique(): + level_comms = communities[communities["level"] == level] + with_source = relationships.merge( + level_comms, left_on="source", right_on="title", how="inner" ) - targets = sources.merge( - communities_at_level, left_on="target", right_on="title", how="inner" + with_both = with_source.merge( + level_comms, left_on="target", right_on="title", how="inner" ) - matched = targets.loc[targets["community_x"] == targets["community_y"]] - text_units = matched.explode("text_unit_ids") + intra = with_both[with_both["community_x"] == with_both["community_y"]] + if intra.empty: + continue grouped = ( - text_units - .groupby(["community_x", "level_x", "parent_x"]) - .agg(relationship_ids=("id", list), text_unit_ids=("text_unit_ids", list)) + intra + .explode("text_unit_ids") + .groupby(["community_x", "parent_x"]) + .agg( + relationship_ids=("id", list), + text_unit_ids=("text_unit_ids", list), + ) .reset_index() ) - grouped.rename( - columns={ - "community_x": "community", - "level_x": "level", - "parent_x": "parent", - }, - inplace=True, - ) - all_grouped = pd.concat([ - all_grouped, - grouped.loc[ - :, ["community", "level", "parent", "relationship_ids", "text_unit_ids"] - ], - ]) + grouped["level"] = level + level_results.append(grouped) + + all_grouped = pd.concat(level_results, ignore_index=True).rename( + columns={ + "community_x": "community", + "parent_x": "parent", + } + ) # deduplicate the lists all_grouped["relationship_ids"] = all_grouped["relationship_ids"].apply( @@ -146,7 +181,27 @@ def create_communities( final_communities["period"] = datetime.now(timezone.utc).date().isoformat() final_communities["size"] = final_communities.loc[:, "entity_ids"].apply(len) - return final_communities.loc[ - :, - COMMUNITIES_FINAL_COLUMNS, - ] + output = final_communities.loc[:, COMMUNITIES_FINAL_COLUMNS] + rows = output.to_dict("records") + sample_rows: list[dict[str, Any]] = [] + for row in rows: + row = _sanitize_row(row) + await communities_table.write(row) + if len(sample_rows) < 5: + sample_rows.append(row) + return sample_rows + + +def _sanitize_row(row: dict[str, Any]) -> dict[str, Any]: + """Convert numpy types to native Python types for table serialization.""" + sanitized = {} + for key, value in row.items(): + if isinstance(value, np.ndarray): + sanitized[key] = value.tolist() + elif isinstance(value, np.integer): + sanitized[key] = int(value) + elif isinstance(value, np.floating): + sanitized[key] = float(value) + else: + sanitized[key] = value + return sanitized diff --git a/tests/unit/indexing/test_cluster_graph.py b/tests/unit/indexing/test_cluster_graph.py new file mode 100644 index 000000000..cf717bc35 --- /dev/null +++ b/tests/unit/indexing/test_cluster_graph.py @@ -0,0 +1,295 @@ +# Copyright (C) 2026 Microsoft + +"""Tests for the cluster_graph operation. + +These tests pin down the behavior of cluster_graph and its internal +_compute_leiden_communities function so that refactoring (vectorizing +iterrows, reducing copies, etc.) can be verified against known output. +""" + +import pandas as pd +import pytest +from graphrag.index.operations.cluster_graph import ( + Communities, + cluster_graph, +) + + +def _make_edges( + rows: list[tuple[str, str, float]], +) -> pd.DataFrame: + """Build a minimal relationships DataFrame from (source, target, weight).""" + return pd.DataFrame([{"source": s, "target": t, "weight": w} for s, t, w in rows]) + + +def _node_sets(clusters: Communities) -> list[set[str]]: + """Extract sorted-by-level list of node sets from cluster output.""" + return [set(nodes) for _, _, _, nodes in clusters] + + +# ------------------------------------------------------------------- +# Basic clustering +# ------------------------------------------------------------------- + + +class TestClusterGraphBasic: + """Verify basic clustering on small synthetic graphs.""" + + def test_single_triangle(self): + """A single triangle should produce one community at level 0.""" + edges = _make_edges([("X", "Y", 1.0), ("X", "Z", 1.0), ("Y", "Z", 1.0)]) + clusters = cluster_graph(edges, max_cluster_size=10, use_lcc=False, seed=42) + + assert len(clusters) == 1 + level, _cid, parent, nodes = clusters[0] + assert level == 0 + assert parent == -1 + assert set(nodes) == {"X", "Y", "Z"} + + def test_two_disconnected_cliques(self): + """Two disconnected triangles should produce two communities.""" + edges = _make_edges([ + ("A", "B", 1.0), + ("A", "C", 1.0), + ("B", "C", 1.0), + ("D", "E", 1.0), + ("D", "F", 1.0), + ("E", "F", 1.0), + ]) + clusters = cluster_graph(edges, max_cluster_size=10, use_lcc=False, seed=42) + + assert len(clusters) == 2 + node_sets = _node_sets(clusters) + assert {"A", "B", "C"} in node_sets + assert {"D", "E", "F"} in node_sets + for level, _, parent, _ in clusters: + assert level == 0 + assert parent == -1 + + def test_lcc_filters_to_largest_component(self): + """With use_lcc=True, only the largest connected component is kept.""" + edges = _make_edges([ + ("A", "B", 1.0), + ("A", "C", 1.0), + ("B", "C", 1.0), + ("D", "E", 1.0), + ("D", "F", 1.0), + ("E", "F", 1.0), + ]) + clusters = cluster_graph(edges, max_cluster_size=10, use_lcc=True, seed=42) + + assert len(clusters) == 1 + all_nodes = set(clusters[0][3]) + assert len(all_nodes) == 3 + + +# ------------------------------------------------------------------- +# Edge normalization +# ------------------------------------------------------------------- + + +class TestEdgeNormalization: + """Verify that direction normalization and deduplication work.""" + + def test_reversed_edges_produce_same_result(self): + """Reversing all edge directions should yield identical clusters.""" + forward = _make_edges([ + ("A", "B", 1.0), + ("A", "C", 1.0), + ("B", "C", 1.0), + ("D", "E", 1.0), + ("D", "F", 1.0), + ("E", "F", 1.0), + ]) + backward = _make_edges([ + ("B", "A", 1.0), + ("C", "A", 1.0), + ("C", "B", 1.0), + ("E", "D", 1.0), + ("F", "D", 1.0), + ("F", "E", 1.0), + ]) + clusters_fwd = cluster_graph( + forward, max_cluster_size=10, use_lcc=False, seed=42 + ) + clusters_bwd = cluster_graph( + backward, max_cluster_size=10, use_lcc=False, seed=42 + ) + + assert _node_sets(clusters_fwd) == _node_sets(clusters_bwd) + + def test_duplicate_edges_are_deduped(self): + """A→B and B→A should be treated as one edge after normalization.""" + edges = _make_edges([ + ("A", "B", 1.0), + ("B", "A", 2.0), + ("A", "C", 1.0), + ("B", "C", 1.0), + ]) + clusters = cluster_graph(edges, max_cluster_size=10, use_lcc=False, seed=42) + + assert len(clusters) == 1 + assert set(clusters[0][3]) == {"A", "B", "C"} + + def test_missing_weight_defaults_to_one(self): + """Edges without a weight column should default to weight 1.0.""" + edges = pd.DataFrame({ + "source": ["A", "A", "B"], + "target": ["B", "C", "C"], + }) + clusters = cluster_graph(edges, max_cluster_size=10, use_lcc=False, seed=42) + + assert len(clusters) == 1 + assert set(clusters[0][3]) == {"A", "B", "C"} + + +# ------------------------------------------------------------------- +# Determinism +# ------------------------------------------------------------------- + + +class TestDeterminism: + """Verify that seeding produces reproducible results.""" + + def test_same_seed_same_result(self): + """Identical seed should yield identical output.""" + edges = _make_edges([ + ("A", "B", 1.0), + ("A", "C", 1.0), + ("B", "C", 1.0), + ("D", "E", 1.0), + ("D", "F", 1.0), + ("E", "F", 1.0), + ]) + c1 = cluster_graph(edges, max_cluster_size=10, use_lcc=False, seed=123) + c2 = cluster_graph(edges, max_cluster_size=10, use_lcc=False, seed=123) + + assert c1 == c2 + + def test_does_not_mutate_input(self): + """cluster_graph should not modify the input DataFrame.""" + edges = _make_edges([ + ("A", "B", 1.0), + ("A", "C", 1.0), + ("B", "C", 1.0), + ]) + original = edges.copy() + cluster_graph(edges, max_cluster_size=10, use_lcc=False, seed=42) + + pd.testing.assert_frame_equal(edges, original) + + +# ------------------------------------------------------------------- +# Output structure +# ------------------------------------------------------------------- + + +class TestOutputStructure: + """Verify the shape and types of the Communities output.""" + + def test_output_tuple_structure(self): + """Each entry should be (level, community_id, parent, node_list).""" + edges = _make_edges([("A", "B", 1.0), ("A", "C", 1.0), ("B", "C", 1.0)]) + clusters = cluster_graph(edges, max_cluster_size=10, use_lcc=False, seed=42) + + for entry in clusters: + assert len(entry) == 4 + level, cid, parent, nodes = entry + assert isinstance(level, int) + assert isinstance(cid, int) + assert isinstance(parent, int) + assert isinstance(nodes, list) + assert all(isinstance(n, str) for n in nodes) + + def test_level_zero_has_parent_minus_one(self): + """All level-0 clusters should have parent == -1.""" + edges = _make_edges([ + ("A", "B", 1.0), + ("A", "C", 1.0), + ("B", "C", 1.0), + ("D", "E", 1.0), + ("D", "F", 1.0), + ("E", "F", 1.0), + ]) + clusters = cluster_graph(edges, max_cluster_size=10, use_lcc=False, seed=42) + + for level, _, parent, _ in clusters: + if level == 0: + assert parent == -1 + + def test_all_nodes_covered_at_each_level(self): + """At any given level, the union of all community nodes should + equal exactly the set of all nodes in the graph for that level.""" + edges = _make_edges([ + ("A", "B", 1.0), + ("A", "C", 1.0), + ("B", "C", 1.0), + ("D", "E", 1.0), + ("D", "F", 1.0), + ("E", "F", 1.0), + ]) + clusters = cluster_graph(edges, max_cluster_size=10, use_lcc=False, seed=42) + + levels: dict[int, set[str]] = {} + for level, _, _, nodes in clusters: + levels.setdefault(level, set()).update(nodes) + + all_nodes = {"A", "B", "C", "D", "E", "F"} + for level, covered_nodes in levels.items(): + assert covered_nodes == all_nodes, ( + f"Level {level}: expected {all_nodes}, got {covered_nodes}" + ) + + +# ------------------------------------------------------------------- +# Real test data (golden file regression) +# ------------------------------------------------------------------- + + +class TestClusterGraphRealData: + """Regression tests using the shared test fixture data.""" + + @pytest.fixture + def relationships(self) -> pd.DataFrame: + """Load the test relationships fixture.""" + return pd.read_parquet("tests/verbs/data/relationships.parquet") + + def test_cluster_count(self, relationships: pd.DataFrame): + """Pin the expected number of clusters from the fixture data.""" + clusters = cluster_graph( + relationships, + max_cluster_size=10, + use_lcc=True, + seed=0xDEADBEEF, + ) + assert len(clusters) == 122 + + def test_level_distribution(self, relationships: pd.DataFrame): + """Pin the expected number of clusters per level.""" + clusters = cluster_graph( + relationships, + max_cluster_size=10, + use_lcc=True, + seed=0xDEADBEEF, + ) + from collections import Counter + + level_counts = Counter(c[0] for c in clusters) + assert level_counts == {0: 23, 1: 65, 2: 32, 3: 2} + + def test_level_zero_nodes_sample(self, relationships: pd.DataFrame): + """Spot-check a few known nodes in level-0 clusters.""" + clusters = cluster_graph( + relationships, + max_cluster_size=10, + use_lcc=True, + seed=0xDEADBEEF, + ) + level_0 = [c for c in clusters if c[0] == 0] + all_level_0_nodes = set() + for _, _, _, nodes in level_0: + all_level_0_nodes.update(nodes) + + assert "SCROOGE" in all_level_0_nodes + assert "ABRAHAM" in all_level_0_nodes + assert "JACOB MARLEY" in all_level_0_nodes diff --git a/tests/unit/indexing/test_create_communities.py b/tests/unit/indexing/test_create_communities.py new file mode 100644 index 000000000..e3f225bf0 --- /dev/null +++ b/tests/unit/indexing/test_create_communities.py @@ -0,0 +1,620 @@ +# Copyright (C) 2026 Microsoft + +"""Tests for the create_communities pure function. + +These tests pin down the behavior of the create_communities function +independently of the workflow runner, so that refactoring (vectorizing +the per-level loop, streaming entity reads, streaming writes, etc.) +can be verified against known output. +""" + +import uuid +from typing import Any + +import numpy as np +import pandas as pd +import pytest +from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS +from graphrag.index.workflows.create_communities import ( + _sanitize_row, + create_communities, +) +from graphrag_storage.tables.csv_table import CSVTable +from graphrag_storage.tables.table import Table + + +class FakeTable(CSVTable): + """In-memory table that collects written rows for test assertions.""" + + def __init__(self) -> None: + self.rows: list[dict[str, Any]] = [] + + async def write(self, row: dict[str, Any]) -> None: + """Append a row to the in-memory store.""" + self.rows.append(row) + + +class FakeEntitiesTable(Table): + """In-memory read-only table that supports async iteration.""" + + def __init__(self, rows: list[dict[str, Any]]) -> None: + self._rows = rows + self._index = 0 + + def __aiter__(self): + """Return an async iterator over the rows.""" + self._index = 0 + return self + + async def __anext__(self) -> dict[str, Any]: + """Yield the next row or stop.""" + if self._index >= len(self._rows): + raise StopAsyncIteration + row = self._rows[self._index] + self._index += 1 + return row + + async def length(self) -> int: + """Return number of rows.""" + return len(self._rows) + + async def has(self, row_id: str) -> bool: + """Check if a row with the given ID exists.""" + return any(r.get("id") == row_id for r in self._rows) + + async def write(self, row: dict[str, Any]) -> None: + """Not supported for read-only table.""" + raise NotImplementedError + + async def close(self) -> None: + """No-op.""" + + +async def _run_create_communities( + title_to_entity_id: dict[str, str], + relationships: pd.DataFrame, + **kwargs: Any, +) -> pd.DataFrame: + """Helper that runs create_communities with fake tables and returns all rows as a DataFrame.""" + communities_table = FakeTable() + entity_rows = [ + {"id": eid, "title": title} for title, eid in title_to_entity_id.items() + ] + entities_table = FakeEntitiesTable(entity_rows) + await create_communities(communities_table, entities_table, relationships, **kwargs) + return pd.DataFrame(communities_table.rows) + + +def _make_title_to_entity_id( + rows: list[tuple[str, str]], +) -> dict[str, str]: + """Build a title-to-entity-id mapping from (id, title) pairs.""" + return {title: eid for eid, title in rows} + + +def _make_relationships( + rows: list[tuple[str, str, str, float, list[str]]], +) -> pd.DataFrame: + """Build a minimal relationships DataFrame. + + Each row is (id, source, target, weight, text_unit_ids). + """ + return pd.DataFrame([ + { + "id": rid, + "source": src, + "target": tgt, + "weight": w, + "text_unit_ids": tuids, + "human_readable_id": i, + } + for i, (rid, src, tgt, w, tuids) in enumerate(rows) + ]) + + +@pytest.fixture +def two_triangles(): + """Two disconnected triangles: {A,B,C} and {D,E,F}.""" + title_to_entity_id = _make_title_to_entity_id([ + ("e1", "A"), + ("e2", "B"), + ("e3", "C"), + ("e4", "D"), + ("e5", "E"), + ("e6", "F"), + ]) + relationships = _make_relationships([ + ("r1", "A", "B", 1.0, ["t1"]), + ("r2", "A", "C", 1.0, ["t1", "t2"]), + ("r3", "B", "C", 1.0, ["t2"]), + ("r4", "D", "E", 1.0, ["t3"]), + ("r5", "D", "F", 1.0, ["t3", "t4"]), + ("r6", "E", "F", 1.0, ["t4"]), + ]) + return title_to_entity_id, relationships + + +# ------------------------------------------------------------------- +# Column schema +# ------------------------------------------------------------------- + + +class TestOutputSchema: + """Verify the output DataFrame has the expected column schema.""" + + async def test_has_all_final_columns(self, two_triangles): + """Output must have exactly the COMMUNITIES_FINAL_COLUMNS.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + assert list(result.columns) == COMMUNITIES_FINAL_COLUMNS + + async def test_column_order_matches_schema(self, two_triangles): + """Column order must match the schema constant exactly.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + for i, col_name in enumerate(COMMUNITIES_FINAL_COLUMNS): + assert result.columns[i] == col_name + + +# ------------------------------------------------------------------- +# Metadata fields +# ------------------------------------------------------------------- + + +class TestMetadataFields: + """Verify computed metadata fields like id, title, size, period.""" + + async def test_uuid_ids(self, two_triangles): + """Each community id should be a valid UUID4.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + for _, row in result.iterrows(): + parsed = uuid.UUID(row["id"]) + assert parsed.version == 4 + + async def test_title_format(self, two_triangles): + """Title should be 'Community N' where N is the community id.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + for _, row in result.iterrows(): + assert row["title"] == f"Community {row['community']}" + + async def test_human_readable_id_equals_community(self, two_triangles): + """human_readable_id should equal the community integer id.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + assert (result["human_readable_id"] == result["community"]).all() + + async def test_size_equals_entity_count(self, two_triangles): + """size should equal the length of entity_ids.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + for _, row in result.iterrows(): + assert row["size"] == len(row["entity_ids"]) + + async def test_period_is_iso_date(self, two_triangles): + """period should be a valid ISO date string.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + from datetime import date + + for _, row in result.iterrows(): + date.fromisoformat(row["period"]) + + +# ------------------------------------------------------------------- +# Entity aggregation +# ------------------------------------------------------------------- + + +class TestEntityAggregation: + """Verify that entity_ids are correctly aggregated per community.""" + + async def test_entity_ids_per_community(self, two_triangles): + """Each community should contain exactly the entities matching + its cluster nodes.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + comm_0 = result[result["community"] == 0].iloc[0] + comm_1 = result[result["community"] == 1].iloc[0] + + assert sorted(comm_0["entity_ids"]) == ["e1", "e2", "e3"] + assert sorted(comm_1["entity_ids"]) == ["e4", "e5", "e6"] + + async def test_entity_ids_are_lists(self, two_triangles): + """entity_ids should be Python lists, not numpy arrays.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + for _, row in result.iterrows(): + assert isinstance(row["entity_ids"], list) + + +# ------------------------------------------------------------------- +# Relationship and text_unit aggregation +# ------------------------------------------------------------------- + + +class TestRelationshipAggregation: + """Verify that relationship_ids and text_unit_ids are correctly + aggregated (intra-community only) and deduplicated.""" + + async def test_relationship_ids_per_community(self, two_triangles): + """Each community should only include relationships where both + endpoints are in the same community.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + comm_0 = result[result["community"] == 0].iloc[0] + comm_1 = result[result["community"] == 1].iloc[0] + + assert sorted(comm_0["relationship_ids"]) == ["r1", "r2", "r3"] + assert sorted(comm_1["relationship_ids"]) == ["r4", "r5", "r6"] + + async def test_text_unit_ids_per_community(self, two_triangles): + """text_unit_ids should be the deduplicated union of text units + from the community's intra-community relationships.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + comm_0 = result[result["community"] == 0].iloc[0] + comm_1 = result[result["community"] == 1].iloc[0] + + assert sorted(comm_0["text_unit_ids"]) == ["t1", "t2"] + assert sorted(comm_1["text_unit_ids"]) == ["t3", "t4"] + + async def test_lists_are_sorted_and_deduplicated(self, two_triangles): + """relationship_ids and text_unit_ids should be sorted with + no duplicates.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + for _, row in result.iterrows(): + assert row["relationship_ids"] == sorted(set(row["relationship_ids"])) + assert row["text_unit_ids"] == sorted(set(row["text_unit_ids"])) + + async def test_cross_community_relationships_excluded(self): + """A relationship spanning two communities must not appear in + either community's relationship_ids.""" + title_to_entity_id = _make_title_to_entity_id([ + ("e1", "A"), + ("e2", "B"), + ("e3", "C"), + ("e4", "D"), + ("e5", "E"), + ("e6", "F"), + ]) + relationships = _make_relationships([ + ("r1", "A", "B", 1.0, ["t1"]), + ("r2", "A", "C", 1.0, ["t1"]), + ("r3", "B", "C", 1.0, ["t1"]), + ("r_cross", "C", "D", 0.1, ["t_cross"]), + ("r4", "D", "E", 1.0, ["t2"]), + ("r5", "D", "F", 1.0, ["t2"]), + ("r6", "E", "F", 1.0, ["t2"]), + ]) + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + all_rel_ids = [] + for _, row in result.iterrows(): + all_rel_ids.extend(row["relationship_ids"]) + assert "r_cross" not in all_rel_ids + assert "t_cross" not in [ + tid for _, row in result.iterrows() for tid in row["text_unit_ids"] + ] + + +# ------------------------------------------------------------------- +# Parent / children tree +# ------------------------------------------------------------------- + + +class TestParentChildTree: + """Verify the parent-child tree structure is consistent.""" + + async def test_level_zero_parent_is_minus_one(self, two_triangles): + """All level-0 communities should have parent == -1.""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + lvl0 = result[result["level"] == 0] + assert (lvl0["parent"] == -1).all() + + async def test_leaf_communities_have_empty_children(self, two_triangles): + """Communities that are nobody's parent should have children=[].""" + title_to_entity_id, relationships = two_triangles + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + for _, row in result.iterrows(): + children = row["children"] + if isinstance(children, list) and len(children) == 0: + child_rows = result[result["parent"] == row["community"]] + assert len(child_rows) == 0 + + async def test_parent_child_bidirectional_consistency_real_data(self): + """For real test data: if community X lists Y as child, + then Y's parent must be X.""" + entities_df = pd.read_parquet("tests/verbs/data/entities.parquet") + title_to_entity_id = dict( + zip(entities_df["title"], entities_df["id"], strict=False) + ) + relationships = pd.read_parquet("tests/verbs/data/relationships.parquet") + result = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=True, + seed=0xDEADBEEF, + ) + for _, row in result.iterrows(): + children = row["children"] + if hasattr(children, "__len__") and len(children) > 0: + for child_id in children: + child_row = result[result["community"] == child_id] + assert len(child_row) == 1, ( + f"Child {child_id} not found or duplicated" + ) + assert child_row.iloc[0]["parent"] == row["community"] + + +# ------------------------------------------------------------------- +# LCC filtering +# ------------------------------------------------------------------- + + +class TestLccFiltering: + """Verify LCC filtering interaction with create_communities.""" + + async def test_lcc_reduces_community_count(self): + """With use_lcc=True and two disconnected components, only the + larger component's communities should appear.""" + title_to_entity_id = _make_title_to_entity_id([ + ("e1", "A"), + ("e2", "B"), + ("e3", "C"), + ("e4", "D"), + ("e5", "E"), + ("e6", "F"), + ]) + relationships = _make_relationships([ + ("r1", "A", "B", 1.0, ["t1"]), + ("r2", "A", "C", 1.0, ["t1"]), + ("r3", "B", "C", 1.0, ["t1"]), + ("r4", "D", "E", 1.0, ["t2"]), + ("r5", "D", "F", 1.0, ["t2"]), + ("r6", "E", "F", 1.0, ["t2"]), + ]) + result_no_lcc = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=False, + seed=42, + ) + result_lcc = await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=True, + seed=42, + ) + assert len(result_lcc) < len(result_no_lcc) + assert len(result_lcc) == 1 + + +# ------------------------------------------------------------------- +# Golden file regression (real test data) +# ------------------------------------------------------------------- + + +class TestRealDataRegression: + """Regression tests using the shared test fixture data. + + These pin exact values so any behavioral change during refactoring + is caught immediately. + """ + + @pytest.fixture + async def real_result(self) -> pd.DataFrame: + """Run create_communities on the test fixture data.""" + entities_df = pd.read_parquet("tests/verbs/data/entities.parquet") + title_to_entity_id = dict( + zip(entities_df["title"], entities_df["id"], strict=False) + ) + relationships = pd.read_parquet("tests/verbs/data/relationships.parquet") + return await _run_create_communities( + title_to_entity_id, + relationships, + max_cluster_size=10, + use_lcc=True, + seed=0xDEADBEEF, + ) + + async def test_row_count(self, real_result: pd.DataFrame): + """Pin the expected number of communities.""" + assert len(real_result) == 122 + + async def test_level_distribution(self, real_result: pd.DataFrame): + """Pin the expected number of communities per level.""" + from collections import Counter + + counts = Counter(real_result["level"].tolist()) + assert counts == {0: 23, 1: 65, 2: 32, 3: 2} + + async def test_values_match_golden_file(self, real_result: pd.DataFrame): + """The output should match the golden Parquet file for all + columns except id (UUID) and period (date-dependent).""" + expected = pd.read_parquet("tests/verbs/data/communities.parquet") + + assert len(real_result) == len(expected) + + skip_columns = {"id", "period", "children"} + for col in COMMUNITIES_FINAL_COLUMNS: + if col in skip_columns: + continue + pd.testing.assert_series_equal( + real_result[col], + expected[col], + check_dtype=False, + check_index=False, + check_names=False, + obj=f"Column '{col}'", + ) + + # children requires special handling: the golden file stores + # numpy arrays, the function may return lists or arrays + for i in range(len(real_result)): + actual_children = list(real_result.iloc[i]["children"]) + expected_children = list(expected.iloc[i]["children"]) + assert actual_children == expected_children, ( + f"Row {i} children mismatch: {actual_children} != {expected_children}" + ) + + async def test_communities_with_children(self, real_result: pd.DataFrame): + """Pin the expected number of communities that have children.""" + has_children = real_result["children"].apply( + lambda x: hasattr(x, "__len__") and len(x) > 0 + ) + assert has_children.sum() == 24 + + +# ------------------------------------------------------------------- +# Row sanitization +# ------------------------------------------------------------------- + + +class TestSanitizeRow: + """Verify numpy types are converted to native Python types.""" + + def test_ndarray_to_list(self): + """np.ndarray values should become plain lists.""" + row = {"children": np.array([1, 2, 3])} + result = _sanitize_row(row) + assert result["children"] == [1, 2, 3] + assert isinstance(result["children"], list) + + def test_empty_ndarray_to_empty_list(self): + """An empty np.ndarray should become an empty list.""" + row = {"children": np.array([])} + assert _sanitize_row(row)["children"] == [] + + def test_np_integer_to_int(self): + """np.integer values should become native int.""" + row = {"community": np.int64(42)} + result = _sanitize_row(row) + assert result["community"] == 42 + assert type(result["community"]) is int + + def test_np_floating_to_float(self): + """np.floating values should become native float.""" + row = {"weight": np.float64(3.14)} + result = _sanitize_row(row) + assert result["weight"] == pytest.approx(3.14) + assert type(result["weight"]) is float + + def test_native_types_pass_through(self): + """Native Python types should pass through unchanged.""" + row = {"id": "abc", "size": 5, "tags": ["a", "b"]} + assert _sanitize_row(row) == row + + def test_mixed_row(self): + """A row with a mix of numpy and native types.""" + row = { + "community": np.int64(7), + "children": np.array([1, 2]), + "title": "Community 7", + "weight": np.float64(0.5), + } + result = _sanitize_row(row) + assert result == { + "community": 7, + "children": [1, 2], + "title": "Community 7", + "weight": pytest.approx(0.5), + } + assert type(result["community"]) is int + assert type(result["children"]) is list + assert type(result["weight"]) is float