Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions datajunction-clients/python/datajunction/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,18 @@ def create_parser(self):
dest="mcp",
help="Skip MCP server configuration",
)
setup_claude_parser.add_argument(
"--agents",
action="store_true",
default=True,
help="Install DJ subagent to ~/.claude/agents/ (default: True)",
)
setup_claude_parser.add_argument(
"--no-agents",
action="store_false",
dest="agents",
help="Skip subagent installation",
)

return parser

Expand Down Expand Up @@ -1356,6 +1368,7 @@ def dispatch_command(self, args, parser):
output_dir=Path(args.output),
skills=args.skills,
mcp=args.mcp,
agents=args.agents,
)
else:
parser.print_help() # pragma: no cover
Expand All @@ -1377,6 +1390,7 @@ def setup_claude(
output_dir: Path,
skills: bool = True,
mcp: bool = True,
agents: bool = True,
):
"""Configure Claude Code integration with DJ."""
import json
Expand Down Expand Up @@ -1464,23 +1478,50 @@ def setup_claude(
"[red]✗ Bundled skill not found. Please ensure datajunction is properly installed.[/red]",
)

# Install subagent if requested
if agents:
agents_dir = Path.home() / ".claude" / "agents"
agents_dir.mkdir(parents=True, exist_ok=True)
agent_file = agents_dir / "dj.md"

console.print("[bold]🤖 Installing DJ subagent[/bold]\n")

subagent_content = """\
---
name: dj
description: >
DataJunction semantic layer expert. Use proactively for any DataJunction
or DJ work — querying metrics, exploring nodes and dimensions, building
SQL, understanding lineage, and semantic layer design.
skills:
- datajunction
model: inherit
---
"""
with open(agent_file, "w") as f:
f.write(subagent_content)

console.print(f"[green]✓ Installed subagent to {agent_file}[/green]\n")

# Setup MCP if requested
if mcp:
self._setup_mcp_server(console)

# Final success message
if skills and mcp:
anything_installed = skills or mcp or agents
if anything_installed: # pragma: no branch
console.print(
"\n[bold green]✓ Claude Code integration complete[/bold green]",
)
parts = []
if skills:
parts.append("skill")
if agents:
parts.append("subagent")
if mcp:
parts.append("MCP server")
console.print(
"[dim]Skills and MCP server are now configured. Restart Claude Code to load changes.[/dim]",
)
elif skills:
console.print("\n[dim]Skills are now available in Claude Code.[/dim]")
elif mcp: # pragma: no branch
console.print(
"\n[dim]MCP server configured. Restart Claude Code to load changes.[/dim]",
f"[dim]{', '.join(parts).capitalize()} installed. Restart Claude Code to load changes.[/dim]",
)

