diff --git a/e2e-tests/mcpchecker/eval.yaml b/e2e-tests/mcpchecker/eval.yaml index 9529d91..7634627 100644 --- a/e2e-tests/mcpchecker/eval.yaml +++ b/e2e-tests/mcpchecker/eval.yaml @@ -69,13 +69,11 @@ config: - path: tasks/cve-cluster-does-exist.yaml assertions: toolsUsed: - - server: stackrox-mcp - toolPattern: "list_clusters" - server: stackrox-mcp toolPattern: "get_clusters_with_orchestrator_cve" argumentsMatch: cveName: "CVE-2016-1000031" - minToolCalls: 2 + minToolCalls: 1 maxToolCalls: 4 # Test 6: CVE with specific cluster filter (does not exist) @@ -85,7 +83,7 @@ config: - server: stackrox-mcp toolPattern: "list_clusters" minToolCalls: 1 - maxToolCalls: 2 + maxToolCalls: 4 # Test 7: CVE detected in clusters - general - path: tasks/cve-clusters-general.yaml diff --git a/internal/toolsets/vulnerability/cluster_resolver.go b/internal/toolsets/vulnerability/cluster_resolver.go new file mode 100644 index 0000000..fdf79e1 --- /dev/null +++ b/internal/toolsets/vulnerability/cluster_resolver.go @@ -0,0 +1,43 @@ +package vulnerability + +import ( + "context" + "fmt" + + v1 "github.com/stackrox/rox/generated/api/v1" + "google.golang.org/grpc" +) + +// resolveClusterID resolves a cluster name to its ID. +// Returns error if cluster name is not found or if API call fails. +func resolveClusterID(ctx context.Context, conn *grpc.ClientConn, + clusterID string, clusterName string) (string, error) { + // Cluster ID has priority. + if clusterID != "" { + return clusterID, nil + } + + if clusterName == "" { + return "", nil + } + + client := v1.NewClustersServiceClient(conn) + + // Use query to filter by cluster name server-side + query := fmt.Sprintf("Cluster:%q", clusterName) + + resp, err := client.GetClusters(ctx, &v1.GetClustersRequest{ + Query: query, + }) + if err != nil { + return "", fmt.Errorf("failed to fetch clusters: %w", err) + } + + clusters := resp.GetClusters() + if len(clusters) == 0 { + return "", fmt.Errorf("cluster with name %q not found", clusterName) + } + + // Return the first matching cluster's ID + return clusters[0].GetId(), nil +} diff --git a/internal/toolsets/vulnerability/cluster_resolver_test.go b/internal/toolsets/vulnerability/cluster_resolver_test.go new file mode 100644 index 0000000..4ace476 --- /dev/null +++ b/internal/toolsets/vulnerability/cluster_resolver_test.go @@ -0,0 +1,143 @@ +package vulnerability + +import ( + "context" + "errors" + "net" + "testing" + + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/stackrox-mcp/internal/toolsets/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +func getBufferConnection(t *testing.T, listener *bufconn.Listener) *grpc.ClientConn { + t.Helper() + + // Create a gRPC client connection to the mock server + conn, err := grpc.NewClient( + "passthrough://buffer", + grpc.WithLocalDNSResolution(), + grpc.WithContextDialer(func(_ context.Context, _ string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + + return conn +} + +func TestResolveClusterID_Success(t *testing.T) { + tests := map[string]struct { + clusterID string + clusterName string + mockClusters []*storage.Cluster + mockError error + expectedID string + expectedQuery string + }{ + "only cluster ID": { + clusterID: "only-cluster-id", + clusterName: "", + mockClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}}, + expectedID: "only-cluster-id", + }, + "cluster ID has priority": { + clusterID: "cluster-with-priority", + clusterName: "production", + mockClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}}, + expectedID: "cluster-with-priority", + }, + "empty cluster name returns empty ID": { + clusterID: "", + clusterName: "", + mockClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}}, + expectedID: "", + }, + "cluster name found returns correct ID": { + clusterID: "", + clusterName: "production", + mockClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}}, + expectedID: "cluster-1", + expectedQuery: `Cluster:"production"`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + mockService := mock.NewClustersServiceMock(testCase.mockClusters, testCase.mockError) + + grpcServer, listener := mock.SetupClusterServer(mockService) + defer grpcServer.Stop() + + conn := getBufferConnection(t, listener) + + defer func() { _ = conn.Close() }() + + clusterID, err := resolveClusterID( + context.Background(), + conn, + testCase.clusterID, + testCase.clusterName, + ) + + require.NoError(t, err) + assert.Equal(t, testCase.expectedID, clusterID) + assert.Equal(t, testCase.expectedQuery, mockService.GetLastCallQuery()) + }) + } +} + +func TestResolveClusterID_Failure(t *testing.T) { + tests := map[string]struct { + clusterName string + mockClusters []*storage.Cluster + mockError error + expectedErrText string + expectedQuery string + }{ + "cluster name not found returns error": { + clusterName: "nonexistent", + mockClusters: []*storage.Cluster{}, + expectedErrText: `cluster with name "nonexistent" not found`, + expectedQuery: `Cluster:"nonexistent"`, + }, + "API error propagation": { + clusterName: "production", + mockError: errors.New("API connection failed"), + expectedErrText: "failed to fetch clusters:", + expectedQuery: `Cluster:"production"`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + mockService := mock.NewClustersServiceMock(testCase.mockClusters, testCase.mockError) + + grpcServer, listener := mock.SetupClusterServer(mockService) + defer grpcServer.Stop() + + conn := getBufferConnection(t, listener) + + defer func() { _ = conn.Close() }() + + clusterID, err := resolveClusterID( + context.Background(), + conn, + "", + testCase.clusterName, + ) + + require.Error(t, err) + assert.Empty(t, clusterID) + assert.Contains(t, err.Error(), testCase.expectedErrText) + + assert.Equal(t, testCase.expectedQuery, mockService.GetLastCallQuery()) + }) + } +} diff --git a/internal/toolsets/vulnerability/clusters.go b/internal/toolsets/vulnerability/clusters.go index 509e754..0ed16d7 100644 --- a/internal/toolsets/vulnerability/clusters.go +++ b/internal/toolsets/vulnerability/clusters.go @@ -18,8 +18,9 @@ import ( // getClustersForCVEInput defines the input parameters for get_clusters_for_cve tool. type getClustersForCVEInput struct { - CVEName string `json:"cveName"` - FilterClusterID string `json:"filterClusterId,omitempty"` + CVEName string `json:"cveName"` + FilterClusterID string `json:"filterClusterId,omitempty"` + FilterClusterName string `json:"filterClusterName,omitempty"` } func (input *getClustersForCVEInput) validate() error { @@ -27,6 +28,10 @@ func (input *getClustersForCVEInput) validate() error { return errors.New("CVE name is required") } + if input.FilterClusterID != "" && input.FilterClusterName != "" { + return errors.New("cannot specify both filterClusterId and filterClusterName") + } + return nil } @@ -76,9 +81,7 @@ func (t *getClustersForCVETool) GetTool() *mcp.Tool { " Call ALL THREE CVE tools (get_clusters_with_orchestrator_cve, get_deployments_for_cve, get_nodes_for_cve)" + " for comprehensive coverage." + " 2) When user asks specifically about 'orchestrator', 'Kubernetes components'," + - " or 'control plane': Use ONLY this tool." + - " 3) For single cluster queries (e.g., 'in cluster X'): First call list_clusters to get cluster ID," + - " then call ONLY this tool with filterClusterId.", + " or 'control plane': Use ONLY this tool.", InputSchema: getClustersForCVEInputSchema(), } } @@ -97,11 +100,13 @@ func getClustersForCVEInputSchema() *jsonschema.Schema { schema.Properties["cveName"].Description = "CVE name to filter clusters (e.g., CVE-2021-44228)" schema.Properties["filterClusterId"].Description = - "Optional cluster ID (cluster ID only, not cluster name) to verify if CVE is detected in a specific cluster." + - " Only use this parameter when the user's query explicitly mentions a specific cluster name." + - " When checking if a CVE exists at all, call without this parameter to check all clusters at once." + - " To resolve cluster names to IDs, use list_clusters tool first." + - " If the cluster doesn't exist, respond that the CVE is not detected in that cluster (since it doesn't exist)." + "Optional cluster ID to verify if CVE is detected in a specific cluster." + + " Cannot be used together with filterClusterName." + + " When checking if a CVE exists at all, call without this parameter to check all clusters at once." + schema.Properties["filterClusterName"].Description = + "Optional cluster name to verify if CVE is detected in a specific cluster." + + " Cannot be used together with filterClusterId." + + " When checking if a CVE exists at all, call without this parameter to check all clusters at once." return schema } @@ -143,7 +148,18 @@ func (t *getClustersForCVETool) handle( clustersClient := v1.NewClustersServiceClient(conn) - query := buildClusterQuery(input) + // Resolve cluster name to ID if provided + resolvedClusterID, err := resolveClusterID(callCtx, conn, input.FilterClusterID, input.FilterClusterName) + if err != nil { + return nil, nil, err + } + + // Build query using the resolved cluster ID + queryInput := getClustersForCVEInput{ + CVEName: input.CVEName, + FilterClusterID: resolvedClusterID, + } + query := buildClusterQuery(queryInput) resp, err := clustersClient.GetClusters(callCtx, &v1.GetClustersRequest{ Query: query, diff --git a/internal/toolsets/vulnerability/clusters_test.go b/internal/toolsets/vulnerability/clusters_test.go index 62daf57..a3ebec3 100644 --- a/internal/toolsets/vulnerability/clusters_test.go +++ b/internal/toolsets/vulnerability/clusters_test.go @@ -58,6 +58,8 @@ func TestGetClustersForCVETool_RegisterWith(t *testing.T) { } // Unit tests for input validate method. +// +//nolint:dupl // Duplication to `TestNodeInputValidate` is detected. They use different input types. func TestClusterInputValidate(t *testing.T) { tests := map[string]struct { input getClustersForCVEInput @@ -78,6 +80,29 @@ func TestClusterInputValidate(t *testing.T) { expectError: true, errorMsg: "CVE name is required", }, + "both cluster ID and name provided": { + input: getClustersForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterID: "cluster-123", + FilterClusterName: "production", + }, + expectError: true, + errorMsg: "cannot specify both filterClusterId and filterClusterName", + }, + "only cluster ID provided": { + input: getClustersForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterID: "cluster-123", + }, + expectError: false, + }, + "only cluster name provided": { + input: getClustersForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterName: "production", + }, + expectError: false, + }, } for testName, testCase := range tests { @@ -297,3 +322,68 @@ func TestClusterHandle_WithFilters(t *testing.T) { }) } } + +func TestClusterHandle_WithValidClusterNameFilter(t *testing.T) { + tests := map[string]struct { + clusterName string + returnedClusters []*storage.Cluster + expectedQuery string + }{ + "cluster name found": { + clusterName: "production", + returnedClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}}, + expectedQuery: `CVE:"CVE-2021-44228"+Cluster ID:"cluster-1"`, + }, + "empty cluster name": { + clusterName: "", + returnedClusters: []*storage.Cluster{}, + expectedQuery: `CVE:"CVE-2021-44228"`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + mockService := mock.NewClustersServiceMock(testCase.returnedClusters, nil) + + grpcServer, listener := mock.SetupClusterServer(mockService) + defer grpcServer.Stop() + + tool, ok := NewGetClustersForCVETool(createTestClient(t, listener)).(*getClustersForCVETool) + require.True(t, ok) + + input := getClustersForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterName: testCase.clusterName, + } + + result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + assert.Contains(t, mockService.GetLastCallQuery(), testCase.expectedQuery) + }) + } +} + +func TestClusterHandle_WithNotValidClusterNameFilter(t *testing.T) { + mockService := mock.NewClustersServiceMock([]*storage.Cluster{}, nil) + + grpcServer, listener := mock.SetupClusterServer(mockService) + defer grpcServer.Stop() + + tool, ok := NewGetClustersForCVETool(createTestClient(t, listener)).(*getClustersForCVETool) + require.True(t, ok) + + input := getClustersForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterName: "nonexistent", + } + + result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, input) + + require.Error(t, err) + assert.Contains(t, err.Error(), `cluster with name "nonexistent" not found`) + assert.Nil(t, result) + assert.Nil(t, output) +} diff --git a/internal/toolsets/vulnerability/deployments.go b/internal/toolsets/vulnerability/deployments.go index 8637ee9..9df02ed 100644 --- a/internal/toolsets/vulnerability/deployments.go +++ b/internal/toolsets/vulnerability/deployments.go @@ -15,6 +15,7 @@ import ( "github.com/stackrox/stackrox-mcp/internal/cursor" "github.com/stackrox/stackrox-mcp/internal/logging" "github.com/stackrox/stackrox-mcp/internal/toolsets" + "google.golang.org/grpc" ) const ( @@ -33,6 +34,7 @@ const ( type getDeploymentsForCVEInput struct { CVEName string `json:"cveName"` FilterClusterID string `json:"filterClusterId,omitempty"` + FilterClusterName string `json:"filterClusterName,omitempty"` FilterNamespace string `json:"filterNamespace,omitempty"` FilterPlatform filterPlatformType `json:"filterPlatform,omitempty"` IncludeDetectedImages bool `json:"includeDetectedImages,omitempty"` @@ -44,6 +46,10 @@ func (input *getDeploymentsForCVEInput) validate() error { return errors.New("CVE name is required") } + if input.FilterClusterID != "" && input.FilterClusterName != "" { + return errors.New("cannot specify both filterClusterId and filterClusterName") + } + return nil } @@ -100,9 +106,7 @@ func (t *getDeploymentsForCVETool) GetTool() *mcp.Tool { " Call ALL THREE CVE tools (get_clusters_with_orchestrator_cve, get_deployments_for_cve, get_nodes_for_cve)" + " for comprehensive coverage." + " 2) When user asks specifically about 'deployments', 'workloads', 'applications'," + - " or 'containers': Use ONLY this tool." + - " 3) For single cluster queries (e.g., 'in cluster X'): First call list_clusters to get cluster ID," + - " then call ONLY this tool with filterClusterId.", + " or 'containers': Use ONLY this tool.", InputSchema: getDeploymentsForCVEInputSchema(), } } @@ -120,7 +124,10 @@ func getDeploymentsForCVEInputSchema() *jsonschema.Schema { schema.Required = []string{"cveName"} schema.Properties["cveName"].Description = "CVE name to filter deployments (e.g., CVE-2021-44228)" - schema.Properties["filterClusterId"].Description = "Optional cluster ID to filter deployments" + schema.Properties["filterClusterId"].Description = "Optional cluster ID to filter deployments." + + " Cannot be used together with filterClusterName." + schema.Properties["filterClusterName"].Description = "Optional cluster name to filter deployments." + + " Cannot be used together with filterClusterId." schema.Properties["filterNamespace"].Description = "Optional namespace to filter deployments" schema.Properties["filterPlatform"].Description = @@ -263,6 +270,18 @@ func fetchImagesForDeployment( return images, nil } +func enrichDeploymentsWithImages(ctx context.Context, conn *grpc.ClientConn, + cveName string, deployments []DeploymentResult) { + imageClient := v1.NewImageServiceClient(conn) + enricher := newDeploymentEnricher(imageClient, cveName, defaultMaxFetchImageConcurrency) + + for i := range deployments { + enricher.enrich(ctx, &deployments[i]) + } + + enricher.wait() +} + // handle is the handler for get_deployments_for_cve tool. // //nolint:funlen @@ -287,10 +306,25 @@ func (t *getDeploymentsForCVETool) handle( } callCtx := auth.WithMCPRequestContext(ctx, req) + + // Resolve cluster name to ID if provided + resolvedClusterID, err := resolveClusterID(callCtx, conn, input.FilterClusterID, input.FilterClusterName) + if err != nil { + return nil, nil, err + } + deploymentClient := v1.NewDeploymentServiceClient(conn) + // Build query using the resolved cluster ID + queryInput := getDeploymentsForCVEInput{ + CVEName: input.CVEName, + FilterClusterID: resolvedClusterID, + FilterNamespace: input.FilterNamespace, + FilterPlatform: input.FilterPlatform, + } + listReq := &v1.RawQuery{ - Query: buildQuery(input), + Query: buildQuery(queryInput), Pagination: &v1.Pagination{ Offset: currCursor.GetOffset(), Limit: defaultLimit + 1, @@ -316,14 +350,7 @@ func (t *getDeploymentsForCVETool) handle( } if input.IncludeDetectedImages { - imageClient := v1.NewImageServiceClient(conn) - enricher := newDeploymentEnricher(imageClient, input.CVEName, defaultMaxFetchImageConcurrency) - - for i := range deployments { - enricher.enrich(callCtx, &deployments[i]) - } - - enricher.wait() + enrichDeploymentsWithImages(callCtx, conn, input.CVEName, deployments) } // We always fetch limit+1 - if we do not have one additional element we can end paging. diff --git a/internal/toolsets/vulnerability/deployments_test.go b/internal/toolsets/vulnerability/deployments_test.go index 01bc035..e8bd453 100644 --- a/internal/toolsets/vulnerability/deployments_test.go +++ b/internal/toolsets/vulnerability/deployments_test.go @@ -87,6 +87,29 @@ func TestInputValidate(t *testing.T) { expectError: true, errorMsg: "CVE name is required", }, + "both cluster ID and name provided": { + input: getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterID: "cluster-123", + FilterClusterName: "production", + }, + expectError: true, + errorMsg: "cannot specify both filterClusterId and filterClusterName", + }, + "only cluster ID provided": { + input: getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterID: "cluster-123", + }, + expectError: false, + }, + "only cluster name provided": { + input: getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterName: "production", + }, + expectError: false, + }, } for testName, testCase := range tests { @@ -513,3 +536,80 @@ func TestHandle_ImageFetchPartialFailure(t *testing.T) { } } } + +func TestDeploymentHandle_WithValidClusterNameFilter(t *testing.T) { + tests := map[string]struct { + clusterName string + returnedClusters []*storage.Cluster + expectedQuery string + }{ + "cluster name found": { + clusterName: "production", + returnedClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}}, + expectedQuery: `CVE:"CVE-2021-44228"+Cluster ID:"cluster-1"`, + }, + "empty cluster name": { + clusterName: "", + returnedClusters: []*storage.Cluster{}, + expectedQuery: `CVE:"CVE-2021-44228"`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + deploymentService := mock.NewDeploymentServiceMock(getTestDeployments(1), nil) + clusterService := mock.NewClustersServiceMock(testCase.returnedClusters, nil) + + grpcServer, listener := mock.SetupAPIServer( + deploymentService, + v1.UnimplementedImageServiceServer{}, + v1.UnimplementedNodeServiceServer{}, + clusterService, + ) + defer grpcServer.Stop() + + tool, ok := NewGetDeploymentsForCVETool(createTestClient(t, listener)).(*getDeploymentsForCVETool) + require.True(t, ok) + + input := getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterName: testCase.clusterName, + } + + result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + assert.Contains(t, deploymentService.GetLastCallQuery(), testCase.expectedQuery) + }) + } +} + +func TestDeploymentHandle_WithNotValidClusterNameFilter(t *testing.T) { + deploymentService := mock.NewDeploymentServiceMock(getTestDeployments(1), nil) + clusterService := mock.NewClustersServiceMock([]*storage.Cluster{}, nil) + + grpcServer, listener := mock.SetupAPIServer( + deploymentService, + v1.UnimplementedImageServiceServer{}, + v1.UnimplementedNodeServiceServer{}, + clusterService, + ) + defer grpcServer.Stop() + + tool, ok := NewGetDeploymentsForCVETool(createTestClient(t, listener)).(*getDeploymentsForCVETool) + require.True(t, ok) + + input := getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterName: "nonexistent", + } + + result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, input) + + require.Error(t, err) + assert.Contains(t, err.Error(), `cluster with name "nonexistent" not found`) + assert.Nil(t, result) + assert.Nil(t, output) +} diff --git a/internal/toolsets/vulnerability/nodes.go b/internal/toolsets/vulnerability/nodes.go index 357f8a8..a00c58f 100644 --- a/internal/toolsets/vulnerability/nodes.go +++ b/internal/toolsets/vulnerability/nodes.go @@ -20,8 +20,9 @@ import ( // getNodesForCVEInput defines the input parameters for get_nodes_for_cve tool. type getNodesForCVEInput struct { - CVEName string `json:"cveName"` - FilterClusterID string `json:"filterClusterId,omitempty"` + CVEName string `json:"cveName"` + FilterClusterID string `json:"filterClusterId,omitempty"` + FilterClusterName string `json:"filterClusterName,omitempty"` } func (input *getNodesForCVEInput) validate() error { @@ -29,6 +30,10 @@ func (input *getNodesForCVEInput) validate() error { return errors.New("CVE name is required") } + if input.FilterClusterID != "" && input.FilterClusterName != "" { + return errors.New("cannot specify both filterClusterId and filterClusterName") + } + return nil } @@ -80,9 +85,7 @@ func (t *getNodesForCVETool) GetTool() *mcp.Tool { " Call ALL THREE CVE tools (get_clusters_with_orchestrator_cve, get_deployments_for_cve, get_nodes_for_cve)" + " for comprehensive coverage." + " 2) When user asks specifically about 'nodes', 'hosts'," + - " or 'operating systems': Use ONLY this tool." + - " 3) For single cluster queries (e.g., 'in cluster X'): First call list_clusters to get cluster ID," + - " then call ONLY this tool with filterClusterId.", + " or 'operating systems': Use ONLY this tool.", InputSchema: getNodesForCVEInputSchema(), } } @@ -100,7 +103,10 @@ func getNodesForCVEInputSchema() *jsonschema.Schema { schema.Required = []string{"cveName"} schema.Properties["cveName"].Description = "CVE name to filter nodes (e.g., CVE-2020-26159)" - schema.Properties["filterClusterId"].Description = "Optional cluster ID to filter nodes" + schema.Properties["filterClusterId"].Description = "Optional cluster ID to filter nodes." + + " Cannot be used together with filterClusterName." + schema.Properties["filterClusterName"].Description = "Optional cluster name to filter nodes." + + " Cannot be used together with filterClusterId." return schema } @@ -197,9 +203,21 @@ func (t *getNodesForCVETool) handle( } callCtx := auth.WithMCPRequestContext(ctx, req) + + // Resolve cluster name to ID if provided + resolvedClusterID, err := resolveClusterID(callCtx, conn, input.FilterClusterID, input.FilterClusterName) + if err != nil { + return nil, nil, err + } + nodeClient := v1.NewNodeServiceClient(conn) - query := buildNodeQuery(input) + // Build query using the resolved cluster ID + queryInput := getNodesForCVEInput{ + CVEName: input.CVEName, + FilterClusterID: resolvedClusterID, + } + query := buildNodeQuery(queryInput) exportReq := &v1.ExportNodeRequest{ Query: query, } diff --git a/internal/toolsets/vulnerability/nodes_test.go b/internal/toolsets/vulnerability/nodes_test.go index 49cd8b6..b59a4a5 100644 --- a/internal/toolsets/vulnerability/nodes_test.go +++ b/internal/toolsets/vulnerability/nodes_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/modelcontextprotocol/go-sdk/mcp" + v1 "github.com/stackrox/rox/generated/api/v1" "github.com/stackrox/rox/generated/storage" "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/toolsets/mock" @@ -57,6 +58,8 @@ func TestGetNodesForCVETool_RegisterWith(t *testing.T) { } // Unit tests for input validate method. +// +//nolint:dupl // Duplication to `TestClusterInputValidate` is detected. They use different input types. func TestNodeInputValidate(t *testing.T) { tests := map[string]struct { input getNodesForCVEInput @@ -77,6 +80,29 @@ func TestNodeInputValidate(t *testing.T) { expectError: true, errorMsg: "CVE name is required", }, + "both cluster ID and name provided": { + input: getNodesForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterID: "cluster-123", + FilterClusterName: "production", + }, + expectError: true, + errorMsg: "cannot specify both filterClusterId and filterClusterName", + }, + "only cluster ID provided": { + input: getNodesForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterID: "cluster-123", + }, + expectError: false, + }, + "only cluster name provided": { + input: getNodesForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterName: "production", + }, + expectError: false, + }, } for testName, testCase := range tests { @@ -326,3 +352,90 @@ func TestNodeHandle_WithFilters(t *testing.T) { }) } } + +func TestNodeHandle_WithValidClusterNameFilter(t *testing.T) { + tests := map[string]struct { + clusterName string + returnedClusters []*storage.Cluster + expectedQuery string + }{ + "cluster name found": { + clusterName: "production", + returnedClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}}, + expectedQuery: `CVE:"CVE-2021-44228"+Cluster ID:"cluster-1"`, + }, + "empty cluster name": { + clusterName: "", + returnedClusters: []*storage.Cluster{}, + expectedQuery: `CVE:"CVE-2021-44228"`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + nodeService := mock.NewNodeServiceMock( + []*storage.Node{ + {Name: "n1", ClusterId: "cluster-1", ClusterName: "production", OsImage: "Ubuntu 20.04"}, + }, + nil, + ) + clusterService := mock.NewClustersServiceMock(testCase.returnedClusters, nil) + + grpcServer, listener := mock.SetupAPIServer( + v1.UnimplementedDeploymentServiceServer{}, + v1.UnimplementedImageServiceServer{}, + nodeService, + clusterService, + ) + defer grpcServer.Stop() + + tool, ok := NewGetNodesForCVETool(createTestClient(t, listener)).(*getNodesForCVETool) + require.True(t, ok) + + input := getNodesForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterName: testCase.clusterName, + } + + result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + assert.Contains(t, nodeService.GetLastCallQuery(), testCase.expectedQuery) + }) + } +} + +func TestNodeHandle_WithNotValidClusterNameFilter(t *testing.T) { + nodeService := mock.NewNodeServiceMock( + []*storage.Node{ + {Name: "n1", ClusterId: "cluster-1", ClusterName: "production", OsImage: "Ubuntu 20.04"}, + }, + nil, + ) + clusterService := mock.NewClustersServiceMock([]*storage.Cluster{}, nil) + + grpcServer, listener := mock.SetupAPIServer( + v1.UnimplementedDeploymentServiceServer{}, + v1.UnimplementedImageServiceServer{}, + nodeService, + clusterService, + ) + defer grpcServer.Stop() + + tool, ok := NewGetNodesForCVETool(createTestClient(t, listener)).(*getNodesForCVETool) + require.True(t, ok) + + input := getNodesForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterName: "nonexistent", + } + + result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, input) + + require.Error(t, err) + assert.Contains(t, err.Error(), `cluster with name "nonexistent" not found`) + assert.Nil(t, result) + assert.Nil(t, output) +}