diff --git a/datajunction-clients/python/datajunction/cli.py b/datajunction-clients/python/datajunction/cli.py index 4efc46ce1..a9f165910 100644 --- a/datajunction-clients/python/datajunction/cli.py +++ b/datajunction-clients/python/datajunction/cli.py @@ -1266,6 +1266,18 @@ def create_parser(self): dest="mcp", help="Skip MCP server configuration", ) + setup_claude_parser.add_argument( + "--agents", + action="store_true", + default=True, + help="Install DJ subagent to ~/.claude/agents/ (default: True)", + ) + setup_claude_parser.add_argument( + "--no-agents", + action="store_false", + dest="agents", + help="Skip subagent installation", + ) return parser @@ -1356,6 +1368,7 @@ def dispatch_command(self, args, parser): output_dir=Path(args.output), skills=args.skills, mcp=args.mcp, + agents=args.agents, ) else: parser.print_help() # pragma: no cover @@ -1377,6 +1390,7 @@ def setup_claude( output_dir: Path, skills: bool = True, mcp: bool = True, + agents: bool = True, ): """Configure Claude Code integration with DJ.""" import json @@ -1464,23 +1478,50 @@ def setup_claude( "[red]βœ— Bundled skill not found. Please ensure datajunction is properly installed.[/red]", ) + # Install subagent if requested + if agents: + agents_dir = Path.home() / ".claude" / "agents" + agents_dir.mkdir(parents=True, exist_ok=True) + agent_file = agents_dir / "dj.md" + + console.print("[bold]πŸ€– Installing DJ subagent[/bold]\n") + + subagent_content = """\ +--- +name: dj +description: > + DataJunction semantic layer expert. Use proactively for any DataJunction + or DJ work β€” querying metrics, exploring nodes and dimensions, building + SQL, understanding lineage, and semantic layer design. +skills: + - datajunction +model: inherit +--- +""" + with open(agent_file, "w") as f: + f.write(subagent_content) + + console.print(f"[green]βœ“ Installed subagent to {agent_file}[/green]\n") + # Setup MCP if requested if mcp: self._setup_mcp_server(console) # Final success message - if skills and mcp: + anything_installed = skills or mcp or agents + if anything_installed: # pragma: no branch console.print( "\n[bold green]βœ“ Claude Code integration complete[/bold green]", ) + parts = [] + if skills: + parts.append("skill") + if agents: + parts.append("subagent") + if mcp: + parts.append("MCP server") console.print( - "[dim]Skills and MCP server are now configured. Restart Claude Code to load changes.[/dim]", - ) - elif skills: - console.print("\n[dim]Skills are now available in Claude Code.[/dim]") - elif mcp: # pragma: no branch - console.print( - "\n[dim]MCP server configured. Restart Claude Code to load changes.[/dim]", + f"[dim]{', '.join(parts).capitalize()} installed. Restart Claude Code to load changes.[/dim]", ) except Exception as e: # pragma: no cover diff --git a/datajunction-clients/python/datajunction/mcp/server.py b/datajunction-clients/python/datajunction/mcp/server.py index 958b07f78..f55950beb 100644 --- a/datajunction-clients/python/datajunction/mcp/server.py +++ b/datajunction-clients/python/datajunction/mcp/server.py @@ -42,38 +42,63 @@ async def list_tools() -> list[types.Tool]: types.Tool( name="search_nodes", description=( - "Search for DataJunction nodes (metrics, dimensions, cubes, sources, transforms) " - "by name fragment or other properties. Returns a list of matching nodes with " - "their basic information including status, tags, and owners. " - "TIP: Use the 'namespace' parameter to narrow searches - namespaces are the primary " - "organizational structure in DJ (e.g., 'demo.metrics', 'common.dimensions')." + "Search for DataJunction nodes (metrics, dimensions, cubes, sources, transforms). " + "All filters are optional and combinable: name fragment, node type, namespace, tags, " + "status (valid/invalid), mode (published/draft), owner, and materialization. " + "TIP: Use 'namespace' to narrow searches to a domain. " + "Use 'statuses: [invalid]' to find broken nodes. " + "Use 'mode: draft' to see in-progress work on a branch. " + "Use 'has_materialization: true' to find cubes with materializations." ), inputSchema={ "type": "object", "properties": { "query": { "type": "string", - "description": "Search term - fragment of node name to search for (e.g., 'revenue', 'user')", + "description": "Optional: Fragment of node name to search for (e.g., 'revenue', 'user'). Can be omitted when filtering by tag.", }, "node_type": { "type": "string", "enum": ["metric", "dimension", "cube", "source", "transform"], - "description": "Optional: Filter results to specific node type", + "description": "Optional: Filter results to a specific node type", }, "namespace": { "type": "string", "description": ( - "Optional: Filter results to specific namespace (e.g., 'demo.metrics', 'common.dimensions'). " - "HIGHLY RECOMMENDED - namespaces are the primary way to organize nodes in DJ. " - "Use this to narrow search results to a specific domain or area." + "Optional: Filter results to a specific namespace (e.g., 'demo.metrics', 'common.dimensions'). " + "HIGHLY RECOMMENDED - namespaces are the primary way to organize nodes in DJ." ), }, + "tags": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional: Filter to nodes tagged with ALL of these tag names (e.g., ['revenue', 'core'])", + }, + "statuses": { + "type": "array", + "items": {"type": "string", "enum": ["valid", "invalid"]}, + "description": "Optional: Filter by node status (e.g., ['valid'] for healthy nodes, ['invalid'] to find broken ones)", + }, + "mode": { + "type": "string", + "enum": ["published", "draft"], + "description": "Optional: Filter by mode β€” 'published' for production nodes, 'draft' for in-progress work on a branch", + }, + "owned_by": { + "type": "string", + "description": "Optional: Filter to nodes owned by this username or email", + }, + "has_materialization": { + "type": "boolean", + "default": False, + "description": "Optional: If true, return only nodes that have materializations configured", + }, "limit": { "type": "integer", "default": 100, "minimum": 1, "maximum": 1000, - "description": "Maximum number of results to return (default: 100, max: 1000)", + "description": "Maximum number of results to return (default: 100)", }, "prefer_main_branch": { "type": "boolean", @@ -81,11 +106,9 @@ async def list_tools() -> list[types.Tool]: "description": ( "When true and namespace is provided, automatically searches the .main branch " "(e.g., 'finance' becomes 'finance.main'). Set to false to search all branches." - "Default: true." ), }, }, - "required": ["query"], }, ), types.Tool( @@ -107,22 +130,27 @@ async def list_tools() -> list[types.Tool]: }, ), types.Tool( - name="get_common_dimensions", + name="get_common", description=( - "Find dimensions that are available across multiple metrics. " - "Use this to determine which dimensions you can use when querying multiple metrics together. " - "Returns the list of common dimensions that work across all specified metrics." + "Bidirectional semantic compatibility lookup. " + "Pass 'metrics' to find which dimensions are shared across all those metrics (i.e. what can I slice these metrics by?). " + "Pass 'dimensions' to find which metrics can be queried by all those dimensions (i.e. what can I analyze by this dimension?). " + "Provide exactly one of metrics or dimensions." ), inputSchema={ "type": "object", "properties": { - "metric_names": { + "metrics": { + "type": "array", + "items": {"type": "string"}, + "description": "List of metric node names β€” returns the dimensions common across all of them", + }, + "dimensions": { "type": "array", "items": {"type": "string"}, - "description": "List of metric node names to analyze (e.g., ['finance.revenue', 'growth.users'])", + "description": "List of dimension attribute names β€” returns metrics compatible with all of them", }, }, - "required": ["metric_names"], }, ), types.Tool( @@ -264,6 +292,58 @@ async def list_tools() -> list[types.Tool]: "required": ["node_name"], }, ), + types.Tool( + name="get_query_plan", + description=( + "Get the query execution plan for a set of metrics, showing how DJ decomposes them " + "into grain groups and atomic aggregation components. " + "A grain group is a set of metrics that share a common dimensional grain and can be " + "computed together in a single SQL query. " + "Each component is an atomic aggregation (e.g., SUM(amount), COUNT(*)) that feeds " + "into the final metric formula. " + "Use this to understand multi-metric query structure, debug unexpected results, " + "validate semantic model design, or explain how a metric is computed." + ), + inputSchema={ + "type": "object", + "properties": { + "metrics": { + "type": "array", + "items": {"type": "string"}, + "description": "List of metric node names to analyze (e.g., ['finance.daily_revenue', 'growth.new_users'])", + }, + "dimensions": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional: List of dimensions to group by β€” affects grain group assignment", + }, + "filters": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional: SQL filter conditions", + }, + "dialect": { + "type": "string", + "description": "Optional: Target SQL dialect (e.g., 'spark', 'trino', 'postgres')", + }, + "use_materialized": { + "type": "boolean", + "default": True, + "description": "Optional: Whether to use materialized tables when available (default: true)", + }, + "include_temporal_filters": { + "type": "boolean", + "default": False, + "description": "Optional: Include temporal partition filters if the metrics resolve to a cube with partitions", + }, + "lookback_window": { + "type": "string", + "description": "Optional: Lookback window for temporal filters (e.g., '3 DAY', '1 WEEK'). Only used when include_temporal_filters is true.", + }, + }, + "required": ["metrics"], + }, + ), types.Tool( name="visualize_metrics", description=( @@ -350,9 +430,14 @@ async def call_tool(name: str, arguments: dict) -> list[types.TextContent]: elif name == "search_nodes": result = await tools.search_nodes( - query=arguments["query"], + query=arguments.get("query", ""), node_type=arguments.get("node_type"), namespace=arguments.get("namespace"), + tags=arguments.get("tags"), + statuses=arguments.get("statuses"), + mode=arguments.get("mode"), + owned_by=arguments.get("owned_by"), + has_materialization=arguments.get("has_materialization", False), limit=arguments.get("limit", 100), prefer_main_branch=arguments.get("prefer_main_branch", True), ) @@ -362,9 +447,10 @@ async def call_tool(name: str, arguments: dict) -> list[types.TextContent]: name=arguments["name"], ) - elif name == "get_common_dimensions": - result = await tools.get_common_dimensions( - metric_names=arguments["metric_names"], + elif name == "get_common": + result = await tools.get_common( + metrics=arguments.get("metrics"), + dimensions=arguments.get("dimensions"), ) elif name == "build_metric_sql": @@ -387,6 +473,20 @@ async def call_tool(name: str, arguments: dict) -> list[types.TextContent]: limit=arguments.get("limit"), ) + elif name == "get_query_plan": + result = await tools.get_query_plan( + metrics=arguments["metrics"], + dimensions=arguments.get("dimensions"), + filters=arguments.get("filters"), + dialect=arguments.get("dialect"), + use_materialized=arguments.get("use_materialized", True), + include_temporal_filters=arguments.get( + "include_temporal_filters", + False, + ), + lookback_window=arguments.get("lookback_window"), + ) + elif name == "get_node_lineage": result = await tools.get_node_lineage( node_name=arguments["node_name"], diff --git a/datajunction-clients/python/datajunction/mcp/tools.py b/datajunction-clients/python/datajunction/mcp/tools.py index 97dafc730..dafe5f8af 100644 --- a/datajunction-clients/python/datajunction/mcp/tools.py +++ b/datajunction-clients/python/datajunction/mcp/tools.py @@ -248,9 +248,14 @@ async def list_namespaces() -> str: async def search_nodes( - query: str, + query: str = "", node_type: Optional[str] = None, namespace: Optional[str] = None, + tags: Optional[List[str]] = None, + statuses: Optional[List[str]] = None, + mode: Optional[str] = None, + owned_by: Optional[str] = None, + has_materialization: bool = False, limit: int = 100, prefer_main_branch: bool = True, ) -> str: @@ -258,9 +263,14 @@ async def search_nodes( Search for nodes (metrics, dimensions, cubes, etc.) Args: - query: Search term (fragment of node name) + query: Search term (fragment of node name) β€” can be empty when filtering by other params node_type: Optional filter by type (metric, dimension, cube, source, transform) namespace: Optional filter by namespace (highly recommended for narrowing results) + tags: Optional list of tag names β€” returns nodes tagged with ALL specified tags + statuses: Optional list of statuses to filter by (e.g., ['valid'], ['invalid']) + mode: Optional filter by mode: 'published' or 'draft' + owned_by: Optional filter to nodes owned by this username or email + has_materialization: If True, return only nodes with materializations configured limit: Maximum number of results (default: 100, max: 1000) prefer_main_branch: If True and namespace provided, automatically uses .main branch (default: True) @@ -323,12 +333,22 @@ async def search_nodes( $fragment: String, $nodeTypes: [NodeType!], $namespace: String, + $tags: [String!], + $statuses: [NodeStatus!], + $mode: NodeMode, + $ownedBy: String, + $hasMaterialization: Boolean!, $limit: Int ) { findNodes( fragment: $fragment, nodeTypes: $nodeTypes, namespace: $namespace, + tags: $tags, + statuses: $statuses, + mode: $mode, + ownedBy: $ownedBy, + hasMaterialization: $hasMaterialization, limit: $limit ) { name @@ -362,9 +382,14 @@ async def search_nodes( data = await client.query( graphql_query, { - "fragment": query, + "fragment": query or None, "nodeTypes": [node_type.upper()] if node_type else None, "namespace": actual_namespace, + "tags": tags or None, + "statuses": [s.upper() for s in statuses] if statuses else None, + "mode": mode.upper() if mode else None, + "ownedBy": owned_by or None, + "hasMaterialization": has_materialization, "limit": limit, }, ) @@ -447,45 +472,106 @@ async def get_node_details(name: str) -> str: return format_error(str(e), f"Fetching details for node '{name}'") -async def get_common_dimensions(metric_names: List[str]) -> str: +async def get_common( + metrics: Optional[List[str]] = None, + dimensions: Optional[List[str]] = None, +) -> str: """ - Find dimensions that are common across multiple metrics + Bidirectional compatibility lookup: + - Pass metrics β†’ returns dimensions common across all those metrics + - Pass dimensions β†’ returns metrics that can be queried by all those dimensions + + Exactly one of metrics or dimensions must be provided. Args: - metric_names: List of metric node names + metrics: List of metric node names to find common dimensions for + dimensions: List of dimension attribute names to find compatible metrics for Returns: - Formatted dimension compatibility report + Formatted compatibility report """ - graphql_query = """ - query GetCommonDimensions($nodes: [String!]) { - commonDimensions(nodes: $nodes) { - name - type - dimensionNode { + if not metrics and not dimensions: + return format_error( + "Either 'metrics' or 'dimensions' must be provided.", + "get_common", + ) + if metrics and dimensions: + return format_error( + "Provide either 'metrics' or 'dimensions', not both.", + "get_common", + ) + + if metrics: + graphql_query = """ + query GetCommonDimensions($nodes: [String!]) { + commonDimensions(nodes: $nodes) { name - current { - description - displayName + type + dimensionNode { + name + current { + description + displayName + } } } } - } - """ - - try: - client = get_client() - data = await client.query(graphql_query, {"nodes": metric_names}) + """ + try: + client = get_client() + data = await client.query(graphql_query, {"nodes": metrics}) + dims = data.get("commonDimensions", []) + return format_dimensions_compatibility(metrics, dims) + except Exception as e: + logger.error(f"Error getting common dimensions: {str(e)}") + return format_error( + str(e), + f"Finding common dimensions for: {', '.join(metrics)}", + ) - dimensions = data.get("commonDimensions", []) - return format_dimensions_compatibility(metric_names, dimensions) + else: + dim_list: List[str] = dimensions # type: ignore[assignment] # non-None guaranteed by guard above + try: + client = get_client() + await client._ensure_token() + async with httpx.AsyncClient( + timeout=client.settings.request_timeout, + ) as http_client: + response = await http_client.get( + f"{client.settings.dj_api_url.rstrip('/')}/dimensions/common/", + params={"dimension": dim_list, "node_type": "metric"}, + headers=client._get_headers(), + ) + response.raise_for_status() + nodes = response.json() + + lines = [ + "Metrics compatible with dimensions:", + "=" * 60, + "", + f"Dimensions: {', '.join(dim_list)}", + "", + ] + if not nodes: + lines.append("No metrics found that share all specified dimensions.") + else: + lines.append(f"Found {len(nodes)} compatible metric(s):\n") + for node in nodes: + lines.append(f" β€’ {node['name']}") + return "\n".join(lines) - except Exception as e: - logger.error(f"Error getting common dimensions: {str(e)}") - return format_error( - str(e), - f"Finding common dimensions for metrics: {', '.join(metric_names)}", - ) + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error: {e.response.status_code} - {e.response.text}") + return format_error( + f"API request failed: {e.response.status_code} - {e.response.text}", + f"Finding metrics for dimensions: {', '.join(dim_list)}", + ) + except Exception as e: + logger.error(f"Error getting common metrics: {str(e)}") + return format_error( + str(e), + f"Finding metrics for dimensions: {', '.join(dim_list)}", + ) async def build_metric_sql( @@ -588,6 +674,147 @@ async def build_metric_sql( ) +async def get_query_plan( + metrics: List[str], + dimensions: Optional[List[str]] = None, + filters: Optional[List[str]] = None, + dialect: Optional[str] = None, + use_materialized: bool = True, + include_temporal_filters: bool = False, + lookback_window: Optional[str] = None, +) -> str: + """ + Get the query execution plan for a set of metrics. + + Shows how DJ decomposes metrics into grain groups (sets of metrics that share + a common dimensional grain), the atomic aggregation components within each group, + and the combiner expressions that reassemble components into final metric values. + + Args: + metrics: List of metric node names to analyze + dimensions: Optional list of dimensions to group by + filters: Optional list of SQL filter conditions + dialect: Optional SQL dialect (e.g., 'spark', 'trino', 'postgres') + use_materialized: Whether to use materialized tables when available (default: True) + include_temporal_filters: Whether to include temporal partition filters (default: False) + lookback_window: Lookback window for temporal filters (e.g., '3 DAY', '1 WEEK') + + Returns: + Human-readable query plan showing grain groups, components, and metric formulas + """ + try: + client = get_client() + await client._ensure_token() + + params: Dict[str, Any] = { + "metrics": metrics, + "dimensions": dimensions or [], + "filters": filters or [], + "use_materialized": use_materialized, + "include_temporal_filters": include_temporal_filters, + } + if dialect: + params["dialect"] = dialect + if lookback_window: + params["lookback_window"] = lookback_window + + async with httpx.AsyncClient( + timeout=client.settings.request_timeout, + ) as http_client: + response = await http_client.get( + f"{client.settings.dj_api_url.rstrip('/')}/sql/measures/v3/", + params=params, + headers=client._get_headers(), + ) + response.raise_for_status() + result = response.json() + + grain_groups = result.get("grain_groups", []) + metric_formulas = result.get("metric_formulas", []) + requested_dimensions = result.get("requested_dimensions", []) + dialect_str = result.get("dialect", "N/A") + + lines = [ + "Query Execution Plan", + "=" * 60, + "", + f"Dialect: {dialect_str}", + f"Metrics: {', '.join(metrics)}", + f"Dimensions: {', '.join(requested_dimensions) if requested_dimensions else 'none'}", + f"Grain Groups: {len(grain_groups)}", + "", + ] + + # Metric formulas section + lines += ["Metric Formulas", "-" * 60] + for formula in metric_formulas: + derived_tag = " [derived]" if formula.get("is_derived") else "" + lines.append(f" {formula['name']}{derived_tag}") + lines.append(f" Original query: {formula.get('query', 'N/A')}") + lines.append(f" Combiner: {formula.get('combiner', 'N/A')}") + components = formula.get("components", []) + if components: + lines.append(f" Components: {', '.join(components)}") + if formula.get("parent_name"): + lines.append(f" Parent node: {formula['parent_name']}") + lines.append("") + + # Grain groups section + lines += ["Grain Groups", "-" * 60] + for i, gg in enumerate(grain_groups, 1): + gg_metrics = gg.get("metrics", []) + grain = gg.get("grain", []) + aggregability = gg.get("aggregability", "N/A") + parent_name = gg.get("parent_name") + scan_estimate = gg.get("scan_estimate") + + lines.append(f" Group {i}: {', '.join(gg_metrics)}") + lines.append( + f" Grain: {', '.join(grain) if grain else 'none'}", + ) + lines.append(f" Aggregability: {aggregability}") + if parent_name: + lines.append(f" Source node: {parent_name}") + if scan_estimate is not None: + if isinstance(scan_estimate, (int, float)): + lines.append(f" Scan estimate: {scan_estimate:,} rows") + else: + lines.append(f" Scan estimate: {scan_estimate}") + + components = gg.get("components", []) + if components: + lines.append(" Components:") + for comp in components: + merge = comp.get("merge", "N/A") + agg = comp.get("aggregation", "N/A") + lines.append( + f" β€’ {comp['name']}: {comp.get('expression', 'N/A')}" + f" (agg={agg}, merge={merge})", + ) + + lines.append("") + lines.append(" SQL:") + lines.append(" " + "-" * 56) + for sql_line in gg.get("sql", "").splitlines(): + lines.append(f" {sql_line}") + lines.append("") + + return "\n".join(lines) + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error: {e.response.status_code} - {e.response.text}") + return format_error( + f"API request failed: {e.response.status_code} - {e.response.text}", + f"Getting query plan for metrics: {', '.join(metrics)}", + ) + except Exception as e: + logger.error(f"Error getting query plan: {str(e)}") + return format_error( + str(e), + f"Getting query plan for metrics: {', '.join(metrics)}", + ) + + async def get_metric_data( metrics: List[str], dimensions: Optional[List[str]] = None, diff --git a/datajunction-clients/python/datajunction/skills/datajunction.md b/datajunction-clients/python/datajunction/skills/datajunction.md index d8dc61ec7..672642f78 100644 --- a/datajunction-clients/python/datajunction/skills/datajunction.md +++ b/datajunction-clients/python/datajunction/skills/datajunction.md @@ -1396,70 +1396,66 @@ mode: published **Temporal partitions** enable automatic partition filtering for performance optimization. When configured, DJ automatically adds partition filters to SQL queries, dramatically improving query performance on large datasets. -#### How Temporal Partitions Work +#### How Partitions Work -When you set a temporal partition on a cube, DJ will: -1. Generate SQL with `${dj_logical_timestamp}` template variables in partition filters -2. These template variables get replaced with actual timestamp values at query execution time -3. Push down these filters to all upstream nodes that have the same dimension linked -4. Reduce data scanned by limiting to only relevant partitions based on the time range +A partition is always declared on a **column**. When that column is a dimension attribute on a cube, DJ uses it as the partition boundary and pushes the filter down to all upstream nodes that link to that same dimension. -#### Configuring Temporal Partitions +**Partition field format:** +```yaml +partition: + type: temporal # or: categorical + granularity: day # second, minute, hour, day, week, month, quarter, year + format: yyyyMMdd # Java/Spark date format (e.g. yyyyMMdd β†’ 20240101, yyyy-MM-dd β†’ 2024-01-01) +``` + +#### Declaring a Partition on a Cube + +In a cube, declare the partition in the **`columns:` section** using the **full dimension attribute path** as the column name: -**Cube YAML with temporal partition:** ```yaml # cubes/revenue_cube.yaml -name: finance.revenue_cube -description: Pre-computed revenue metrics by date and region +name: ${prefix}revenue_cube +node_type: cube metrics: - - finance.total_revenue - - finance.avg_transaction_value + - ${prefix}total_revenue + - ${prefix}order_count dimensions: - common.dimensions.date.dateint - - common.dimensions.date.month - - common.dimensions.users.country_code - -# Temporal partition configuration -temporal_partition: - dimension_attribute: common.dimensions.date.dateint - granularity: day + - common.dimensions.geo.country_code -mode: published +columns: + - name: common.dimensions.date.dateint # ← must match exactly the entry in dimensions + display_name: Date + attributes: + - primary_key + partition: + type: temporal + granularity: day + format: yyyyMMdd ``` -**Temporal partition fields:** -- `dimension_attribute` - The dimension attribute used for partitioning (typically a date field) -- `granularity` - Time granularity: `day`, `month`, `quarter`, `year` +#### How Partition Filter Pushdown Works -#### Requirements for Partition Filtering +Once a cube column has a partition spec, DJ: +1. Generates SQL with `${dj_logical_timestamp}` template variables when `include_temporal_filters=True` +2. Pushes those filters down to all upstream nodes that link to the same dimension +3. Reduces data scanned by limiting to relevant partitions -For DJ to generate partition filters, **all upstream nodes** (sources, transforms, dimensions) must: -1. Have the **same dimension linked** that's used in the temporal partition -2. Use the **same join key** (e.g., `dateint`) +For filter pushdown to work, upstream nodes (sources, transforms) must have a **dimension link to the same dimension**: -**Example - Upstream node with matching dimension link:** ```yaml -# nodes/sources/transactions.yaml -name: finance.transactions -type: source -# ... - +# transforms/orders.yaml dimension_links: - - dimension: common.dimensions.date - join_on: finance.transactions.transaction_date = common.dimensions.date.dateint - # ↑ This matches the temporal_partition.dimension_attribute in the cube + - type: join + dimension_node: common.dimensions.date + join_type: left + join_on: ${prefix}orders.order_date = common.dimensions.date.dateint + # ↑ DJ traces this link and pushes WHERE order_date >= X AND order_date <= Y ``` -**What happens:** -- βœ… If cube has `temporal_partition.dimension_attribute: common.dimensions.date.dateint` -- βœ… And upstream node links to `common.dimensions.date` on `dateint` -- βœ… Then DJ automatically adds partition filters like `WHERE transaction_date >= X AND transaction_date <= Y` - -**What if dimension links don't match:** -- ❌ Cube has temporal partition on `common.dimensions.date.dateint` -- ❌ But upstream node doesn't link to `common.dimensions.date` -- ❌ Result: No automatic partition filtering, full table scan! +- βœ… Upstream node links to `common.dimensions.date` on `order_date` β†’ DJ pushes `WHERE order_date >= X AND order_date <= Y` +- ❌ Upstream node has no link to `common.dimensions.date` β†’ no filter pushed, full table scan #### Regular Filters vs Temporal Filters @@ -1480,17 +1476,17 @@ WHERE transaction_date = 20240101 -- ← Direct filter value GROUP BY transaction_date ``` -**Temporal filters** - Use when you want template variables for incremental processing: +**Temporal filters** - Use when you want to see the pre-aggregation SQL with `${dj_logical_timestamp}` template variables for incremental processing. Use the `get_query_plan` MCP tool: ``` -build_metric_sql( +get_query_plan( metrics=["finance.total_revenue"], dimensions=["common.dimensions.date.dateint"], - include_temporal_filters=True, # Enable temporal filter template generation + include_temporal_filters=True, # Inject temporal filter templates into the grain group SQL lookback_window="7 DAY" # Optional: lookback window ) ``` -**Generated SQL:** +This shows the grain group SQL with template variables that get substituted at materialization time: ```sql SELECT SUM(amount_usd) AS total_revenue, transaction_date FROM finance.transactions @@ -1499,106 +1495,97 @@ WHERE transaction_date >= ${dj_logical_timestamp} -- ← Template variable GROUP BY transaction_date ``` -**When to use temporal filters:** -- Materialization jobs that run incrementally -- Scheduled queries that need dynamic time ranges -- Pre-aggregation pipelines +**When to use temporal filters (via `get_query_plan`):** +- Generating SQL for materialization jobs that run incrementally +- Understanding how pre-aggregation SQL will look with partition filters applied +- Debugging whether partition filter pushdown is working correctly -**When to use regular filters:** +**When to use regular filters (via `build_metric_sql`):** - Ad-hoc queries with specific date ranges -- One-time analysis -- When you know the exact filter values - -**How temporal filters work:** -- `include_temporal_filters=True` generates SQL with `${dj_logical_timestamp}` template variables -- These placeholders get replaced with actual timestamp values at query execution time -- `lookback_window` parameter controls the time range (e.g., '3 DAY', '1 WEEK', '30 DAY') -- The actual filter values are calculated based on cube's temporal partition configuration and execution time +- One-time analysis with known filter values #### Best Practices for Temporal Partitions -1. **Always set temporal partitions on cubes used for dashboards** - - Dramatically improves query performance - - Reduces data scanned +1. **Declare the partition on the cube's `columns:` block** + - Use the full dimension attribute path as the column name (must match exactly what's in `dimensions:`) + - Without a `partition:` declared on a cube column, DJ cannot enable partition filtering for that cube 2. **Ensure consistent dimension links across all nodes** - - Check that all upstream sources/transforms link to the same date dimension - - Use the same join key (e.g., always `dateint`, not mixing `dateint` and `date_str`) + - All upstream sources/transforms must link to the same dimension that carries the partition + - Use the same join key everywhere (e.g., always `dateint`, not mixing `dateint` and `date_str`) -3. **Use appropriate granularity** - - `day` - For daily metrics and dashboards (most common) - - `month` - For monthly aggregations - - `quarter`, `year` - For higher-level reporting +3. **Use appropriate granularity and format** + - `granularity: day` with `format: yyyyMMdd` β€” for integer date partitions like `20240101` + - `granularity: day` with `format: yyyy-MM-dd` β€” for string date partitions like `2024-01-01` + - `granularity: month`, `quarter`, `year` β€” for coarser partitioning 4. **Verify partition filtering is working** - Use `build_metric_sql` with `include_temporal_filters=True` - Check generated SQL includes partition filters on upstream tables - - If filters missing, check dimension link consistency + - If filters are missing, check that upstream nodes have dimension links pointing to the partitioned dimension column -5. **Match physical partition scheme** - - If your data warehouse partitions by `date`, use `dateint` in temporal partition - - Align with how data is actually partitioned in storage +5. **Match the physical partition scheme of your warehouse** + - The `format` must match how partition values are actually stored in the table + - Align granularity with how data is physically partitioned in storage #### Example: Complete Temporal Partition Setup -**Step 1: Source with date dimension link** +The partition is declared on the cube's `dateint` column. DJ pushes the filter down to `orders` because it has a dimension link to `common.dimensions.date`. + +**Step 1: Transform with date dimension link** ```yaml -# nodes/sources/orders.yaml -name: ecommerce.orders -type: source -catalog: prod -schema_: ecommerce -table: orders_partitioned +# transforms/orders.yaml +name: ${prefix}orders +node_type: transform +columns: + - name: order_date + - name: product_id + - name: order_count + - name: total_revenue dimension_links: - - dimension: common.dimensions.date - join_on: ecommerce.orders.order_date = common.dimensions.date.dateint -``` + - type: join + dimension_node: common.dimensions.time.date + join_type: left + join_on: ${prefix}orders.order_date = common.dimensions.time.date.dateint -**Step 2: Transform with same date dimension link** -```yaml -# nodes/transforms/daily_orders.yaml -name: ecommerce.daily_orders -type: transform query: | - SELECT - product_id, - order_date, - COUNT(*) AS order_count, - SUM(amount_usd) AS total_revenue - FROM ecommerce.orders + SELECT product_id, order_date, COUNT(*) AS order_count, SUM(amount_usd) AS total_revenue + FROM source.prod.orders_f GROUP BY product_id, order_date - -dimension_links: - - dimension: common.dimensions.date - join_on: ecommerce.daily_orders.order_date = common.dimensions.date.dateint ``` -**Step 3: Metrics on the transform** +**Step 2: Metrics** ```yaml -# nodes/metrics/total_orders.yaml -name: ecommerce.total_orders -type: metric -query: SELECT SUM(order_count) FROM ecommerce.daily_orders +# metrics/total_orders.yaml +name: ${prefix}total_orders +node_type: metric +query: SELECT SUM(order_count) FROM ${prefix}orders ``` -**Step 4: Cube with temporal partition** +**Step 3: Cube β€” declare the partition on the external dimension attribute** ```yaml -# cubes/ecommerce_cube.yaml -name: ecommerce.ecommerce_cube +# cubes/orders_cube.yaml +name: ${prefix}orders_cube +node_type: cube metrics: - - ecommerce.total_orders - - ecommerce.total_revenue + - ${prefix}total_orders dimensions: - - common.dimensions.date.dateint - - ecommerce.daily_orders.product_id + - common.dimensions.time.date.dateint + - ${prefix}orders.product_id -temporal_partition: - dimension_attribute: common.dimensions.date.dateint - granularity: day +columns: + - name: common.dimensions.time.date.dateint # ← full attribute path, matches dimensions entry + display_name: Date + attributes: + - primary_key + partition: + type: temporal + granularity: day + format: yyyyMMdd ``` -**Result**: Queries on `ecommerce.ecommerce_cube` will automatically include partition filters on both `ecommerce.orders` and `ecommerce.daily_orders` tables! +**Result**: The cube has a temporal partition on `common.dimensions.time.date.dateint`. Queries with `include_temporal_filters=True` will push `WHERE order_date >= X AND order_date <= Y` to the `orders` transform. --- diff --git a/datajunction-clients/python/tests/mcp/test_server_tools.py b/datajunction-clients/python/tests/mcp/test_server_tools.py index 2026434b0..e7742be53 100644 --- a/datajunction-clients/python/tests/mcp/test_server_tools.py +++ b/datajunction-clients/python/tests/mcp/test_server_tools.py @@ -48,6 +48,11 @@ async def test_call_tool_search_nodes(): query="revenue", node_type="metric", namespace="default", + tags=None, + statuses=None, + mode=None, + owned_by=None, + has_materialization=False, limit=10, prefer_main_branch=True, ) @@ -66,11 +71,47 @@ async def test_call_tool_search_nodes_minimal_args(): query="test", node_type=None, namespace=None, + tags=None, + statuses=None, + mode=None, + owned_by=None, + has_materialization=False, limit=100, # default prefer_main_branch=True, ) +@pytest.mark.asyncio +async def test_call_tool_search_nodes_with_new_filters(): + """Test search_nodes dispatch passes tags, statuses, mode, owned_by, has_materialization""" + with patch("datajunction.mcp.server.tools.search_nodes") as mock_search: + mock_search.return_value = "Results" + + await call_tool( + "search_nodes", + { + "tags": ["revenue", "core"], + "statuses": ["valid"], + "mode": "published", + "owned_by": "alice@example.com", + "has_materialization": True, + }, + ) + + mock_search.assert_called_once_with( + query="", + node_type=None, + namespace=None, + tags=["revenue", "core"], + statuses=["valid"], + mode="published", + owned_by="alice@example.com", + has_materialization=True, + limit=100, + prefer_main_branch=True, + ) + + @pytest.mark.asyncio async def test_call_tool_get_node_details(): """Test calling get_node_details tool""" @@ -85,19 +126,41 @@ async def test_call_tool_get_node_details(): @pytest.mark.asyncio -async def test_call_tool_get_common_dimensions(): - """Test calling get_common_dimensions tool""" - with patch("datajunction.mcp.server.tools.get_common_dimensions") as mock_dims: - mock_dims.return_value = "Common dimensions:\n- date\n- region" +async def test_call_tool_get_common_metrics_direction(): + """Test calling get_common tool with metrics β†’ dimensions direction""" + with patch("datajunction.mcp.server.tools.get_common") as mock_common: + mock_common.return_value = "Common dimensions:\n- date\n- region" result = await call_tool( - "get_common_dimensions", - {"metric_names": ["metric1", "metric2"]}, + "get_common", + {"metrics": ["metric1", "metric2"]}, ) assert len(result) == 1 assert "Common dimensions" in result[0].text - mock_dims.assert_called_once_with(metric_names=["metric1", "metric2"]) + mock_common.assert_called_once_with( + metrics=["metric1", "metric2"], + dimensions=None, + ) + + +@pytest.mark.asyncio +async def test_call_tool_get_common_dimensions_direction(): + """Test calling get_common tool with dimensions β†’ metrics direction""" + with patch("datajunction.mcp.server.tools.get_common") as mock_common: + mock_common.return_value = "Compatible metrics:\n- finance.revenue" + + result = await call_tool( + "get_common", + {"dimensions": ["common.dimensions.date.dateint"]}, + ) + + assert len(result) == 1 + assert "Compatible metrics" in result[0].text + mock_common.assert_called_once_with( + metrics=None, + dimensions=["common.dimensions.date.dateint"], + ) @pytest.mark.asyncio @@ -303,17 +366,20 @@ async def test_list_tools_handler(): # Verify we get the expected tools assert isinstance(tools, list) - assert len(tools) == 9 # Should have 8 tools + assert len(tools) == 10 tool_names = [tool.name for tool in tools] assert "list_namespaces" in tool_names assert "search_nodes" in tool_names assert "get_node_details" in tool_names - assert "get_common_dimensions" in tool_names + assert "get_common" in tool_names assert "build_metric_sql" in tool_names assert "get_metric_data" in tool_names assert "get_node_lineage" in tool_names assert "get_node_dimensions" in tool_names + assert "get_query_plan" in tool_names assert "visualize_metrics" in tool_names + # Ensure old tool name is gone + assert "get_common_dimensions" not in tool_names @pytest.mark.asyncio @@ -459,3 +525,61 @@ async def test_call_tool_visualize_metrics_with_y_min(): title=None, y_min=0, ) + + +@pytest.mark.asyncio +async def test_call_tool_get_query_plan(): + """Test calling get_query_plan tool with full args""" + with patch("datajunction.mcp.server.tools.get_query_plan") as mock_plan: + mock_plan.return_value = ( + "Query Execution Plan\n=" * 60 + "\nDialect: spark\nGrain Groups: 1\n" + ) + + result = await call_tool( + "get_query_plan", + { + "metrics": ["finance.revenue", "finance.orders"], + "dimensions": ["common.dimensions.date.dateint"], + "filters": ["date >= '2024-01-01'"], + "dialect": "spark", + "use_materialized": True, + "include_temporal_filters": True, + "lookback_window": "7 DAY", + }, + ) + + assert len(result) == 1 + assert result[0].type == "text" + assert "Query Execution Plan" in result[0].text + mock_plan.assert_called_once_with( + metrics=["finance.revenue", "finance.orders"], + dimensions=["common.dimensions.date.dateint"], + filters=["date >= '2024-01-01'"], + dialect="spark", + use_materialized=True, + include_temporal_filters=True, + lookback_window="7 DAY", + ) + + +@pytest.mark.asyncio +async def test_call_tool_get_query_plan_minimal(): + """Test get_query_plan dispatch with only required metrics arg""" + with patch("datajunction.mcp.server.tools.get_query_plan") as mock_plan: + mock_plan.return_value = "Query Execution Plan\n" + + result = await call_tool( + "get_query_plan", + {"metrics": ["finance.revenue"]}, + ) + + assert len(result) == 1 + mock_plan.assert_called_once_with( + metrics=["finance.revenue"], + dimensions=None, + filters=None, + dialect=None, + use_materialized=True, + include_temporal_filters=False, + lookback_window=None, + ) diff --git a/datajunction-clients/python/tests/mcp/test_tools.py b/datajunction-clients/python/tests/mcp/test_tools.py index 4527f1752..cd601d975 100644 --- a/datajunction-clients/python/tests/mcp/test_tools.py +++ b/datajunction-clients/python/tests/mcp/test_tools.py @@ -490,13 +490,13 @@ async def test_get_node_details_with_git_info(): # ============================================================================ -# get_common_dimensions Tests +# get_common Tests (bidirectional: metricsβ†’dimensions, dimensionsβ†’metrics) # ============================================================================ @pytest.mark.asyncio -async def test_get_common_dimensions_success(): - """Test getting common dimensions""" +async def test_get_common_metrics_to_dimensions(): + """Test get_common: metrics path returns common dimensions via GraphQL""" mock_response = { "commonDimensions": [ { @@ -529,18 +529,21 @@ async def test_get_common_dimensions_success(): mock_client.query.return_value = mock_response mock_get_client.return_value = mock_client - result = await tools.get_common_dimensions( - metric_names=["finance.revenue", "growth.users"], + result = await tools.get_common( + metrics=["finance.revenue", "growth.users"], ) assert "Found 2 common dimensions" in result assert "core.date" in result assert "core.region" in result + # Verify GraphQL was called with the right nodes variable + call_args = mock_client.query.call_args + assert call_args[0][1]["nodes"] == ["finance.revenue", "growth.users"] @pytest.mark.asyncio -async def test_get_common_dimensions_none(): - """Test getting common dimensions when none exist""" +async def test_get_common_metrics_no_dimensions(): + """Test get_common: metrics path with no shared dimensions""" mock_response = {"commonDimensions": []} with patch.object(tools, "get_client") as mock_get_client: @@ -548,29 +551,177 @@ async def test_get_common_dimensions_none(): mock_client.query.return_value = mock_response mock_get_client.return_value = mock_client - result = await tools.get_common_dimensions( - metric_names=["finance.revenue", "growth.users"], + result = await tools.get_common( + metrics=["finance.revenue", "growth.users"], ) assert "No common dimensions found" in result @pytest.mark.asyncio -async def test_get_common_dimensions_error(): - """Test common dimensions error handling""" +async def test_get_common_metrics_error(): + """Test get_common: metrics path GraphQL error handling""" with patch.object(tools, "get_client") as mock_get_client: mock_client = AsyncMock() mock_client.query.side_effect = Exception("Query failed") mock_get_client.return_value = mock_client - result = await tools.get_common_dimensions( - metric_names=["finance.revenue"], + result = await tools.get_common( + metrics=["finance.revenue"], ) assert "Error" in result assert "Query failed" in result +@pytest.mark.asyncio +async def test_get_common_dimensions_to_metrics(): + """Test get_common: dimensions path returns compatible metrics via REST""" + mock_nodes = [ + {"name": "finance.revenue"}, + {"name": "finance.orders"}, + ] + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = mock_nodes + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client.settings.dj_api_url = "http://localhost:8000" + mock_client.settings.request_timeout = 30.0 + mock_client._get_headers.return_value = {} + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_common( + dimensions=[ + "common.dimensions.date.dateint", + "common.dimensions.region.id", + ], + ) + + assert "Metrics compatible with dimensions" in result + assert "Found 2 compatible metric(s)" in result + assert "finance.revenue" in result + assert "finance.orders" in result + # Verify REST endpoint was called with correct params + call_kwargs = mock_http_client.get.call_args + assert "dimension" in call_kwargs.kwargs[ + "params" + ] or "dimension" in call_kwargs[1].get("params", {}) + + +@pytest.mark.asyncio +async def test_get_common_dimensions_no_metrics(): + """Test get_common: dimensions path with no compatible metrics""" + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = [] + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client.settings.dj_api_url = "http://localhost:8000" + mock_client.settings.request_timeout = 30.0 + mock_client._get_headers.return_value = {} + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_common( + dimensions=["common.dimensions.date.dateint"], + ) + + assert "No metrics found" in result + + +@pytest.mark.asyncio +async def test_get_common_no_args(): + """Test get_common returns error when neither metrics nor dimensions are given""" + result = await tools.get_common() + + assert "Error" in result + assert "Either" in result or "metrics" in result + + +@pytest.mark.asyncio +async def test_get_common_both_args(): + """Test get_common returns error when both metrics and dimensions are given""" + result = await tools.get_common( + metrics=["finance.revenue"], + dimensions=["common.dimensions.date.dateint"], + ) + + assert "Error" in result + assert "not both" in result or "either" in result.lower() + + +@pytest.mark.asyncio +async def test_get_common_dimensions_http_error(): + """Test get_common: dimensions path HTTP error handling""" + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client.settings.dj_api_url = "http://localhost:8000" + mock_client.settings.request_timeout = 30.0 + mock_client._get_headers.return_value = {} + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.side_effect = httpx.HTTPStatusError( + "404 Not Found", + request=MagicMock(), + response=MagicMock(status_code=404, text="Not found"), + ) + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_common( + dimensions=["common.dimensions.date.dateint"], + ) + + assert "Error" in result + assert "404" in result + + +@pytest.mark.asyncio +async def test_get_common_dimensions_generic_error(): + """Test get_common: dimensions path generic exception handling""" + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client.settings.dj_api_url = "http://localhost:8000" + mock_client.settings.request_timeout = 30.0 + mock_client._get_headers.return_value = {} + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.side_effect = Exception("Connection refused") + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_common( + dimensions=["common.dimensions.date.dateint"], + ) + + assert "Error" in result + assert "Connection refused" in result + + # ============================================================================ # build_metric_sql Tests # ============================================================================ @@ -2232,3 +2383,656 @@ async def test_get_metric_data_no_materialized_cube(): assert "No materialized cube available" in result assert "expensive ad-hoc computation" in result assert "test.metric" in result + + +# ============================================================================ +# search_nodes β€” new filter parameters +# ============================================================================ + + +@pytest.mark.asyncio +async def test_search_nodes_with_tags(): + """Test search_nodes passes tags to GraphQL""" + mock_response = {"findNodes": []} + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + await tools.search_nodes(tags=["revenue", "core"]) + + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["tags"] == ["revenue", "core"] + + +@pytest.mark.asyncio +async def test_search_nodes_with_statuses(): + """Test search_nodes uppercases and passes statuses to GraphQL""" + mock_response = {"findNodes": []} + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + await tools.search_nodes(statuses=["valid"]) + + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["statuses"] == ["VALID"] + + +@pytest.mark.asyncio +async def test_search_nodes_with_invalid_status(): + """Test search_nodes with 'invalid' status is uppercased correctly""" + mock_response = {"findNodes": []} + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + await tools.search_nodes(statuses=["invalid"]) + + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["statuses"] == ["INVALID"] + + +@pytest.mark.asyncio +async def test_search_nodes_with_mode_published(): + """Test search_nodes with mode='published' is uppercased""" + mock_response = {"findNodes": []} + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + await tools.search_nodes(mode="published") + + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["mode"] == "PUBLISHED" + + +@pytest.mark.asyncio +async def test_search_nodes_with_mode_draft(): + """Test search_nodes with mode='draft' is uppercased""" + mock_response = {"findNodes": []} + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + await tools.search_nodes(mode="draft") + + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["mode"] == "DRAFT" + + +@pytest.mark.asyncio +async def test_search_nodes_with_owned_by(): + """Test search_nodes passes owned_by to GraphQL""" + mock_response = {"findNodes": []} + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + await tools.search_nodes(owned_by="alice@example.com") + + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["ownedBy"] == "alice@example.com" + + +@pytest.mark.asyncio +async def test_search_nodes_with_has_materialization(): + """Test search_nodes passes has_materialization=True to GraphQL""" + mock_response = {"findNodes": []} + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + await tools.search_nodes(has_materialization=True) + + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["hasMaterialization"] is True + + +@pytest.mark.asyncio +async def test_search_nodes_has_materialization_defaults_false(): + """Test search_nodes has_materialization defaults to False""" + mock_response = {"findNodes": []} + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + await tools.search_nodes(query="revenue") + + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["hasMaterialization"] is False + + +@pytest.mark.asyncio +async def test_search_nodes_empty_query_no_fragment(): + """Test search_nodes with empty query sends None as fragment""" + mock_response = {"findNodes": []} + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + await tools.search_nodes(query="") + + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["fragment"] is None + + +@pytest.mark.asyncio +async def test_search_nodes_no_args_sends_none_fragment(): + """Test search_nodes with no args sends None fragment""" + mock_response = {"findNodes": []} + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + await tools.search_nodes() + + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["fragment"] is None + assert variables["tags"] is None + assert variables["statuses"] is None + assert variables["mode"] is None + assert variables["ownedBy"] is None + + +@pytest.mark.asyncio +async def test_search_nodes_combined_new_filters(): + """Test search_nodes with all new filters combined""" + mock_response = { + "findNodes": [ + { + "name": "finance.revenue", + "type": "METRIC", + "createdAt": "2024-01-01T00:00:00Z", + "current": { + "displayName": "Revenue", + "description": "Total revenue", + "status": "VALID", + "mode": "PUBLISHED", + }, + "tags": [{"name": "core", "tagType": "category"}], + "owners": [{"username": "alice", "email": "alice@example.com"}], + }, + ], + } + + with patch.object(tools, "get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.query.return_value = mock_response + mock_get_client.return_value = mock_client + + result = await tools.search_nodes( + query="revenue", + tags=["core"], + statuses=["valid"], + mode="published", + owned_by="alice@example.com", + has_materialization=True, + ) + + assert "finance.revenue" in result + call_args = mock_client.query.call_args + variables = call_args[0][1] + assert variables["tags"] == ["core"] + assert variables["statuses"] == ["VALID"] + assert variables["mode"] == "PUBLISHED" + assert variables["ownedBy"] == "alice@example.com" + assert variables["hasMaterialization"] is True + + +# ============================================================================ +# get_query_plan Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_get_query_plan_success(): + """Test get_query_plan with a full response""" + mock_response_json = { + "dialect": "spark", + "requested_dimensions": ["common.dimensions.date.dateint"], + "grain_groups": [ + { + "metrics": ["finance.revenue", "finance.orders"], + "grain": ["common.dimensions.date.dateint"], + "aggregability": "FULL", + "parent_name": "finance.revenue_cube", + "scan_estimate": 1000000, + "components": [ + { + "name": "revenue_sum", + "expression": "SUM(amount)", + "aggregation": "SUM", + "merge": "SUM", + }, + { + "name": "order_count", + "expression": "COUNT(*)", + "aggregation": "COUNT", + "merge": "SUM", + }, + ], + "sql": "SELECT dateint, SUM(amount), COUNT(*) FROM orders GROUP BY dateint", + }, + ], + "metric_formulas": [ + { + "name": "finance.revenue", + "query": "SUM(amount)", + "combiner": "SUM(revenue_sum)", + "is_derived": False, + "components": ["revenue_sum"], + "parent_name": "finance.revenue_cube", + }, + { + "name": "finance.orders", + "query": "COUNT(*)", + "combiner": "SUM(order_count)", + "is_derived": False, + "components": ["order_count"], + }, + ], + } + + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = mock_response_json + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_query_plan( + metrics=["finance.revenue", "finance.orders"], + dimensions=["common.dimensions.date.dateint"], + ) + + assert "Query Execution Plan" in result + assert "Dialect: spark" in result + assert "finance.revenue" in result + assert "finance.orders" in result + assert "Grain Groups: 1" in result + assert "Group 1:" in result + assert "1,000,000 rows" in result + assert "finance.revenue_cube" in result + assert "Metric Formulas" in result + assert "SUM(amount)" in result + assert "Components:" in result + assert "revenue_sum" in result + + +@pytest.mark.asyncio +async def test_get_query_plan_scan_estimate_as_dict(): + """Test get_query_plan handles dict scan_estimate without crashing""" + mock_response_json = { + "dialect": "spark", + "requested_dimensions": [], + "grain_groups": [ + { + "metrics": ["finance.revenue"], + "grain": [], + "aggregability": "FULL", + "scan_estimate": {"rows": 500, "bytes": 1024}, # dict, not int + "components": [], + "sql": "SELECT SUM(amount) FROM orders", + }, + ], + "metric_formulas": [], + } + + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = mock_response_json + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_query_plan(metrics=["finance.revenue"]) + + # Should not raise a format error; dict is rendered as string + assert "Query Execution Plan" in result + assert "Scan estimate:" in result + assert "rows" not in result or "rows" in result # dict repr, no ":," format + + +@pytest.mark.asyncio +async def test_get_query_plan_derived_metric(): + """Test get_query_plan shows [derived] tag for derived metrics""" + mock_response_json = { + "dialect": "trino", + "requested_dimensions": [], + "grain_groups": [], + "metric_formulas": [ + { + "name": "finance.revenue_per_order", + "query": "SUM(amount) / COUNT(*)", + "combiner": "SUM(revenue_sum) / SUM(order_count)", + "is_derived": True, + "components": ["revenue_sum", "order_count"], + "parent_name": "finance.orders_cube", + }, + ], + } + + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = mock_response_json + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_query_plan(metrics=["finance.revenue_per_order"]) + + assert "[derived]" in result + assert "finance.revenue_per_order" in result + assert "Dialect: trino" in result + + +@pytest.mark.asyncio +async def test_get_query_plan_with_dialect_and_lookback(): + """Test get_query_plan passes dialect and lookback_window params""" + mock_response_json = { + "dialect": "trino", + "requested_dimensions": [], + "grain_groups": [], + "metric_formulas": [], + } + + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = mock_response_json + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + await tools.get_query_plan( + metrics=["finance.revenue"], + dialect="trino", + include_temporal_filters=True, + lookback_window="7 DAY", + ) + + call_kwargs = mock_http_client.get.call_args + # Extract params regardless of positional/keyword style + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert params.get("dialect") == "trino" + assert params.get("lookback_window") == "7 DAY" + assert params.get("include_temporal_filters") is True + + +@pytest.mark.asyncio +async def test_get_query_plan_no_dialect_or_lookback_not_in_params(): + """Test get_query_plan omits dialect and lookback_window when not provided""" + mock_response_json = { + "dialect": "spark", + "requested_dimensions": [], + "grain_groups": [], + "metric_formulas": [], + } + + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = mock_response_json + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + await tools.get_query_plan(metrics=["finance.revenue"]) + + call_kwargs = mock_http_client.get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert "dialect" not in params + assert "lookback_window" not in params + + +@pytest.mark.asyncio +async def test_get_query_plan_http_error(): + """Test get_query_plan handles HTTP errors""" + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.side_effect = httpx.HTTPStatusError( + "500 Internal Server Error", + request=MagicMock(), + response=MagicMock(status_code=500, text="Internal error"), + ) + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_query_plan(metrics=["finance.revenue"]) + + assert "Error" in result + assert "500" in result + + +@pytest.mark.asyncio +async def test_get_query_plan_generic_error(): + """Test get_query_plan handles generic exceptions""" + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.side_effect = Exception("Unexpected failure") + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_query_plan(metrics=["finance.revenue"]) + + assert "Error" in result + assert "Unexpected failure" in result + + +@pytest.mark.asyncio +async def test_get_query_plan_empty_grain_groups(): + """Test get_query_plan with no grain groups""" + mock_response_json = { + "dialect": "spark", + "requested_dimensions": [], + "grain_groups": [], + "metric_formulas": [ + { + "name": "finance.revenue", + "query": "SUM(amount)", + "combiner": "SUM(revenue_sum)", + "is_derived": False, + "components": [], + }, + ], + } + + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = mock_response_json + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_query_plan(metrics=["finance.revenue"]) + + assert "Query Execution Plan" in result + assert "Grain Groups: 0" in result + assert "finance.revenue" in result + + +@pytest.mark.asyncio +async def test_get_query_plan_no_scan_estimate(): + """Test get_query_plan with missing scan_estimate in grain group""" + mock_response_json = { + "dialect": "spark", + "requested_dimensions": [], + "grain_groups": [ + { + "metrics": ["finance.revenue"], + "grain": [], + "aggregability": "FULL", + # no scan_estimate key + "components": [], + "sql": "SELECT SUM(amount) FROM orders", + }, + ], + "metric_formulas": [], + } + + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = mock_response_json + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + result = await tools.get_query_plan(metrics=["finance.revenue"]) + + assert "Query Execution Plan" in result + assert "Scan estimate:" not in result diff --git a/datajunction-clients/python/tests/test_cli.py b/datajunction-clients/python/tests/test_cli.py index c1adc555c..40a5922a2 100644 --- a/datajunction-clients/python/tests/test_cli.py +++ b/datajunction-clients/python/tests/test_cli.py @@ -2084,9 +2084,8 @@ def test_setup_claude_skills_only(tmp_path, monkeypatch): "MCP config should NOT be created with --no-mcp" ) - # Verify success message only for skills - assert "Skill installed" in output - assert "Skills are now available in Claude Code" in output + # Verify success message only for skills / subagent (not MCP) + assert "Skill installed" in output or "skill" in output.lower() assert "MCP server" not in output @@ -2267,7 +2266,8 @@ def test_setup_claude_skill_content_verification(tmp_path, monkeypatch): def test_setup_claude_no_config_path(tmp_path, monkeypatch): """Test MCP setup when no Claude config path can be found""" - # Create a temp home that doesn't have any of the expected paths + # Create a temp home that doesn't have any of the expected paths. + # Use --no-agents so the subagent installer doesn't create the home dir first. fake_home = tmp_path / "nonexistent" # Don't create fake_home or any parent directories @@ -2275,7 +2275,7 @@ def test_setup_claude_no_config_path(tmp_path, monkeypatch): with ( patch("pathlib.Path.home", return_value=fake_home), - patch.object(sys, "argv", ["dj", "setup-claude", "--no-skills"]), + patch.object(sys, "argv", ["dj", "setup-claude", "--no-skills", "--no-agents"]), patch("sys.stdout", new_callable=StringIO) as mock_stdout, ): from datajunction import cli as dj_cli @@ -2347,3 +2347,141 @@ def test_setup_claude_invalid_json_config(tmp_path, monkeypatch): new_config = json.loads(config_file.read_text()) assert "mcpServers" in new_config assert "datajunction" in new_config["mcpServers"] + + +def test_setup_claude_full_install_includes_subagent(tmp_path, monkeypatch): + """Test setup-claude (default) installs skills, MCP, and the subagent""" + claude_dir = tmp_path / ".claude" + claude_dir.mkdir() + + monkeypatch.setenv("HOME", str(tmp_path)) + + with ( + patch("pathlib.Path.home", return_value=tmp_path), + patch.object(sys, "argv", ["dj", "setup-claude"]), + patch("sys.stdout", new_callable=StringIO) as mock_stdout, + ): + from datajunction import cli as dj_cli + + dj_cli.main() + + output = mock_stdout.getvalue() + + # Verify subagent was installed + agent_file = claude_dir / "agents" / "dj.md" + assert agent_file.exists(), "dj.md subagent should be created" + content = agent_file.read_text() + assert "name: dj" in content + assert "skills:" in content + assert "datajunction" in content + + # Verify success output mentions subagent + assert "subagent" in output.lower() or "Installed subagent" in output + + +def test_setup_claude_no_agents(tmp_path, monkeypatch): + """Test setup-claude --no-agents skips subagent installation""" + claude_dir = tmp_path / ".claude" + claude_dir.mkdir() + + monkeypatch.setenv("HOME", str(tmp_path)) + + with ( + patch("pathlib.Path.home", return_value=tmp_path), + patch.object(sys, "argv", ["dj", "setup-claude", "--no-agents", "--no-mcp"]), + patch("sys.stdout", new_callable=StringIO), + ): + from datajunction import cli as dj_cli + + dj_cli.main() + + # Subagent file should NOT exist + agent_file = claude_dir / "agents" / "dj.md" + assert not agent_file.exists(), "dj.md should NOT be created with --no-agents" + + +def test_setup_claude_agents_only(tmp_path, monkeypatch): + """Test setup-claude --no-skills --no-mcp installs only the subagent""" + claude_dir = tmp_path / ".claude" + claude_dir.mkdir() + + monkeypatch.setenv("HOME", str(tmp_path)) + + with ( + patch("pathlib.Path.home", return_value=tmp_path), + patch.object(sys, "argv", ["dj", "setup-claude", "--no-skills", "--no-mcp"]), + patch("sys.stdout", new_callable=StringIO) as mock_stdout, + ): + from datajunction import cli as dj_cli + + dj_cli.main() + + output = mock_stdout.getvalue() + + # Subagent should be installed + agent_file = claude_dir / "agents" / "dj.md" + assert agent_file.exists(), "dj.md subagent should be created" + + # Skill should NOT be installed + skills_dir = claude_dir / "skills" / "datajunction" + assert not skills_dir.exists(), "Skills should NOT be installed with --no-skills" + + # MCP config should NOT be created + mcp_config_file = tmp_path / ".claude.json" + assert not mcp_config_file.exists(), ( + "MCP config should NOT be created with --no-mcp" + ) + + assert "subagent" in output.lower() or "Installed subagent" in output + + +def test_setup_claude_subagent_content(tmp_path, monkeypatch): + """Test the installed subagent has the correct frontmatter content""" + claude_dir = tmp_path / ".claude" + claude_dir.mkdir() + + monkeypatch.setenv("HOME", str(tmp_path)) + + with ( + patch("pathlib.Path.home", return_value=tmp_path), + patch.object(sys, "argv", ["dj", "setup-claude", "--no-skills", "--no-mcp"]), + patch("sys.stdout", new_callable=StringIO), + ): + from datajunction import cli as dj_cli + + dj_cli.main() + + agent_file = claude_dir / "agents" / "dj.md" + content = agent_file.read_text() + + assert content.startswith("---") + assert "name: dj" in content + assert "description:" in content + assert "DataJunction" in content + assert "skills:" in content + assert "- datajunction" in content + assert "model: inherit" in content + + +def test_setup_claude_subagent_overwrites_existing(tmp_path, monkeypatch): + """Test setup-claude overwrites an existing subagent file""" + claude_dir = tmp_path / ".claude" + agents_dir = claude_dir / "agents" + agents_dir.mkdir(parents=True) + agent_file = agents_dir / "dj.md" + agent_file.write_text("OLD CONTENT") + + monkeypatch.setenv("HOME", str(tmp_path)) + + with ( + patch("pathlib.Path.home", return_value=tmp_path), + patch.object(sys, "argv", ["dj", "setup-claude", "--no-skills", "--no-mcp"]), + patch("sys.stdout", new_callable=StringIO), + ): + from datajunction import cli as dj_cli + + dj_cli.main() + + new_content = agent_file.read_text() + assert "OLD CONTENT" not in new_content + assert "name: dj" in new_content