diff --git a/.golangci.yml b/.golangci.yml index 3e7ff0d..e4d7468 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -19,6 +19,9 @@ linters: - paralleltest - testpackage - noinlineerr + - ireturn + # requires package-level variables created from errors.New() + - err113 issues: max-issues-per-linter: 0 max-same-issues: 0 diff --git a/Makefile b/Makefile index de29de0..ed8df6c 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,9 @@ # Binary name BINARY_NAME=stackrox-mcp +# Version (can be overridden with VERSION=x.y.z make build) +VERSION?=0.1.0 + # Go parameters GOCMD=go GOBUILD=$(GOCMD) build @@ -11,6 +14,9 @@ GOTEST=$(GOCMD) test GOFMT=$(GOCMD) fmt GOCLEAN=$(GOCMD) clean +# Build flags +LDFLAGS=-ldflags "-X github.com/stackrox/stackrox-mcp/internal/server.version=$(VERSION)" + # Coverage files COVERAGE_OUT=coverage.out @@ -24,7 +30,7 @@ help: ## Display this help message .PHONY: build build: ## Build the binary - $(GOBUILD) -o $(BINARY_NAME) ./cmd/stackrox-mcp + $(GOBUILD) $(LDFLAGS) -o $(BINARY_NAME) ./cmd/stackrox-mcp .PHONY: test test: ## Run unit tests with coverage diff --git a/README.md b/README.md index 8b72e86..331e507 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ export STACKROX_MCP__TOOLS__VULNERABILITY__ENABLED=true ./stackrox-mcp ``` +The server will start on `http://localhost:8080` by default. See the [Testing the MCP Server](#testing-the-mcp-server) section for instructions on connecting with Claude Code. + ## Configuration The StackRox MCP server supports configuration through both YAML files and environment variables. Environment variables take precedence over YAML configuration. @@ -80,6 +82,15 @@ Global MCP server settings. |--------|---------------------|------|----------|---------|-------------| | `global.read_only_tools` | `STACKROX_MCP__GLOBAL__READ_ONLY_TOOLS` | bool | No | `true` | Only allow read-only tools | +#### Server Configuration + +HTTP server settings for the MCP server. + +| Option | Environment Variable | Type | Required | Default | Description | +|--------|---------------------|------|----------|---------|-------------| +| `server.address` | `STACKROX_MCP__SERVER__ADDRESS` | string | No | `localhost` | HTTP server listen address | +| `server.port` | `STACKROX_MCP__SERVER__PORT` | int | No | `8080` | HTTP server listen port (must be 1-65535) | + #### Tools Configuration Enable or disable individual MCP tools. At least one tool has to be enabled. @@ -97,6 +108,63 @@ Configuration values are loaded in the following order (later sources override e 2. YAML configuration file (if provided via `--config`) 3. Environment variables (highest precedence) +## Testing the MCP Server + +### Starting the Server + +Start the server with a configuration file: + +```bash +./stackrox-mcp --config examples/config-read-only.yaml +``` + +Or using environment variables: + +```bash +export STACKROX_MCP__CENTRAL__URL="central.example.com:8443" +export STACKROX_MCP__TOOLS__VULNERABILITY__ENABLED="true" +./stackrox-mcp +``` + +The server will start on `http://localhost:8080` by default (configurable via `server.address` and `server.port`). + +### Connecting with Claude Code CLI + +Add the MCP server to Claude Code using command-line options: + +```bash +claude mcp add stackrox \ + --name "StackRox MCP Server" \ + --transport http \ + --url http://localhost:8080 +``` + +### Verifying Connection + +List configured MCP servers: + +```bash +claude mcp list +``` + +Get details for a specific server: + +```bash +claude mcp get stackrox +``` + +Within a Claude Code session, use the `/mcp` command to view available tools from connected servers. + +### Example Usage + +Once connected, interact with the tools using natural language: + +**List all clusters:** +``` +You: "Can you list all the clusters from StackRox?" +Claude: [Uses list_clusters tool to retrieve cluster information] +``` + ## Development For detailed development guidelines, testing standards, and contribution workflows, see [CONTRIBUTING.md](.github/CONTRIBUTING.md). diff --git a/cmd/stackrox-mcp/main.go b/cmd/stackrox-mcp/main.go index 2c81c41..73f023e 100644 --- a/cmd/stackrox-mcp/main.go +++ b/cmd/stackrox-mcp/main.go @@ -2,14 +2,29 @@ package main import ( + "context" "flag" "log/slog" "os" + "os/signal" + "syscall" "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stackrox/stackrox-mcp/internal/logging" + "github.com/stackrox/stackrox-mcp/internal/server" + "github.com/stackrox/stackrox-mcp/internal/toolsets" + toolsetConfig "github.com/stackrox/stackrox-mcp/internal/toolsets/config" + toolsetVulnerability "github.com/stackrox/stackrox-mcp/internal/toolsets/vulnerability" ) +// getToolsets initializes and returns all available toolsets. +func getToolsets(cfg *config.Config) []toolsets.Toolset { + return []toolsets.Toolset{ + toolsetConfig.NewToolset(cfg), + toolsetVulnerability.NewToolset(cfg), + } +} + func main() { logging.SetupLogging() @@ -19,11 +34,30 @@ func main() { cfg, err := config.LoadConfig(*configPath) if err != nil { - slog.Error("Failed to load configuration", "error", err) - os.Exit(1) + logging.Fatal("Failed to load configuration", err) } slog.Info("Configuration loaded successfully", "config", cfg) + registry := toolsets.NewRegistry(cfg, getToolsets(cfg)) + srv := server.NewServer(cfg, registry) + + // Set up context with signal handling for graceful shutdown. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + go func() { + <-sigChan + slog.Info("Received shutdown signal") + cancel() + }() + slog.Info("Starting Stackrox MCP server") + + if err := srv.Start(ctx); err != nil { + logging.Fatal("Server error", err) + } } diff --git a/cmd/stackrox-mcp/main_test.go b/cmd/stackrox-mcp/main_test.go new file mode 100644 index 0000000..38ddda8 --- /dev/null +++ b/cmd/stackrox-mcp/main_test.go @@ -0,0 +1,100 @@ +package main + +import ( + "context" + "errors" + "net" + "net/http" + "strconv" + "testing" + "time" + + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/server" + "github.com/stackrox/stackrox-mcp/internal/testutil" + "github.com/stackrox/stackrox-mcp/internal/toolsets" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func getDefaultConfig() *config.Config { + return &config.Config{ + Global: config.GlobalConfig{ + ReadOnlyTools: false, + }, + Central: config.CentralConfig{ + URL: "central.example.com:8443", + }, + Server: config.ServerConfig{ + Address: "localhost", + Port: 8080, + }, + Tools: config.ToolsConfig{ + Vulnerability: config.ToolsetVulnerabilityConfig{ + Enabled: true, + }, + ConfigManager: config.ToolConfigManagerConfig{ + Enabled: false, + }, + }, + } +} + +func TestGetToolsets(t *testing.T) { + cfg := getDefaultConfig() + cfg.Tools.ConfigManager.Enabled = true + + allToolsets := getToolsets(cfg) + + require.NotNil(t, allToolsets) + assert.Len(t, allToolsets, 2, "Should have 2 allToolsets") + assert.Equal(t, "config_manager", allToolsets[0].GetName()) + assert.Equal(t, "vulnerability", allToolsets[1].GetName()) +} + +func TestGracefulShutdown(t *testing.T) { + // Set up minimal valid config. + t.Setenv("STACKROX_MCP__TOOLS__VULNERABILITY__ENABLED", "true") + + cfg, err := config.LoadConfig("") + require.NoError(t, err) + require.NotNil(t, cfg) + cfg.Server.Port = testutil.GetPortForTest(t) + + registry := toolsets.NewRegistry(cfg, getToolsets(cfg)) + srv := server.NewServer(cfg, registry) + ctx, cancel := context.WithCancel(context.Background()) + + errChan := make(chan error, 1) + + go func() { + errChan <- srv.Start(ctx) + }() + + serverURL := "http://" + net.JoinHostPort(cfg.Server.Address, strconv.Itoa(cfg.Server.Port)) + err = testutil.WaitForServerReady(serverURL, 3*time.Second) + require.NoError(t, err, "Server should start within timeout") + + // Establish actual HTTP connection to verify server is responding. + //nolint:gosec,noctx + resp, err := http.Get(serverURL) + if err == nil { + _ = resp.Body.Close() + } + + require.NoError(t, err, "Should be able to establish HTTP connection to server") + + // Simulate shutdown signal by canceling context. + cancel() + + // Wait for server to shut down. + select { + case err := <-errChan: + // Server should shut down cleanly (either nil or context.Canceled). + if err != nil && errors.Is(err, context.Canceled) { + t.Errorf("Server returned unexpected error: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("Server did not shut down within timeout period") + } +} diff --git a/examples/config-read-only.yaml b/examples/config-read-only.yaml index 2c81cd9..24d45d9 100644 --- a/examples/config-read-only.yaml +++ b/examples/config-read-only.yaml @@ -36,6 +36,17 @@ global: # When false, both read and write tools may be available (if implemented) read_only_tools: true +# HTTP server configuration +server: + # Server listen address (optional, default: localhost) + # The address on which the MCP HTTP server will listen + address: localhost + + # Server listen port (optional, default: 8080) + # The port on which the MCP HTTP server will listen + # Must be between 1 and 65535 + port: 8080 + # Configuration of MCP tools # Each tool has an enable/disable flag. At least one tool has to be enabled. tools: diff --git a/go.mod b/go.mod index 0aa36fb..86e7876 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/stackrox/stackrox-mcp go 1.24 require ( + github.com/modelcontextprotocol/go-sdk v1.1.0 github.com/pkg/errors v0.9.1 github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 @@ -12,6 +13,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect @@ -20,7 +22,9 @@ require ( github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index de4a809..c2ac974 100644 --- a/go.sum +++ b/go.sum @@ -6,12 +6,16 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/modelcontextprotocol/go-sdk v1.1.0 h1:Qjayg53dnKC4UZ+792W21e4BpwEZBzwgRW6LrjLWSwA= +github.com/modelcontextprotocol/go-sdk v1.1.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -36,12 +40,18 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/config/config.go b/internal/config/config.go index 175d63d..6796d43 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,10 +8,13 @@ import ( "github.com/spf13/viper" ) +const defaultPort = 8080 + // Config represents the complete application configuration. type Config struct { Central CentralConfig `mapstructure:"central"` Global GlobalConfig `mapstructure:"global"` + Server ServerConfig `mapstructure:"server"` Tools ToolsConfig `mapstructure:"tools"` } @@ -27,6 +30,12 @@ type GlobalConfig struct { ReadOnlyTools bool `mapstructure:"read_only_tools"` } +// ServerConfig contains HTTP server configuration. +type ServerConfig struct { + Address string `mapstructure:"address"` + Port int `mapstructure:"port"` +} + // ToolsConfig contains configuration for individual MCP tools. type ToolsConfig struct { Vulnerability ToolsetVulnerabilityConfig `mapstructure:"vulnerability"` @@ -87,6 +96,9 @@ func setDefaults(viper *viper.Viper) { viper.SetDefault("global.read_only_tools", true) + viper.SetDefault("server.address", "localhost") + viper.SetDefault("server.port", defaultPort) + viper.SetDefault("tools.vulnerability.enabled", false) viper.SetDefault("tools.config_manager.enabled", false) } @@ -102,6 +114,14 @@ func (c *Config) Validate() error { return errURLRequired } + if c.Server.Address == "" { + return errors.New("server.address is required") + } + + if c.Server.Port < 1 || c.Server.Port > 65535 { + return errors.New("server.port must be between 1 and 65535") + } + if !c.Tools.Vulnerability.Enabled && !c.Tools.ConfigManager.Enabled { return errAtLeastOneTool } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 6b07419..b59bce2 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -2,17 +2,32 @@ package config import ( "os" - "path/filepath" "testing" + "github.com/stackrox/stackrox-mcp/internal/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestLoadConfig_FromYAML(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") +// getDefaultConfig returns a default config for testing validation logic. +func getDefaultConfig() *Config { + return &Config{ + Central: CentralConfig{ + URL: "central.example.com:8443", + }, + Server: ServerConfig{ + Address: "localhost", + Port: 8080, + }, + Tools: ToolsConfig{ + Vulnerability: ToolsetVulnerabilityConfig{ + Enabled: true, + }, + }, + } +} +func TestLoadConfig_FromYAML(t *testing.T) { yamlContent := ` central: url: central.example.com:8443 @@ -26,8 +41,8 @@ tools: config_manager: enabled: False ` - err := os.WriteFile(configPath, []byte(yamlContent), 0600) - require.NoError(t, err) + + configPath := testutil.WriteYAMLFile(t, yamlContent) defer func() { assert.NoError(t, os.Remove(configPath)) }() @@ -52,10 +67,7 @@ tools: enabled: false ` - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - err := os.WriteFile(configPath, []byte(yamlContent), 0600) - require.NoError(t, err) + configPath := testutil.WriteYAMLFile(t, yamlContent) defer func() { assert.NoError(t, os.Remove(configPath)) }() @@ -91,7 +103,7 @@ func TestLoadConfig_EnvVarOnly(t *testing.T) { } func TestLoadConfig_Defaults(t *testing.T) { - // Set only required field + // Set only required field. t.Setenv("STACKROX_MCP__TOOLS__CONFIG_MANAGER__ENABLED", "true") cfg, err := LoadConfig("") @@ -124,33 +136,55 @@ func TestLoadConfig_MissingFile(t *testing.T) { } func TestLoadConfig_InvalidYAML(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - invalidYAML := ` central: url: central.example.com:8443 invalid yaml syntax here: [[[ ` - err := os.WriteFile(configPath, []byte(invalidYAML), 0600) - require.NoError(t, err) + configPath := testutil.WriteYAMLFile(t, invalidYAML) + + defer func() { assert.NoError(t, os.Remove(configPath)) }() - _, err = LoadConfig(configPath) + _, err := LoadConfig(configPath) assert.Error(t, err) } +func TestLoadConfig_UnmarshalFailure(t *testing.T) { + // YAML with type mismatch - port should be int. + invalidTypeYAML := ` +server: + port: "not-a-number" +` + configPath := testutil.WriteYAMLFile(t, invalidTypeYAML) + _, err := LoadConfig(configPath) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal config") +} + +func TestLoadConfig_ValidationFailure(t *testing.T) { + // Valid YAML but fails on central URL validation (no URL). + validYAMLInvalidConfig := ` +central: + url: "" +server: + address: localhost + port: 8080 +tools: + vulnerability: + enabled: true +` + + configPath := testutil.WriteYAMLFile(t, validYAMLInvalidConfig) + _, err := LoadConfig(configPath) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid configuration") + assert.Contains(t, err.Error(), "central.url is required") +} + func TestValidate_MissingURL(t *testing.T) { - cfg := &Config{ - Central: CentralConfig{ - URL: "", - }, - Tools: ToolsConfig{ - Vulnerability: ToolsetVulnerabilityConfig{ - Enabled: true, - }, - }, - } + cfg := getDefaultConfig() + cfg.Central.URL = "" err := cfg.Validate() require.Error(t, err) @@ -158,11 +192,8 @@ func TestValidate_MissingURL(t *testing.T) { } func TestValidate_AtLeastOneTool(t *testing.T) { - cfg := &Config{ - Central: CentralConfig{ - URL: "central.example.com:8443", - }, - } + cfg := getDefaultConfig() + cfg.Tools.Vulnerability.Enabled = false err := cfg.Validate() require.Error(t, err) @@ -170,25 +201,50 @@ func TestValidate_AtLeastOneTool(t *testing.T) { } func TestValidate_ValidConfig(t *testing.T) { - cfg := &Config{ - Central: CentralConfig{ - URL: "central.example.com:8443", - Insecure: false, - ForceHTTP1: false, - }, - Global: GlobalConfig{ - ReadOnlyTools: true, - }, - Tools: ToolsConfig{ - Vulnerability: ToolsetVulnerabilityConfig{ - Enabled: true, - }, - ConfigManager: ToolConfigManagerConfig{ - Enabled: false, - }, - }, - } + cfg := getDefaultConfig() err := cfg.Validate() assert.NoError(t, err) } + +func TestValidate_MissingServerAddress(t *testing.T) { + cfg := getDefaultConfig() + cfg.Server.Address = "" + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "server.address is required") +} + +func TestValidate_InvalidServerPort(t *testing.T) { + tests := map[string]struct { + port int + }{ + "zero port": {port: 0}, + "negative port": {port: -1}, + "port too high": {port: 65536}, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + cfg := getDefaultConfig() + cfg.Server.Port = tt.port + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "server.port must be between 1 and 65535") + }) + } +} + +func TestLoadConfig_ServerDefaults(t *testing.T) { + // Set only required fields. + t.Setenv("STACKROX_MCP__TOOLS__CONFIG_MANAGER__ENABLED", "true") + + cfg, err := LoadConfig("") + require.NoError(t, err) + require.NotNil(t, cfg) + + assert.Equal(t, "localhost", cfg.Server.Address) + assert.Equal(t, 8080, cfg.Server.Port) +} diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 45a3182..0ecb7b4 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -30,10 +30,16 @@ func SetupLogging() { } } - // Initialize slog with JSON handler + // Initialize slog with JSON handler. logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ Level: logLevel, })) slog.SetDefault(logger) } + +// Fatal logs and error and exits with exit code 1. +func Fatal(msg string, err error) { + slog.Error(msg, "error", err) + os.Exit(1) +} diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go index 6f0ef1d..d32e907 100644 --- a/internal/logging/logging_test.go +++ b/internal/logging/logging_test.go @@ -11,12 +11,12 @@ import ( ) func TestSetupLogging(t *testing.T) { - // Clear any existing LOG_LEVEL environment variable + // Clear any existing LOG_LEVEL environment variable. require.NoError(t, os.Unsetenv("LOG_LEVEL")) SetupLogging() - // Verify default log level is INFO + // Verify default log level is INFO. handler := slog.Default().Handler() assert.True(t, handler.Enabled(context.Background(), slog.LevelInfo)) assert.False(t, handler.Enabled(context.Background(), slog.LevelDebug)) diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..3d581b3 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,118 @@ +// Package server represents MCP server. +package server + +import ( + "context" + "log/slog" + "net" + "net/http" + "strconv" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/pkg/errors" + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/toolsets" +) + +const ( + shutdownTimeout = 5 * time.Second + readHeaderTimeout = 5 * time.Second +) + +// version is set at build time via ldflags (ldflags can't modify constants). +var version = "dev" + +// Server represents the MCP HTTP server. +type Server struct { + cfg *config.Config + registry *toolsets.Registry + mcp *mcp.Server +} + +// NewServer creates a new MCP server instance. +func NewServer(cfg *config.Config, registry *toolsets.Registry) *Server { + mcpServer := mcp.NewServer( + &mcp.Implementation{ + Name: "stackrox-mcp", + Version: version, + }, + nil, + ) + + return &Server{ + cfg: cfg, + registry: registry, + mcp: mcpServer, + } +} + +// Start starts the HTTP server with Streamable HTTP transport. +func (s *Server) Start(ctx context.Context) error { + s.registerTools() + + handler := mcp.NewStreamableHTTPHandler( + func(*http.Request) *mcp.Server { + return s.mcp + }, + nil, + ) + + addr := net.JoinHostPort(s.cfg.Server.Address, strconv.Itoa(s.cfg.Server.Port)) + httpServer := &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: readHeaderTimeout, + } + + slog.Info("Starting MCP HTTP server", "address", s.cfg.Server.Address, "port", s.cfg.Server.Port) + + // Start server in a goroutine. + errChan := make(chan error, 1) + + go func() { + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + errChan <- errors.Wrap(err, "HTTP server error") + } + }() + + // Wait for context cancellation or server error. + select { + case <-ctx.Done(): + slog.Info("Shutting down HTTP server") + // Create a context with timeout for graceful shutdown. + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer shutdownCancel() + //nolint:contextcheck + return errors.Wrap(httpServer.Shutdown(shutdownCtx), "server shutting down failed") + case err := <-errChan: + return err + } +} + +// registerTools registers all tools from the registry with the MCP server. +func (s *Server) registerTools() { + slog.Info("Registering MCP tools") + + for _, toolset := range s.registry.GetToolsets() { + if !toolset.IsEnabled() { + slog.Info("Skipping disabled toolset", "toolset", toolset.GetName()) + + continue + } + + for _, tool := range toolset.GetTools() { + if s.cfg.Global.ReadOnlyTools && !tool.IsReadOnly() { + slog.Info("Skipping read-write tool (read-only mode enabled)", "tool", tool.GetName()) + + continue + } + + slog.Info("Registering tool", "toolset", toolset.GetName(), "tool", tool.GetName(), "read_only", tool.IsReadOnly()) + + tool.RegisterWith(s.mcp) + } + } + + slog.Info("Tools registration complete") +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..199f004 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,168 @@ +package server + +import ( + "context" + "errors" + "net" + "net/http" + "strconv" + "testing" + "time" + + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/testutil" + "github.com/stackrox/stackrox-mcp/internal/toolsets" + "github.com/stackrox/stackrox-mcp/internal/toolsets/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// getDefaultConfig returns a default config for tests. +func getDefaultConfig() *config.Config { + return &config.Config{ + Global: config.GlobalConfig{ + ReadOnlyTools: false, + }, + Central: config.CentralConfig{ + URL: "central.example.com:8443", + }, + Server: config.ServerConfig{ + Address: "localhost", + Port: 8080, + }, + Tools: config.ToolsConfig{ + Vulnerability: config.ToolsetVulnerabilityConfig{ + Enabled: true, + }, + ConfigManager: config.ToolConfigManagerConfig{ + Enabled: false, + }, + }, + } +} + +func TestNewServer(t *testing.T) { + cfg := getDefaultConfig() + + registry := toolsets.NewRegistry(cfg, []toolsets.Toolset{}) + + srv := NewServer(cfg, registry) + + require.NotNil(t, srv) + assert.Equal(t, cfg, srv.cfg) + assert.Equal(t, registry, srv.registry) + assert.NotNil(t, srv.mcp) +} + +func TestServer_registerTools_AllEnabled(t *testing.T) { + cfg := getDefaultConfig() + cfg.Global.ReadOnlyTools = false + + readOnlyTestTool := mock.NewTool("test_read_only_tool", true) + readWriteTestTool := mock.NewTool("test_read_write_tool", false) + + toolsetList := []toolsets.Toolset{ + mock.NewToolset("test_toolset", true, []toolsets.Tool{readOnlyTestTool, readWriteTestTool}), + } + + registry := toolsets.NewRegistry(cfg, toolsetList) + srv := NewServer(cfg, registry) + + srv.registerTools() + + assert.True(t, readOnlyTestTool.RegisterCalled, "read-only test tool should be registered") + assert.True(t, readWriteTestTool.RegisterCalled, "read-write test tool should be registered") +} + +func TestServer_registerTools_ReadOnlyMode(t *testing.T) { + cfg := getDefaultConfig() + cfg.Global.ReadOnlyTools = true + + readOnlyTestTool := mock.NewTool("test_read_only_tool", true) + readWriteTestTool := mock.NewTool("test_read_write_tool", false) + + toolsetList := []toolsets.Toolset{ + mock.NewToolset("test_toolset", true, []toolsets.Tool{readOnlyTestTool, readWriteTestTool}), + } + + registry := toolsets.NewRegistry(cfg, toolsetList) + srv := NewServer(cfg, registry) + + srv.registerTools() + + assert.True(t, readOnlyTestTool.RegisterCalled, "read-only test tool should be registered") + assert.False(t, readWriteTestTool.RegisterCalled, "read-write test tool should not be registered in read-only mode") +} + +func TestServer_registerTools_DisabledToolset(t *testing.T) { + cfg := getDefaultConfig() + cfg.Global.ReadOnlyTools = false + + enabledTestTool := mock.NewTool("test_enabled_tool", true) + disabledTestTool := mock.NewTool("test_disabled_tool", true) + + toolsetList := []toolsets.Toolset{ + mock.NewToolset("enabled_toolset", true, []toolsets.Tool{enabledTestTool}), + mock.NewToolset("disabled_toolset", false, []toolsets.Tool{disabledTestTool}), + } + + registry := toolsets.NewRegistry(cfg, toolsetList) + srv := NewServer(cfg, registry) + + srv.registerTools() + + assert.True(t, enabledTestTool.RegisterCalled, "tool from enabled toolset should be registered") + assert.False(t, disabledTestTool.RegisterCalled, "tool from disabled toolset should not be registered") +} + +func TestServer_Start(t *testing.T) { + cfg := getDefaultConfig() + cfg.Server.Port = testutil.GetPortForTest(t) + + testTool := mock.NewTool("test_tool", true) + + toolsetList := []toolsets.Toolset{ + mock.NewToolset("test_toolset", true, []toolsets.Tool{testTool}), + } + + registry := toolsets.NewRegistry(cfg, toolsetList) + srv := NewServer(cfg, registry) + + ctx, cancel := context.WithCancel(context.Background()) + + errChan := make(chan error, 1) + + go func() { + errChan <- srv.Start(ctx) + }() + + serverURL := "http://" + net.JoinHostPort(cfg.Server.Address, strconv.Itoa(cfg.Server.Port)) + err := testutil.WaitForServerReady(serverURL, 3*time.Second) + require.NoError(t, err, "Server should start within timeout") + + // Verify tools were registered. + assert.True(t, testTool.RegisterCalled, "test tool should be registered when server starts") + + // Establish actual HTTP connection to verify server is responding. + //nolint:gosec,noctx + resp, err := http.Get(serverURL) + if err == nil { + _ = resp.Body.Close() + } + // We don't require a successful response, just that we can connect + require.NoError(t, err, "Should be able to establish HTTP connection to server") + + // Trigger graceful shutdown. + cancel() + + // Wait for server to shut down. + select { + case err := <-errChan: + // Server should shut down cleanly. + if err != nil && errors.Is(err, context.Canceled) { + t.Errorf("Server returned unexpected error: %v", err) + } + case <-time.After(shutdownTimeout): + t.Fatal("Server did not shut down within timeout period") + } +} diff --git a/internal/testutil/config.go b/internal/testutil/config.go new file mode 100644 index 0000000..bddb289 --- /dev/null +++ b/internal/testutil/config.go @@ -0,0 +1,30 @@ +// Package testutil contains test helpers. +package testutil + +import ( + "fmt" + "os" + "path/filepath" + "testing" +) + +const defaultFilePermissions = 0o600 + +// WriteYAMLFile writes the given YAML content to a temporary file and returns its path. +// The file will be automatically cleaned up when the test completes. +// Returns the absolute path to the created file. +func WriteYAMLFile(t *testing.T, content string) string { + t.Helper() + + tmpDir := t.TempDir() + // Use test name to create unique filename for parallel test execution + filename := fmt.Sprintf("config-%s.yaml", t.Name()) + configPath := filepath.Join(tmpDir, filename) + + err := os.WriteFile(configPath, []byte(content), defaultFilePermissions) + if err != nil { + t.Fatalf("Failed to write YAML file: %v", err) + } + + return configPath +} diff --git a/internal/testutil/ports.go b/internal/testutil/ports.go new file mode 100644 index 0000000..60f21a6 --- /dev/null +++ b/internal/testutil/ports.go @@ -0,0 +1,29 @@ +package testutil + +import ( + "hash/fnv" + "testing" +) + +const ( + minPort = 10000 + maxPort = 60000 +) + +// GetPortForTest returns a deterministic port number based on the test name. +// This ensures that each test gets a unique, reproducible port number for parallel execution. +// The port is calculated by hashing the test name and mapping it to a safe range. +func GetPortForTest(t *testing.T) int { + t.Helper() + + // Hash the test name using FNV-1a + h := fnv.New32a() + _, _ = h.Write([]byte(t.Name())) + hash := h.Sum32() + + // Map the hash to the port range [minPort, maxPort) + portRange := maxPort - minPort + port := minPort + int(hash%uint32(portRange)) + + return port +} diff --git a/internal/testutil/server.go b/internal/testutil/server.go new file mode 100644 index 0000000..8bcb257 --- /dev/null +++ b/internal/testutil/server.go @@ -0,0 +1,31 @@ +package testutil + +import ( + "fmt" + "net/http" + "time" +) + +const timeoutDuration = 100 * time.Millisecond + +// WaitForServerReady polls the server until it's ready to accept connections. +// This function is useful for integration tests where you need to wait for +// a server to start before making requests to it. +func WaitForServerReady(address string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + client := &http.Client{Timeout: timeoutDuration} + + for time.Now().Before(deadline) { + //nolint:noctx + resp, err := client.Get(address) + if err == nil { + _ = resp.Body.Close() + + return nil + } + + time.Sleep(timeoutDuration) + } + + return fmt.Errorf("server did not become ready within %v", timeout) +} diff --git a/internal/toolsets/config/tools.go b/internal/toolsets/config/tools.go new file mode 100644 index 0000000..ab7c95b --- /dev/null +++ b/internal/toolsets/config/tools.go @@ -0,0 +1,61 @@ +package config + +import ( + "context" + "errors" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stackrox/stackrox-mcp/internal/toolsets" +) + +// listClustersInput defines the input parameters for list_clusters tool. +type listClustersInput struct{} + +// listClustersOutput defines the output structure for list_clusters tool. +type listClustersOutput struct { + Clusters []string `json:"clusters"` +} + +// listClustersTool implements the list_clusters tool. +type listClustersTool struct { + name string +} + +// NewListClustersTool creates a new list_clusters tool. +func NewListClustersTool() toolsets.Tool { + return &listClustersTool{ + name: "list_clusters", + } +} + +// IsReadOnly returns true as this tool only reads data. +func (t *listClustersTool) IsReadOnly() bool { + return true +} + +// GetName returns the tool name. +func (t *listClustersTool) GetName() string { + return t.name +} + +// GetTool returns the MCP Tool definition. +func (t *listClustersTool) GetTool() *mcp.Tool { + return &mcp.Tool{ + Name: t.name, + Description: "List all clusters managed by StackRox Central with their IDs, names, and types", + } +} + +// RegisterWith registers the list_clusters tool handler with the MCP server. +func (t *listClustersTool) RegisterWith(server *mcp.Server) { + mcp.AddTool(server, t.GetTool(), t.handle) +} + +// handle is the placeholder handler for list_clusters tool. +func (t *listClustersTool) handle( + _ context.Context, + _ *mcp.CallToolRequest, + _ listClustersInput, +) (*mcp.CallToolResult, *listClustersOutput, error) { + return nil, nil, errors.New("list_clusters tool is not yet implemented") +} diff --git a/internal/toolsets/config/tools_test.go b/internal/toolsets/config/tools_test.go new file mode 100644 index 0000000..9712d95 --- /dev/null +++ b/internal/toolsets/config/tools_test.go @@ -0,0 +1,48 @@ +package config + +import ( + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewListClustersTool(t *testing.T) { + tool := NewListClustersTool() + + require.NotNil(t, tool) + assert.Equal(t, "list_clusters", tool.GetName()) +} + +func TestListClustersTool_IsReadOnly(t *testing.T) { + tool := NewListClustersTool() + + assert.True(t, tool.IsReadOnly(), "list_clusters should be read-only") +} + +func TestListClustersTool_GetTool(t *testing.T) { + tool := NewListClustersTool() + + mcpTool := tool.GetTool() + + require.NotNil(t, mcpTool) + assert.Equal(t, "list_clusters", mcpTool.Name) + assert.NotEmpty(t, mcpTool.Description) +} + +func TestListClustersTool_RegisterWith(t *testing.T) { + tool := NewListClustersTool() + server := mcp.NewServer( + &mcp.Implementation{ + Name: "test-server", + Version: "1.0.0", + }, + &mcp.ServerOptions{}, + ) + + // Should not panic + assert.NotPanics(t, func() { + tool.RegisterWith(server) + }) +} diff --git a/internal/toolsets/config/toolset.go b/internal/toolsets/config/toolset.go new file mode 100644 index 0000000..4c5ce03 --- /dev/null +++ b/internal/toolsets/config/toolset.go @@ -0,0 +1,42 @@ +// Package config provides functionality for config manager toolset. +package config + +import ( + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/toolsets" +) + +// Toolset implements the config management toolset. +type Toolset struct { + cfg *config.Config + tools []toolsets.Tool +} + +// NewToolset creates a new config management toolset. +func NewToolset(cfg *config.Config) *Toolset { + return &Toolset{ + cfg: cfg, + tools: []toolsets.Tool{ + NewListClustersTool(), + }, + } +} + +// GetName returns the toolset name. +func (t *Toolset) GetName() string { + return "config_manager" +} + +// IsEnabled checks if this toolset is enabled in configuration. +func (t *Toolset) IsEnabled() bool { + return t.cfg.Tools.ConfigManager.Enabled +} + +// GetTools returns all tools. +func (t *Toolset) GetTools() []toolsets.Tool { + if !t.IsEnabled() { + return []toolsets.Tool{} + } + + return t.tools +} diff --git a/internal/toolsets/config/toolset_test.go b/internal/toolsets/config/toolset_test.go new file mode 100644 index 0000000..b062e23 --- /dev/null +++ b/internal/toolsets/config/toolset_test.go @@ -0,0 +1,49 @@ +package config + +import ( + "testing" + + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewToolset(t *testing.T) { + toolset := NewToolset(&config.Config{}) + require.NotNil(t, toolset) + assert.Equal(t, "config_manager", toolset.GetName()) +} + +func TestToolset_IsEnabled_True(t *testing.T) { + cfg := &config.Config{ + Tools: config.ToolsConfig{ + ConfigManager: config.ToolConfigManagerConfig{ + Enabled: true, + }, + }, + } + + toolset := NewToolset(cfg) + assert.True(t, toolset.IsEnabled()) + + tools := toolset.GetTools() + require.NotEmpty(t, tools, "Should return tools when enabled") + require.Len(t, tools, 1, "Should have list_clusters tool") + assert.Equal(t, "list_clusters", tools[0].GetName()) +} + +func TestToolset_IsEnabled_False(t *testing.T) { + cfg := &config.Config{ + Tools: config.ToolsConfig{ + ConfigManager: config.ToolConfigManagerConfig{ + Enabled: false, + }, + }, + } + + toolset := NewToolset(cfg) + assert.False(t, toolset.IsEnabled()) + + tools := toolset.GetTools() + assert.Empty(t, tools, "Should return empty list when toolset is disabled") +} diff --git a/internal/toolsets/mock/tool.go b/internal/toolsets/mock/tool.go new file mode 100644 index 0000000..a2a7bc8 --- /dev/null +++ b/internal/toolsets/mock/tool.go @@ -0,0 +1,43 @@ +package mock + +import ( + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Tool is a mock implementation of the toolsets.Tool interface for testing. +type Tool struct { + NameValue string + ReadOnlyValue bool + RegisterCalled bool +} + +// NewTool creates a new mock tool with the given name and read-only status. +func NewTool(name string, readOnly bool) *Tool { + return &Tool{ + NameValue: name, + ReadOnlyValue: readOnly, + } +} + +// IsReadOnly returns whether this tool is read-only. +func (m *Tool) IsReadOnly() bool { + return m.ReadOnlyValue +} + +// GetName returns the tool name. +func (m *Tool) GetName() string { + return m.NameValue +} + +// GetTool returns the MCP Tool definition. +func (m *Tool) GetTool() *mcp.Tool { + return &mcp.Tool{ + Name: m.NameValue, + Description: "Mock tool for testing", + } +} + +// RegisterWith tracks that the tool was registered with the MCP server. +func (m *Tool) RegisterWith(_ *mcp.Server) { + m.RegisterCalled = true +} diff --git a/internal/toolsets/mock/toolset.go b/internal/toolsets/mock/toolset.go new file mode 100644 index 0000000..da38563 --- /dev/null +++ b/internal/toolsets/mock/toolset.go @@ -0,0 +1,42 @@ +// Package mock holds mocks for Tool and Toolset interfaces. +package mock + +import ( + "github.com/stackrox/stackrox-mcp/internal/toolsets" +) + +// Toolset is a mock implementation of the toolsets.Toolset interface for testing. +type Toolset struct { + NameValue string + EnabledValue bool + ToolsValue []toolsets.Tool +} + +// NewToolset creates a new mock toolset with the given parameters. +func NewToolset(name string, enabled bool, tools []toolsets.Tool) *Toolset { + return &Toolset{ + NameValue: name, + EnabledValue: enabled, + ToolsValue: tools, + } +} + +// GetName returns the toolset name. +func (m *Toolset) GetName() string { + return m.NameValue +} + +// IsEnabled returns whether this toolset is enabled. +func (m *Toolset) IsEnabled() bool { + return m.EnabledValue +} + +// GetTools returns the tools in this toolset. +// If the toolset is disabled, returns an empty slice. +func (m *Toolset) GetTools() []toolsets.Tool { + if !m.EnabledValue { + return []toolsets.Tool{} + } + + return m.ToolsValue +} diff --git a/internal/toolsets/registry.go b/internal/toolsets/registry.go new file mode 100644 index 0000000..dab0fe8 --- /dev/null +++ b/internal/toolsets/registry.go @@ -0,0 +1,43 @@ +package toolsets + +import ( + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stackrox/stackrox-mcp/internal/config" +) + +// Registry manages all available toolsets and collects their tools. +type Registry struct { + cfg *config.Config + toolsets []Toolset +} + +// NewRegistry creates a new registry with the given config and toolsets. +func NewRegistry(cfg *config.Config, toolsets []Toolset) *Registry { + return &Registry{ + cfg: cfg, + toolsets: toolsets, + } +} + +// GetAllTools returns all enabled tools from all enabled toolsets. +// Each toolset handles its own enabled check. +func (r *Registry) GetAllTools() []*mcp.Tool { + tools := make([]*mcp.Tool, 0) + + for _, toolset := range r.toolsets { + for _, tool := range toolset.GetTools() { + if r.cfg.Global.ReadOnlyTools && !tool.IsReadOnly() { + continue + } + + tools = append(tools, tool.GetTool()) + } + } + + return tools +} + +// GetToolsets returns all registered toolsets (for debugging/testing). +func (r *Registry) GetToolsets() []Toolset { + return r.toolsets +} diff --git a/internal/toolsets/registry_test.go b/internal/toolsets/registry_test.go new file mode 100644 index 0000000..6ac9cd0 --- /dev/null +++ b/internal/toolsets/registry_test.go @@ -0,0 +1,165 @@ +package toolsets_test + +import ( + "testing" + + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/toolsets" + "github.com/stackrox/stackrox-mcp/internal/toolsets/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRegistry(t *testing.T) { + cfg := &config.Config{} + toolsetList := []toolsets.Toolset{ + mock.NewToolset("mock_toolset_1", true, []toolsets.Tool{ + mock.NewTool("tool_1", true), + }), + mock.NewToolset("mock_toolset_2", true, []toolsets.Tool{ + mock.NewTool("tool_2", true), + }), + } + + registry := toolsets.NewRegistry(cfg, toolsetList) + + require.NotNil(t, registry) + assert.Len(t, registry.GetToolsets(), 2) +} + +func TestRegistry_GetAllTools_AllToolsetsEnabled(t *testing.T) { + cfg := &config.Config{ + Global: config.GlobalConfig{ + ReadOnlyTools: false, + }, + } + + toolsetList := []toolsets.Toolset{ + mock.NewToolset("mock_toolset_1", true, []toolsets.Tool{ + mock.NewTool("read_only_tool", true), + }), + mock.NewToolset("mock_toolset_2", true, []toolsets.Tool{ + mock.NewTool("read_write_tool", false), + }), + } + + registry := toolsets.NewRegistry(cfg, toolsetList) + + tools := registry.GetAllTools() + + require.NotEmpty(t, tools) + assert.Len(t, tools, 2, "Should have tools from both toolsets") + + toolNames := make(map[string]bool) + for _, tool := range tools { + toolNames[tool.Name] = true + } + + assert.True(t, toolNames["read_only_tool"], "Should have read_only_tool") + assert.True(t, toolNames["read_write_tool"], "Should have read_write_tool") +} + +func TestRegistry_GetAllTools_OneToolsetDisabled(t *testing.T) { + cfg := &config.Config{ + Global: config.GlobalConfig{ + ReadOnlyTools: false, + }, + } + + toolsetList := []toolsets.Toolset{ + mock.NewToolset("enabled_toolset", true, []toolsets.Tool{ + mock.NewTool("enabled_tool", true), + }), + mock.NewToolset("disabled_toolset", false, []toolsets.Tool{ + mock.NewTool("disabled_tool", true), + }), + } + + registry := toolsets.NewRegistry(cfg, toolsetList) + + tools := registry.GetAllTools() + + require.NotEmpty(t, tools) + require.Len(t, tools, 1, "Should only have tools from enabled toolset") + assert.Equal(t, "enabled_tool", tools[0].Name) +} + +func TestRegistry_GetAllTools_AllToolsetsDisabled(t *testing.T) { + cfg := &config.Config{} + + toolsetList := []toolsets.Toolset{ + mock.NewToolset("disabled_toolset_1", false, []toolsets.Tool{ + mock.NewTool("tool_1", true), + }), + mock.NewToolset("disabled_toolset_2", false, []toolsets.Tool{ + mock.NewTool("tool_2", true), + }), + } + + registry := toolsets.NewRegistry(cfg, toolsetList) + + tools := registry.GetAllTools() + + assert.Empty(t, tools, "Should return empty list when all toolsets are disabled") +} + +func TestRegistry_GetAllTools_FiltersReadWriteTools(t *testing.T) { + cfg := &config.Config{ + Global: config.GlobalConfig{ + ReadOnlyTools: true, + }, + } + + toolsetList := []toolsets.Toolset{ + mock.NewToolset("mixed_toolset", true, []toolsets.Tool{ + mock.NewTool("read_only_1", true), + mock.NewTool("read_write_1", false), + mock.NewTool("read_only_2", true), + mock.NewTool("read_write_2", false), + }), + } + + registry := toolsets.NewRegistry(cfg, toolsetList) + tools := registry.GetAllTools() + + require.Len(t, tools, 2, "Should filter out read-write tools") + + toolNames := make(map[string]bool) + for _, tool := range tools { + toolNames[tool.Name] = true + } + + assert.True(t, toolNames["read_only_1"], "Should have read-only tools") + assert.True(t, toolNames["read_only_2"], "Should have read-only tools") +} + +func TestRegistry_GetToolsets(t *testing.T) { + cfg := &config.Config{} + + mockToolset1 := mock.NewToolset("toolset_1", true, []toolsets.Tool{ + mock.NewTool("tool_1", true), + }) + mockToolset2 := mock.NewToolset("toolset_2", true, []toolsets.Tool{ + mock.NewTool("tool_2", true), + }) + + toolsetList := []toolsets.Toolset{mockToolset1, mockToolset2} + registry := toolsets.NewRegistry(cfg, toolsetList) + + retrievedToolsets := registry.GetToolsets() + + require.Len(t, retrievedToolsets, 2) + assert.Contains(t, retrievedToolsets, mockToolset1) + assert.Contains(t, retrievedToolsets, mockToolset2) +} + +func TestRegistry_EmptyToolsets(t *testing.T) { + cfg := &config.Config{} + registry := toolsets.NewRegistry(cfg, []toolsets.Toolset{}) + + tools := registry.GetAllTools() + toolsetList := registry.GetToolsets() + + assert.Empty(t, tools) + assert.Empty(t, toolsetList) +} diff --git a/internal/toolsets/toolset.go b/internal/toolsets/toolset.go new file mode 100644 index 0000000..af9a9d3 --- /dev/null +++ b/internal/toolsets/toolset.go @@ -0,0 +1,33 @@ +// Package toolsets handles tools and toolsets registration. +package toolsets + +import ( + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Tool represents a single MCP tool with metadata. +type Tool interface { + // IsReadOnly returns true if the tool only performs read operations. + IsReadOnly() bool + + // GetTool returns the MCP SDK Tool definition. + GetTool() *mcp.Tool + + // GetName returns the tool name for logging/debugging. + GetName() string + + // RegisterWith registers the tool's handler with the MCP server. + RegisterWith(server *mcp.Server) +} + +// Toolset represents a collection of related tools. +type Toolset interface { + // GetName returns the toolset name. + GetName() string + + // IsEnabled checks if this toolset is enabled in configuration. + IsEnabled() bool + + // GetTools returns available tools based on configuration. + GetTools() []Tool +} diff --git a/internal/toolsets/vulnerability/tools.go b/internal/toolsets/vulnerability/tools.go new file mode 100644 index 0000000..03f2bc3 --- /dev/null +++ b/internal/toolsets/vulnerability/tools.go @@ -0,0 +1,64 @@ +package vulnerability + +import ( + "context" + "errors" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stackrox/stackrox-mcp/internal/toolsets" +) + +// listClusterCVEsInput defines the input parameters for list_cluster_cves tool. +type listClusterCVEsInput struct { + ClusterID string `json:"clusterId,omitempty"` +} + +// listClusterCVEsOutput defines the output structure for list_cluster_cves tool. +type listClusterCVEsOutput struct { + CVEs []string `json:"cves"` +} + +// listClusterCVEsTool implements the list_cluster_cves tool. +type listClusterCVEsTool struct { + name string +} + +// NewListClusterCVEsTool creates a new list_cluster_cves tool. +func NewListClusterCVEsTool() toolsets.Tool { + return &listClusterCVEsTool{ + name: "list_cluster_cves", + } +} + +// IsReadOnly returns true as this tool only reads data. +func (t *listClusterCVEsTool) IsReadOnly() bool { + return true +} + +// GetName returns the tool name. +func (t *listClusterCVEsTool) GetName() string { + return t.name +} + +// GetTool returns the MCP Tool definition. +func (t *listClusterCVEsTool) GetTool() *mcp.Tool { + return &mcp.Tool{ + Name: t.name, + //nolint:lll + Description: "List CVEs affecting a specific cluster or all clusters in StackRox Central with CVE names, scores, affected images, and deployments", + } +} + +// RegisterWith registers the list_cluster_cves tool handler with the MCP server. +func (t *listClusterCVEsTool) RegisterWith(server *mcp.Server) { + mcp.AddTool(server, t.GetTool(), t.handle) +} + +// handle is the placeholder handler for list_cluster_cves tool. +func (t *listClusterCVEsTool) handle( + _ context.Context, + _ *mcp.CallToolRequest, + _ listClusterCVEsInput, +) (*mcp.CallToolResult, *listClusterCVEsOutput, error) { + return nil, nil, errors.New("list_cluster_cves tool is not yet implemented") +} diff --git a/internal/toolsets/vulnerability/tools_test.go b/internal/toolsets/vulnerability/tools_test.go new file mode 100644 index 0000000..9e2b0a6 --- /dev/null +++ b/internal/toolsets/vulnerability/tools_test.go @@ -0,0 +1,48 @@ +package vulnerability + +import ( + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewListClusterCVEsTool(t *testing.T) { + tool := NewListClusterCVEsTool() + + require.NotNil(t, tool) + assert.Equal(t, "list_cluster_cves", tool.GetName()) +} + +func TestListClusterCVEsTool_IsReadOnly(t *testing.T) { + tool := NewListClusterCVEsTool() + + assert.True(t, tool.IsReadOnly(), "list_cluster_cves should be read-only") +} + +func TestListClusterCVEsTool_GetTool(t *testing.T) { + tool := NewListClusterCVEsTool() + + mcpTool := tool.GetTool() + + require.NotNil(t, mcpTool) + assert.Equal(t, "list_cluster_cves", mcpTool.Name) + assert.NotEmpty(t, mcpTool.Description) +} + +func TestListClusterCVEsTool_RegisterWith(t *testing.T) { + tool := NewListClusterCVEsTool() + server := mcp.NewServer( + &mcp.Implementation{ + Name: "test-server", + Version: "1.0.0", + }, + &mcp.ServerOptions{}, + ) + + // Should not panic + assert.NotPanics(t, func() { + tool.RegisterWith(server) + }) +} diff --git a/internal/toolsets/vulnerability/toolset.go b/internal/toolsets/vulnerability/toolset.go new file mode 100644 index 0000000..70950f7 --- /dev/null +++ b/internal/toolsets/vulnerability/toolset.go @@ -0,0 +1,42 @@ +// Package vulnerability provides functionality for vulnerability toolset. +package vulnerability + +import ( + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/toolsets" +) + +// Toolset implements the vulnerability management toolset. +type Toolset struct { + cfg *config.Config + tools []toolsets.Tool +} + +// NewToolset creates a new vulnerability management toolset. +func NewToolset(cfg *config.Config) *Toolset { + return &Toolset{ + cfg: cfg, + tools: []toolsets.Tool{ + NewListClusterCVEsTool(), + }, + } +} + +// GetName returns the toolset name. +func (t *Toolset) GetName() string { + return "vulnerability" +} + +// IsEnabled checks if this toolset is enabled in configuration. +func (t *Toolset) IsEnabled() bool { + return t.cfg.Tools.Vulnerability.Enabled +} + +// GetTools returns all tools. +func (t *Toolset) GetTools() []toolsets.Tool { + if !t.IsEnabled() { + return []toolsets.Tool{} + } + + return t.tools +} diff --git a/internal/toolsets/vulnerability/toolset_test.go b/internal/toolsets/vulnerability/toolset_test.go new file mode 100644 index 0000000..b4cbfaf --- /dev/null +++ b/internal/toolsets/vulnerability/toolset_test.go @@ -0,0 +1,58 @@ +package vulnerability + +import ( + "testing" + + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewToolset(t *testing.T) { + cfg := &config.Config{ + Tools: config.ToolsConfig{ + Vulnerability: config.ToolsetVulnerabilityConfig{ + Enabled: true, + }, + }, + } + + toolset := NewToolset(cfg) + + require.NotNil(t, toolset) + assert.Equal(t, "vulnerability", toolset.GetName()) +} + +func TestToolset_IsEnabled_True(t *testing.T) { + cfg := &config.Config{ + Tools: config.ToolsConfig{ + Vulnerability: config.ToolsetVulnerabilityConfig{ + Enabled: true, + }, + }, + } + + toolset := NewToolset(cfg) + assert.True(t, toolset.IsEnabled()) + + tools := toolset.GetTools() + require.NotEmpty(t, tools, "Should return tools when enabled") + require.Len(t, tools, 1, "Should have list_cluster_cves tool") + assert.Equal(t, "list_cluster_cves", tools[0].GetName()) +} + +func TestToolset_IsEnabled_False(t *testing.T) { + cfg := &config.Config{ + Tools: config.ToolsConfig{ + Vulnerability: config.ToolsetVulnerabilityConfig{ + Enabled: false, + }, + }, + } + + toolset := NewToolset(cfg) + assert.False(t, toolset.IsEnabled()) + + tools := toolset.GetTools() + assert.Empty(t, tools, "Should return empty list when toolset is disabled") +}