except Exception as e: # pragma: no cover
Expand Down
148 changes: 124 additions & 24 deletions datajunction-clients/python/datajunction/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,50 +42,73 @@ async def list_tools() -> list[types.Tool]:
types.Tool(
name="search_nodes",
description=(
"Search for DataJunction nodes (metrics, dimensions, cubes, sources, transforms) "
"by name fragment or other properties. Returns a list of matching nodes with "
"their basic information including status, tags, and owners. "
"TIP: Use the 'namespace' parameter to narrow searches - namespaces are the primary "
"organizational structure in DJ (e.g., 'demo.metrics', 'common.dimensions')."
"Search for DataJunction nodes (metrics, dimensions, cubes, sources, transforms). "
"All filters are optional and combinable: name fragment, node type, namespace, tags, "
"status (valid/invalid), mode (published/draft), owner, and materialization. "
"TIP: Use 'namespace' to narrow searches to a domain. "
"Use 'statuses: [invalid]' to find broken nodes. "
"Use 'mode: draft' to see in-progress work on a branch. "
"Use 'has_materialization: true' to find cubes with materializations."
),
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search term - fragment of node name to search for (e.g., 'revenue', 'user')",
"description": "Optional: Fragment of node name to search for (e.g., 'revenue', 'user'). Can be omitted when filtering by tag.",
},
"node_type": {
"type": "string",
"enum": ["metric", "dimension", "cube", "source", "transform"],
"description": "Optional: Filter results to specific node type",
"description": "Optional: Filter results to a specific node type",
},
"namespace": {
"type": "string",
"description": (
"Optional: Filter results to specific namespace (e.g., 'demo.metrics', 'common.dimensions'). "
"HIGHLY RECOMMENDED - namespaces are the primary way to organize nodes in DJ. "
"Use this to narrow search results to a specific domain or area."
"Optional: Filter results to a specific namespace (e.g., 'demo.metrics', 'common.dimensions'). "
"HIGHLY RECOMMENDED - namespaces are the primary way to organize nodes in DJ."
),
},
"tags": {
"type": "array",
"items": {"type": "string"},
"description": "Optional: Filter to nodes tagged with ALL of these tag names (e.g., ['revenue', 'core'])",
},
"statuses": {
"type": "array",
"items": {"type": "string", "enum": ["valid", "invalid"]},
"description": "Optional: Filter by node status (e.g., ['valid'] for healthy nodes, ['invalid'] to find broken ones)",
},
"mode": {
"type": "string",
"enum": ["published", "draft"],
"description": "Optional: Filter by mode — 'published' for production nodes, 'draft' for in-progress work on a branch",
},
"owned_by": {
"type": "string",
"description": "Optional: Filter to nodes owned by this username or email",
},
"has_materialization": {
"type": "boolean",
"default": False,
"description": "Optional: If true, return only nodes that have materializations configured",
},
"limit": {
"type": "integer",
"default": 100,
"minimum": 1,
"maximum": 1000,
"description": "Maximum number of results to return (default: 100, max: 1000)",
"description": "Maximum number of results to return (default: 100)",
},
"prefer_main_branch": {
"type": "boolean",
"default": True,
"description": (
"When true and namespace is provided, automatically searches the .main branch "
"(e.g., 'finance' becomes 'finance.main'). Set to false to search all branches."
"Default: true."
),
},
},
"required": ["query"],
},
),
types.Tool(
Expand All @@ -107,22 +130,27 @@ async def list_tools() -> list[types.Tool]:
},
),
types.Tool(
name="get_common_dimensions",
name="get_common",
description=(
"Find dimensions that are available across multiple metrics. "
"Use this to determine which dimensions you can use when querying multiple metrics together. "
"Returns the list of common dimensions that work across all specified metrics."
"Bidirectional semantic compatibility lookup. "
"Pass 'metrics' to find which dimensions are shared across all those metrics (i.e. what can I slice these metrics by?). "
"Pass 'dimensions' to find which metrics can be queried by all those dimensions (i.e. what can I analyze by this dimension?). "
"Provide exactly one of metrics or dimensions."
),
inputSchema={
"type": "object",
"properties": {
"metric_names": {
"metrics": {
"type": "array",
"items": {"type": "string"},
"description": "List of metric node names — returns the dimensions common across all of them",
},
"dimensions": {
"type": "array",
"items": {"type": "string"},
"description": "List of metric node names to analyze (e.g., ['finance.revenue', 'growth.users'])",
"description": "List of dimension attribute names — returns metrics compatible with all of them",
},
},
"required": ["metric_names"],
},
),
types.Tool(
Expand Down Expand Up @@ -264,6 +292,58 @@ async def list_tools() -> list[types.Tool]:
"required": ["node_name"],
},
),
types.Tool(
name="get_query_plan",
description=(
"Get the query execution plan for a set of metrics, showing how DJ decomposes them "
"into grain groups and atomic aggregation components. "
"A grain group is a set of metrics that share a common dimensional grain and can be "
"computed together in a single SQL query. "
"Each component is an atomic aggregation (e.g., SUM(amount), COUNT(*)) that feeds "
"into the final metric formula. "
"Use this to understand multi-metric query structure, debug unexpected results, "
"validate semantic model design, or explain how a metric is computed."
),
inputSchema={
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {"type": "string"},
"description": "List of metric node names to analyze (e.g., ['finance.daily_revenue', 'growth.new_users'])",
},
"dimensions": {
"type": "array",
"items": {"type": "string"},
"description": "Optional: List of dimensions to group by — affects grain group assignment",
},
"filters": {
"type": "array",
"items": {"type": "string"},
"description": "Optional: SQL filter conditions",
},
"dialect": {
"type": "string",
"description": "Optional: Target SQL dialect (e.g., 'spark', 'trino', 'postgres')",
},
"use_materialized": {
"type": "boolean",
"default": True,
"description": "Optional: Whether to use materialized tables when available (default: true)",
},
"include_temporal_filters": {
"type": "boolean",
"default": False,
"description": "Optional: Include temporal partition filters if the metrics resolve to a cube with partitions",
},
"lookback_window": {
"type": "string",
"description": "Optional: Lookback window for temporal filters (e.g., '3 DAY', '1 WEEK'). Only used when include_temporal_filters is true.",
},
},
"required": ["metrics"],
},
),
types.Tool(
name="visualize_metrics",
description=(
Expand Down Expand Up @@ -350,9 +430,14 @@ async def call_tool(name: str, arguments: dict) -> list[types.TextContent]:

elif name == "search_nodes":
result = await tools.search_nodes(
query=arguments["query"],
query=arguments.get("query", ""),
node_type=arguments.get("node_type"),
namespace=arguments.get("namespace"),
tags=arguments.get("tags"),
statuses=arguments.get("statuses"),
mode=arguments.get("mode"),
owned_by=arguments.get("owned_by"),
has_materialization=arguments.get("has_materialization", False),
limit=arguments.get("limit", 100),
prefer_main_branch=arguments.get("prefer_main_branch", True),
)
Expand All @@ -362,9 +447,10 @@ async def call_tool(name: str, arguments: dict) -> list[types.TextContent]:
name=arguments["name"],
)

elif name == "get_common_dimensions":
result = await tools.get_common_dimensions(
metric_names=arguments["metric_names"],
elif name == "get_common":
result = await tools.get_common(
metrics=arguments.get("metrics"),
dimensions=arguments.get("dimensions"),
)

elif name == "build_metric_sql":
Expand All @@ -387,6 +473,20 @@ async def call_tool(name: str, arguments: dict) -> list[types.TextContent]:
limit=arguments.get("limit"),
)

elif name == "get_query_plan":
result = await tools.get_query_plan(
metrics=arguments["metrics"],
dimensions=arguments.get("dimensions"),
filters=arguments.get("filters"),
dialect=arguments.get("dialect"),
use_materialized=arguments.get("use_materialized", True),
include_temporal_filters=arguments.get(
"include_temporal_filters",
False,
),
lookback_window=arguments.get("lookback_window"),
)

elif name == "get_node_lineage":
result = await tools.get_node_lineage(
node_name=arguments["node_name"],
Expand Down
Loading
Loading