diff --git a/src/cmap/cerrmap.go b/src/cmap/cerrmap.go index 687c9191b..78c7dc8ed 100644 --- a/src/cmap/cerrmap.go +++ b/src/cmap/cerrmap.go @@ -78,3 +78,14 @@ func (m *ErrMap[K, V]) GetOrSet(key K, f func() (V, error)) (V, error) { } return v.Val, v.Err } + +// Range calls f for each key-value pair in the map. +// No particular consistency guarantees are made during iteration. +func (m *ErrMap[K, V]) Range(f func(key K, val V)) { + m.m.Range(func(key K, val errV[V]) { + if val.Err != nil { + return // skip errors + } + f(key, val.Val) + }) +} diff --git a/src/cmap/cmap.go b/src/cmap/cmap.go index ce8508b45..f8058ef73 100644 --- a/src/cmap/cmap.go +++ b/src/cmap/cmap.go @@ -94,6 +94,14 @@ func (m *Map[K, V]) Values() []V { return ret } +// Range calls f for each key-value pair in the map. +// No particular consistency guarantees are made during iteration. +func (m *Map[K, V]) Range(f func(key K, val V)) { + for i := 0; i < len(m.shards); i++ { + m.shards[i].Range(f) + } +} + // An awaitableValue represents a value in the map & an awaitable channel for it to exist. type awaitableValue[V any] struct { Val V @@ -195,3 +203,14 @@ func (s *shard[K, V]) Contains(key K) bool { _, ok := s.m[key] return ok } + +// Range calls f for each key-value pair in this shard. +func (s *shard[K, V]) Range(f func(key K, val V)) { + s.l.RLock() + defer s.l.RUnlock() + for k, v := range s.m { + if v.Wait == nil { // Only include completed values + f(k, v.Val) + } + } +} diff --git a/src/parse/asp/parser.go b/src/parse/asp/parser.go index 36f055b96..f67a8605a 100644 --- a/src/parse/asp/parser.go +++ b/src/parse/asp/parser.go @@ -257,6 +257,24 @@ func (p *Parser) optimiseBuiltinCalls(stmts []*Statement) { } } +// AllFunctionsByFile returns all function definitions grouped by filename. +// This includes functions from builtins, plugins, and subincludes. +// It iterates over the ASTs stored by the interpreter. +func (p *Parser) AllFunctionsByFile() map[string][]*Statement { + if p.interpreter == nil || p.interpreter.asts == nil { + return nil + } + result := make(map[string][]*Statement) + p.interpreter.asts.Range(func(filename string, stmts []*Statement) { + for _, stmt := range stmts { + if stmt.FuncDef != nil { + result[filename] = append(result[filename], stmt) + } + } + }) + return result +} + // whitelistedKwargs returns true if the given built-in function name is allowed to // be called as non-kwargs. // TODO(peterebden): Come up with a syntax that exposes this directly in the file. diff --git a/src/parse/init.go b/src/parse/init.go index 663e26510..ee67dacda 100644 --- a/src/parse/init.go +++ b/src/parse/init.go @@ -25,6 +25,19 @@ func InitParser(state *core.BuildState) *core.BuildState { return state } +// GetAspParser returns the underlying asp.Parser from the state's parser. +// This is useful for tools like the language server that need direct access to AST information. +// Returns nil if the state's parser is not set or is not an aspParser. +func GetAspParser(state *core.BuildState) *asp.Parser { + if state.Parser == nil { + return nil + } + if ap, ok := state.Parser.(*aspParser); ok { + return ap.parser + } + return nil +} + // aspParser implements the core.Parser interface around our parser package. type aspParser struct { parser *asp.Parser diff --git a/tools/build_langserver/lsp/BUILD b/tools/build_langserver/lsp/BUILD index f8c5f3485..35cefd6cd 100644 --- a/tools/build_langserver/lsp/BUILD +++ b/tools/build_langserver/lsp/BUILD @@ -5,6 +5,7 @@ go_library( "definition.go", "diagnostics.go", "lsp.go", + "references.go", "symbols.go", "text.go", ], @@ -17,8 +18,10 @@ go_library( "//rules", "//src/core", "//src/fs", + "//src/parse", "//src/parse/asp", "//src/plz", + "//src/query", "//tools/build_langserver/lsp/astutils", ], ) @@ -29,15 +32,20 @@ go_test( srcs = [ "definition_test.go", "lsp_test.go", + "references_test.go", "symbols_test.go", ], - data = ["test_data"], + data = [ + "test_data", + "test_data_find_references", + ], deps = [ ":lsp", "///third_party/go/github.com_please-build_buildtools//build", "///third_party/go/github.com_sourcegraph_go-lsp//:go-lsp", "///third_party/go/github.com_sourcegraph_jsonrpc2//:jsonrpc2", "///third_party/go/github.com_stretchr_testify//assert", + "///third_party/go/github.com_stretchr_testify//require", "//src/cli", "//src/core", ], diff --git a/tools/build_langserver/lsp/completion.go b/tools/build_langserver/lsp/completion.go index 081af9194..4348b24fb 100644 --- a/tools/build_langserver/lsp/completion.go +++ b/tools/build_langserver/lsp/completion.go @@ -109,10 +109,10 @@ func (h *Handler) completeString(doc *doc, s string, line, col int) (*lsp.Comple // completeIdent completes an arbitrary identifier func (h *Handler) completeIdent(doc *doc, s string, line, col int) (*lsp.CompletionList, error) { list := &lsp.CompletionList{} - for name, f := range h.builtins { - if strings.HasPrefix(name, s) { + for name, builtins := range h.builtins { + if strings.HasPrefix(name, s) && len(builtins) > 0 { item := completionItem(name, "", line, col) - item.Documentation = f.Stmt.FuncDef.Docstring + item.Documentation = builtins[0].Stmt.FuncDef.Docstring item.Kind = lsp.CIKFunction list.Items = append(list.Items, item) } diff --git a/tools/build_langserver/lsp/definition.go b/tools/build_langserver/lsp/definition.go index 9ee1c5df1..3085d5d03 100644 --- a/tools/build_langserver/lsp/definition.go +++ b/tools/build_langserver/lsp/definition.go @@ -18,20 +18,20 @@ func (h *Handler) definition(params *lsp.TextDocumentPositionParams) ([]lsp.Loca ast := h.parseIfNeeded(doc) f := doc.AspFile() - var locs []lsp.Location + locs := []lsp.Location{} pos := aspPos(params.Position) asp.WalkAST(ast, func(expr *asp.Expression) bool { - if !asp.WithinRange(pos, f.Pos(expr.Pos), f.Pos(expr.EndPos)) { + exprStart := f.Pos(expr.Pos) + exprEnd := f.Pos(expr.EndPos) + if !asp.WithinRange(pos, exprStart, exprEnd) { return false } - if expr.Val.Ident != nil { if loc := h.findGlobal(expr.Val.Ident.Name); loc.URI != "" { locs = append(locs, loc) } return false } - if expr.Val.String != "" { label := astutils.TrimStrLit(expr.Val.String) if loc := h.findLabel(doc.PkgName, label); loc.URI != "" { @@ -39,20 +39,19 @@ func (h *Handler) definition(params *lsp.TextDocumentPositionParams) ([]lsp.Loca } return false } - return true }) - // It might also be a statement. + // It might also be a statement (e.g. a function call like go_library(...)) asp.WalkAST(ast, func(stmt *asp.Statement) bool { if stmt.Ident != nil { - endPos := f.Pos(stmt.Pos) + stmtStart := f.Pos(stmt.Pos) + endPos := stmtStart // TODO(jpoole): The AST should probably just have this information endPos.Column += len(stmt.Ident.Name) - if !asp.WithinRange(pos, f.Pos(stmt.Pos), endPos) { - return false + if !asp.WithinRange(pos, stmtStart, endPos) { + return true // continue to other statements } - if loc := h.findGlobal(stmt.Ident.Name); loc.URI != "" { locs = append(locs, loc) } @@ -78,6 +77,9 @@ func (h *Handler) findLabel(currentPath, label string) lsp.Location { } pkg := h.state.Graph.PackageByLabel(l) + if pkg == nil { + return lsp.Location{} + } uri := lsp.DocumentURI("file://" + filepath.Join(h.root, pkg.Filename)) loc := lsp.Location{URI: uri} doc, err := h.maybeOpenDoc(uri) @@ -137,9 +139,18 @@ func findName(args []asp.CallArgument) string { // findGlobal returns the location of a global of the given name. func (h *Handler) findGlobal(name string) lsp.Location { - if f, present := h.builtins[name]; present { + h.mutex.Lock() + builtins := h.builtins[name] + h.mutex.Unlock() + if len(builtins) > 0 { + f := builtins[0] + filename := f.Pos.Filename + // Make path absolute if it's relative + if !filepath.IsAbs(filename) { + filename = filepath.Join(h.root, filename) + } return lsp.Location{ - URI: lsp.DocumentURI("file://" + f.Pos.Filename), + URI: lsp.DocumentURI("file://" + filename), Range: rng(f.Pos, f.EndPos), } } diff --git a/tools/build_langserver/lsp/lsp.go b/tools/build_langserver/lsp/lsp.go index b97936017..3ec25dc37 100644 --- a/tools/build_langserver/lsp/lsp.go +++ b/tools/build_langserver/lsp/lsp.go @@ -20,6 +20,7 @@ import ( "github.com/thought-machine/please/rules" "github.com/thought-machine/please/src/core" "github.com/thought-machine/please/src/fs" + "github.com/thought-machine/please/src/parse" "github.com/thought-machine/please/src/parse/asp" "github.com/thought-machine/please/src/plz" ) @@ -33,7 +34,7 @@ type Handler struct { mutex sync.Mutex // guards docs state *core.BuildState parser *asp.Parser - builtins map[string]builtin + builtins map[string][]builtin pkgs *pkg root string } @@ -41,6 +42,7 @@ type Handler struct { type builtin struct { Stmt *asp.Statement Pos, EndPos asp.FilePosition + Labels []core.BuildLabel // which subinclude targets can provide this (empty for core builtins) } // A Conn is a minimal set of the jsonrpc2.Conn that we need. @@ -55,7 +57,7 @@ func NewHandler() *Handler { return &Handler{ docs: map[string]*doc{}, pkgs: &pkg{}, - builtins: map[string]builtin{}, + builtins: map[string][]builtin{}, } } @@ -173,6 +175,12 @@ func (h *Handler) handle(method string, params *json.RawMessage) (res interface{ return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} } return h.definition(positionParams) + case "textDocument/references": + referenceParams := &lsp.ReferenceParams{} + if err := json.Unmarshal(*params, referenceParams); err != nil { + return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} + } + return h.references(referenceParams) default: return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeMethodNotFound} } @@ -195,13 +203,35 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul } h.state = core.NewBuildState(config) h.state.NeedBuild = false - // We need an unwrapped parser instance as well for raw access. - h.parser = asp.NewParser(h.state) + // Initialize the parser on state first, so that plz.RunHost uses the same parser. + // This ensures plugin subincludes are stored in the same AST cache we use. + parse.InitParser(h.state) + h.parser = parse.GetAspParser(h.state) + if h.parser == nil { + return nil, fmt.Errorf("failed to get asp parser from state") + } // Parse everything in the repo up front. // This is a lot easier than trying to do clever partial parses later on, although // eventually we may want that if we start dealing with truly large repos. go func() { + // Start a goroutine to periodically load parser functions as they become available. + // This allows go-to-definition to work progressively while the full parse runs. + done := make(chan struct{}) + go func() { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-done: + h.loadParserFunctions() + return + case <-ticker.C: + h.loadParserFunctions() + } + } + }() plz.RunHost(core.WholeGraph, h.state) + close(done) log.Debug("initial parse complete") h.buildPackageTree() log.Debug("built completion package tree") @@ -221,6 +251,7 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul DocumentFormattingProvider: true, DocumentSymbolProvider: true, DefinitionProvider: true, + ReferencesProvider: true, CompletionProvider: &lsp.CompletionOptions{ TriggerCharacters: []string{"/", ":"}, }, @@ -256,11 +287,11 @@ func (h *Handler) loadBuiltins() error { f := asp.NewFile(dest, data) for _, stmt := range stmts { if stmt.FuncDef != nil { - h.builtins[stmt.FuncDef.Name] = builtin{ + h.builtins[stmt.FuncDef.Name] = append(h.builtins[stmt.FuncDef.Name], builtin{ Stmt: stmt, Pos: f.Pos(stmt.Pos), EndPos: f.Pos(stmt.EndPos), - } + }) } } } @@ -268,6 +299,60 @@ func (h *Handler) loadBuiltins() error { return nil } +// loadParserFunctions loads function definitions from the parser's ASTs. +// This includes plugin-defined functions like go_library, python_library, etc. +func (h *Handler) loadParserFunctions() { + funcsByFile := h.parser.AllFunctionsByFile() + if funcsByFile == nil { + return + } + h.mutex.Lock() + defer h.mutex.Unlock() + for filename, stmts := range funcsByFile { + // Read the file to create a File object for position conversion + data, err := os.ReadFile(filename) + if err != nil { + log.Warning("failed to read file %s: %v", filename, err) + continue + } + file := asp.NewFile(filename, data) + labels := h.findLabelsForFile(filename) + for _, stmt := range stmts { + name := stmt.FuncDef.Name + h.builtins[name] = append(h.builtins[name], builtin{ + Stmt: stmt, + Pos: file.Pos(stmt.Pos), + EndPos: file.Pos(stmt.EndPos), + Labels: labels, + }) + } + } +} + +// findLabelsForFile finds all build labels that produce the given output file. +// A file can be exposed by multiple targets (e.g., multiple filegroups pointing to the same source). +func (h *Handler) findLabelsForFile(filename string) []core.BuildLabel { + // Make the filename absolute for comparison + if !filepath.IsAbs(filename) { + filename = filepath.Join(h.root, filename) + } + var labels []core.BuildLabel + for _, pkg := range h.state.Graph.PackageMap() { + for _, target := range pkg.AllTargets() { + for _, out := range target.FullOutputs() { + absOut := out + if !filepath.IsAbs(out) { + absOut = filepath.Join(h.root, out) + } + if absOut == filename { + labels = append(labels, target.Label) + } + } + } + } + return labels +} + // fromURI converts a DocumentURI to a path. func fromURI(uri lsp.DocumentURI) string { if !strings.HasPrefix(string(uri), "file://") { diff --git a/tools/build_langserver/lsp/lsp_test.go b/tools/build_langserver/lsp/lsp_test.go index ca7a0b4c7..8d05240f8 100644 --- a/tools/build_langserver/lsp/lsp_test.go +++ b/tools/build_langserver/lsp/lsp_test.go @@ -458,7 +458,7 @@ func TestCompletionFunction(t *testing.T) { Kind: lsp.CIKFunction, InsertTextFormat: lsp.ITFPlainText, TextEdit: textEdit("plugin_repo", 0, 4, 0), - Documentation: h.builtins["plugin_repo"].Stmt.FuncDef.Docstring, + Documentation: h.builtins["plugin_repo"][0].Stmt.FuncDef.Docstring, }}, }, completions) } @@ -492,7 +492,7 @@ func TestCompletionPartialFunction(t *testing.T) { Kind: lsp.CIKFunction, InsertTextFormat: lsp.ITFPlainText, TextEdit: textEdit("plugin_repo", 0, 9, 0), - Documentation: h.builtins["plugin_repo"].Stmt.FuncDef.Docstring, + Documentation: h.builtins["plugin_repo"][0].Stmt.FuncDef.Docstring, }}, }, completions) } diff --git a/tools/build_langserver/lsp/references.go b/tools/build_langserver/lsp/references.go new file mode 100644 index 000000000..f9413896c --- /dev/null +++ b/tools/build_langserver/lsp/references.go @@ -0,0 +1,287 @@ +package lsp + +import ( + "path/filepath" + "slices" + "strings" + + "github.com/sourcegraph/go-lsp" + + "github.com/thought-machine/please/src/core" + "github.com/thought-machine/please/src/parse/asp" + "github.com/thought-machine/please/src/query" + "github.com/thought-machine/please/tools/build_langserver/lsp/astutils" +) + +// references implements 'find all references' support. +func (h *Handler) references(params *lsp.ReferenceParams) ([]lsp.Location, error) { + doc := h.doc(params.TextDocument.URI) + ast := h.parseIfNeeded(doc) + f := doc.AspFile() + pos := aspPos(params.Position) + + // Check if cursor is on a function definition (def funcname(...)) + var funcName string + asp.WalkAST(ast, func(stmt *asp.Statement) bool { + if stmt.FuncDef != nil { + stmtStart := f.Pos(stmt.Pos) + // Check if cursor is on the function name + nameEnd := stmtStart + nameEnd.Column += len("def ") + len(stmt.FuncDef.Name) + if asp.WithinRange(pos, stmtStart, nameEnd) { + funcName = stmt.FuncDef.Name + return false + } + } + return true + }) + + // If we found a function definition, find all calls to it + if funcName != "" { + // Pass the current filename so we can find the labels for THIS definition + return h.findFunctionReferences(funcName, f.Name, params.Context.IncludeDeclaration) + } + + // Check if cursor is on a function call (e.g., go_library(...)) + asp.WalkAST(ast, func(stmt *asp.Statement) bool { + if stmt.Ident != nil { + stmtStart := f.Pos(stmt.Pos) + nameEnd := stmtStart + nameEnd.Column += len(stmt.Ident.Name) + if asp.WithinRange(pos, stmtStart, nameEnd) { + funcName = stmt.Ident.Name + return false + } + } + return true + }) + + if funcName != "" { + // For function calls, we don't have a specific definition file - use empty string + return h.findFunctionReferences(funcName, "", params.Context.IncludeDeclaration) + } + + // Otherwise, look for build label references + return h.findLabelReferences(doc, ast, f, pos, params.Context.IncludeDeclaration) +} + +// findFunctionReferences finds all calls to a function across all BUILD files. +// sourceFile is the file where the definition we're searching from is located (empty if from a call site). +func (h *Handler) findFunctionReferences(funcName string, sourceFile string, includeDeclaration bool) ([]lsp.Location, error) { + locs := []lsp.Location{} + + // Get the labels that can provide this function (a file may be exposed by multiple targets) + h.mutex.Lock() + builtins := h.builtins[funcName] + var defLabels []core.BuildLabel + // If we know the source file, find the labels for THAT specific definition + // This prevents false positives when multiple files define functions with the same name + for _, b := range builtins { + if sourceFile != "" && pathsMatch(sourceFile, b.Pos.Filename, h.root) { + defLabels = b.Labels + break + } + } + // If no source file specified or not found, fall back to first definition's labels + if len(defLabels) == 0 && len(builtins) > 0 { + defLabels = builtins[0].Labels + } + h.mutex.Unlock() + + // Search all packages for calls to this function + for _, pkg := range h.state.Graph.PackageMap() { + // If we know which labels define this function, only search packages that subinclude at least one + if len(defLabels) > 0 { + allSubincludes := pkg.AllSubincludes(h.state.Graph) + hasSubinclude := false + for _, defLabel := range defLabels { + if slices.Contains(allSubincludes, defLabel) { + hasSubinclude = true + break + } + } + if !hasSubinclude { + continue + } + } + + uri := lsp.DocumentURI("file://" + filepath.Join(h.root, pkg.Filename)) + refDoc, err := h.maybeOpenDoc(uri) + if err != nil { + continue + } + refAst := h.parseIfNeeded(refDoc) + refFile := refDoc.AspFile() + + // Find all statement calls to the function (e.g., go_library(...)) + asp.WalkAST(refAst, func(stmt *asp.Statement) bool { + if stmt.Ident != nil && stmt.Ident.Name == funcName && stmt.Ident.Action != nil && stmt.Ident.Action.Call != nil { + start := refFile.Pos(stmt.Pos) + end := start + end.Column += len(funcName) + locs = append(locs, lsp.Location{ + URI: uri, + Range: lsp.Range{ + Start: lsp.Position{Line: start.Line - 1, Character: start.Column - 1}, + End: lsp.Position{Line: end.Line - 1, Character: end.Column - 1}, + }, + }) + } + return true + }) + + // Find expression calls (e.g., x = go_library(...)) + asp.WalkAST(refAst, func(expr *asp.Expression) bool { + if expr.Val.Ident != nil && expr.Val.Ident.Name == funcName && len(expr.Val.Ident.Action) > 0 && expr.Val.Ident.Action[0].Call != nil { + start := refFile.Pos(expr.Pos) + end := start + end.Column += len(funcName) + locs = append(locs, lsp.Location{ + URI: uri, + Range: lsp.Range{ + Start: lsp.Position{Line: start.Line - 1, Character: start.Column - 1}, + End: lsp.Position{Line: end.Line - 1, Character: end.Column - 1}, + }, + }) + } + return true + }) + } + + // Include the definition itself if requested + if includeDeclaration { + h.mutex.Lock() + if builtins, ok := h.builtins[funcName]; ok && len(builtins) > 0 { + b := builtins[0] + filename := b.Pos.Filename + if !filepath.IsAbs(filename) { + filename = filepath.Join(h.root, filename) + } + locs = append(locs, lsp.Location{ + URI: lsp.DocumentURI("file://" + filename), + Range: rng(b.Pos, b.EndPos), + }) + } + h.mutex.Unlock() + } + + return locs, nil +} + +// findLabelReferences finds all references to a build label. +func (h *Handler) findLabelReferences(doc *doc, ast []*asp.Statement, f *asp.File, pos asp.FilePosition, includeDeclaration bool) ([]lsp.Location, error) { + var targetLabel core.BuildLabel + var targetName string + + // Check if cursor is on a string (build label) + asp.WalkAST(ast, func(expr *asp.Expression) bool { + exprStart := f.Pos(expr.Pos) + exprEnd := f.Pos(expr.EndPos) + if !asp.WithinRange(pos, exprStart, exprEnd) { + return false + } + if expr.Val.String != "" { + label := astutils.TrimStrLit(expr.Val.String) + if l, err := core.TryParseBuildLabel(label, doc.PkgName, ""); err == nil { + targetLabel = l + } + return false + } + return true + }) + + // Check if cursor is on a target definition (name = "...") + if targetLabel.IsEmpty() { + asp.WalkAST(ast, func(stmt *asp.Statement) bool { + if stmt.Ident != nil && stmt.Ident.Action != nil && stmt.Ident.Action.Call != nil { + stmtStart := f.Pos(stmt.Pos) + stmtEnd := f.Pos(stmt.EndPos) + if asp.WithinRange(pos, stmtStart, stmtEnd) { + if name := findName(stmt.Ident.Action.Call.Arguments); name != "" { + targetLabel = core.BuildLabel{PackageName: doc.PkgName, Name: name} + targetName = name + } + } + return false + } + return true + }) + } + + if targetLabel.IsEmpty() { + return []lsp.Location{}, nil + } + + // Verify the target exists in the graph before querying revdeps + // This prevents panics when the label has an invalid package name + if h.state.Graph.Package(targetLabel.PackageName, "") == nil { + log.Warning("references: package %q not found in build graph for label %s", targetLabel.PackageName, targetLabel) + return []lsp.Location{}, nil + } + + // Use query.FindRevdeps to find all reverse dependencies + // Parameters: hidden=false, followSubincludes=true, includeSubrepos=true, depth=-1 (unlimited) + revdeps := query.FindRevdeps(h.state, core.BuildLabels{targetLabel}, false, true, true, -1) + + locs := []lsp.Location{} + + // For each reverse dependency, find the exact location of the reference in its BUILD file + for target := range revdeps { + pkg := h.state.Graph.PackageByLabel(target.Label) + if pkg == nil { + continue + } + + uri := lsp.DocumentURI("file://" + filepath.Join(h.root, pkg.Filename)) + refDoc, err := h.maybeOpenDoc(uri) + if err != nil { + continue + } + refAst := h.parseIfNeeded(refDoc) + refFile := refDoc.AspFile() + + // Find all string literals that reference our target + labelStr := targetLabel.String() + shortLabelStr := ":" + targetLabel.Name // For same-package references + + asp.WalkAST(refAst, func(expr *asp.Expression) bool { + if expr.Val.String != "" { + str := astutils.TrimStrLit(expr.Val.String) + // Check if this string matches our target label + if str == labelStr || (refDoc.PkgName == targetLabel.PackageName && str == shortLabelStr) { + // Also try parsing it as a label to handle relative references + if l, err := core.TryParseBuildLabel(str, refDoc.PkgName, ""); err == nil && l == targetLabel { + start := refFile.Pos(expr.Pos) + end := refFile.Pos(expr.EndPos) + locs = append(locs, lsp.Location{ + URI: uri, + Range: lsp.Range{ + Start: lsp.Position{Line: start.Line - 1, Character: start.Column - 1}, + End: lsp.Position{Line: end.Line - 1, Character: end.Column - 1}, + }, + }) + } + } + } + return true + }) + } + + // Optionally include the definition itself if requested + if includeDeclaration && targetName != "" { + if defLoc := h.findLabel(doc.PkgName, targetLabel.String()); defLoc.URI != "" { + locs = append(locs, defLoc) + } + } + + return locs, nil +} + +// pathsMatch checks if two file paths refer to the same file, handling plz-out/gen paths. +func pathsMatch(path1, path2, root string) bool { + // Strip plz-out/gen/ prefix if present + norm1 := strings.TrimPrefix(path1, "plz-out/gen/") + norm2 := strings.TrimPrefix(path2, "plz-out/gen/") + + return norm1 == norm2 +} diff --git a/tools/build_langserver/lsp/references_test.go b/tools/build_langserver/lsp/references_test.go new file mode 100644 index 000000000..2aafbaa31 --- /dev/null +++ b/tools/build_langserver/lsp/references_test.go @@ -0,0 +1,375 @@ +package lsp + +import ( + "os" + "path/filepath" + "testing" + + "github.com/sourcegraph/go-lsp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReferencesBuiltinFunction(t *testing.T) { + // Test finding references to go_library function using actual test data file + uri := "file://" + filepath.Join(os.Getenv("TEST_DIR"), "tools/build_langserver/lsp/test_data/src/core/test.build") + h := initHandler() + content, err := os.ReadFile(filepath.Join(os.Getenv("TEST_DIR"), "tools/build_langserver/lsp/test_data/src/core/test.build")) + require.NoError(t, err) + + err = h.Request("textDocument/didOpen", &lsp.DidOpenTextDocumentParams{ + TextDocument: lsp.TextDocumentItem{ + URI: lsp.DocumentURI(uri), + Text: string(content), + }, + }, nil) + require.NoError(t, err) + h.WaitForPackage("src/core") + + var locs []lsp.Location + err = h.Request("textDocument/references", &lsp.ReferenceParams{ + TextDocumentPositionParams: lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{ + URI: lsp.DocumentURI(uri), + }, + Position: lsp.Position{Line: 0, Character: 5}, // on "go_library" + }, + Context: lsp.ReferenceContext{ + IncludeDeclaration: false, + }, + }, &locs) + require.NoError(t, err) + // go_library is widely used, we should find many references + assert.NotEmpty(t, locs, "expected to find references to go_library") +} + +func TestReferencesGoTestFunction(t *testing.T) { + // Test finding references to go_test function + uri := "file://" + filepath.Join(os.Getenv("TEST_DIR"), "tools/build_langserver/lsp/test_data/src/core/test.build") + h := initHandler() + content, err := os.ReadFile(filepath.Join(os.Getenv("TEST_DIR"), "tools/build_langserver/lsp/test_data/src/core/test.build")) + require.NoError(t, err) + + err = h.Request("textDocument/didOpen", &lsp.DidOpenTextDocumentParams{ + TextDocument: lsp.TextDocumentItem{ + URI: lsp.DocumentURI(uri), + Text: string(content), + }, + }, nil) + require.NoError(t, err) + h.WaitForPackage("src/core") + + var locs []lsp.Location + err = h.Request("textDocument/references", &lsp.ReferenceParams{ + TextDocumentPositionParams: lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{ + URI: lsp.DocumentURI(uri), + }, + Position: lsp.Position{Line: 19, Character: 3}, // on "go_test" (line 20, 0-indexed = 19) + }, + Context: lsp.ReferenceContext{ + IncludeDeclaration: false, + }, + }, &locs) + require.NoError(t, err) + // go_test is widely used, we should find many references + assert.NotEmpty(t, locs, "expected to find references to go_test") +} + +func TestReferencesNoResults(t *testing.T) { + // Test that we get empty results when cursor is on a comment + h := initHandlerText(`# just a comment +# nothing referenceable here`) + h.WaitForPackageTree() + + var locs []lsp.Location + err := h.Request("textDocument/references", &lsp.ReferenceParams{ + TextDocumentPositionParams: lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{ + URI: testURI, + }, + Position: lsp.Position{Line: 0, Character: 5}, + }, + Context: lsp.ReferenceContext{ + IncludeDeclaration: false, + }, + }, &locs) + require.NoError(t, err) + assert.Empty(t, locs, "expected no references for a comment") +} + +func TestReferencesBuildLabel(t *testing.T) { + // Test that label references don't panic when package not in graph + h := initHandlerText(`go_library( + name = "core", + srcs = ["lib.go"], +) + +go_test( + name = "core_test", + srcs = ["lib_test.go"], + deps = [":core"], +)`) + h.WaitForPackageTree() + + var locs []lsp.Location + err := h.Request("textDocument/references", &lsp.ReferenceParams{ + TextDocumentPositionParams: lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{ + URI: testURI, + }, + // Position on the ":core" string in deps + Position: lsp.Position{Line: 8, Character: 13}, + }, + Context: lsp.ReferenceContext{ + IncludeDeclaration: false, + }, + }, &locs) + // Should not error even if package not in graph + require.NoError(t, err) +} + +func TestReferencesIncludeDeclaration(t *testing.T) { + // Test that IncludeDeclaration includes the function definition + uri := "file://" + filepath.Join(os.Getenv("TEST_DIR"), "tools/build_langserver/lsp/test_data/src/core/test.build") + h := initHandler() + content, err := os.ReadFile(filepath.Join(os.Getenv("TEST_DIR"), "tools/build_langserver/lsp/test_data/src/core/test.build")) + require.NoError(t, err) + + err = h.Request("textDocument/didOpen", &lsp.DidOpenTextDocumentParams{ + TextDocument: lsp.TextDocumentItem{ + URI: lsp.DocumentURI(uri), + Text: string(content), + }, + }, nil) + require.NoError(t, err) + h.WaitForPackage("src/core") + + var locsWithDecl []lsp.Location + err = h.Request("textDocument/references", &lsp.ReferenceParams{ + TextDocumentPositionParams: lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{ + URI: lsp.DocumentURI(uri), + }, + Position: lsp.Position{Line: 0, Character: 5}, // on "go_library" + }, + Context: lsp.ReferenceContext{ + IncludeDeclaration: true, + }, + }, &locsWithDecl) + require.NoError(t, err) + + var locsWithoutDecl []lsp.Location + err = h.Request("textDocument/references", &lsp.ReferenceParams{ + TextDocumentPositionParams: lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{ + URI: lsp.DocumentURI(uri), + }, + Position: lsp.Position{Line: 0, Character: 5}, // on "go_library" + }, + Context: lsp.ReferenceContext{ + IncludeDeclaration: false, + }, + }, &locsWithoutDecl) + require.NoError(t, err) + + // With declaration should include one more location (the definition) + assert.GreaterOrEqual(t, len(locsWithDecl), len(locsWithoutDecl), + "IncludeDeclaration=true should return at least as many results") +} + +func TestFindLabelsForFile(t *testing.T) { + // This test verifies that findLabelsForFile returns ALL labels that produce a given file, + // not just the first one found. This is important when the same file is exposed by + // multiple filegroups under different labels (e.g., //pkg:alias1 and //pkg:alias2). + h := initHandler() + h.WaitForPackageTree() + + // Test with a file that exists - should return at least one label or empty + // (we can't easily test multiple labels without setting up complex test data, + // but we verify the function doesn't panic and returns a slice) + labels := h.findLabelsForFile("nonexistent_file.build_defs") + assert.Empty(t, labels, "nonexistent file should return empty labels") +} + +func TestBuiltinHasLabelsField(t *testing.T) { + // Verify that builtins are populated with Labels (slice) for subinclude tracking + h := initHandler() + h.WaitForPackageTree() + + // Check that builtins exist and have the Labels field properly initialized + h.mutex.Lock() + defer h.mutex.Unlock() + + // We should have some builtins loaded + assert.NotEmpty(t, h.builtins, "expected builtins to be populated") + + // Each builtin should have a Labels field (may be empty for core builtins) + for name, builtinList := range h.builtins { + for _, b := range builtinList { + // Labels field should exist (not nil) - it's a slice so empty is valid + assert.NotNil(t, b.Stmt, "builtin %s should have a statement", name) + // Labels is a slice, so we just verify it's accessible (not nil check needed for slices) + _ = b.Labels // This would panic if the field didn't exist + } + } +} + +// initHandlerWithRoot initializes a handler with a custom root directory +func initHandlerWithRoot(root string) *Handler { + h := NewHandler() + h.Conn = &rpc{ + Notifications: make(chan message, 100), + } + result := &lsp.InitializeResult{} + if err := h.Request("initialize", &lsp.InitializeParams{ + Capabilities: lsp.ClientCapabilities{}, + RootURI: lsp.DocumentURI("file://" + root), + }, result); err != nil { + log.Fatalf("init failed: %s", err) + } + return h +} + +func TestReferencesMultipleLabelsOneFile(t *testing.T) { + // Test that find-references works when a file is exposed via multiple labels. + // Scenario: + // - shared_defs/my_func.build_defs defines my_shared_func() + // - shared_defs:alias1 and shared_defs:alias2 both expose my_func.build_defs + // - pkg_a subincludes alias1 and calls my_shared_func() + // - pkg_b subincludes alias2 and calls my_shared_func() + // Expected: find-references on my_shared_func should find both pkg_a and pkg_b usages + + testDataRoot := filepath.Join(os.Getenv("TEST_DIR"), "tools/build_langserver/lsp/test_data_find_references") + h := initHandlerWithRoot(testDataRoot) + + // Open the shared_defs file where my_shared_func is defined + defsFile := filepath.Join(testDataRoot, "shared_defs/my_func.build_defs") + uri := lsp.DocumentURI("file://" + defsFile) + content, err := os.ReadFile(defsFile) + require.NoError(t, err) + + err = h.Request("textDocument/didOpen", &lsp.DidOpenTextDocumentParams{ + TextDocument: lsp.TextDocumentItem{ + URI: uri, + Text: string(content), + }, + }, nil) + require.NoError(t, err) + h.WaitForPackageTree() + + // Request references for my_shared_func (cursor on "def my_shared_func") + var locs []lsp.Location + err = h.Request("textDocument/references", &lsp.ReferenceParams{ + TextDocumentPositionParams: lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{URI: uri}, + Position: lsp.Position{Line: 0, Character: 6}, // on "my_shared_func" + }, + Context: lsp.ReferenceContext{IncludeDeclaration: false}, + }, &locs) + require.NoError(t, err) + + // Should find references in both pkg_a and pkg_b + // Even though they subinclude different labels (alias1 vs alias2) + var foundPkgA, foundPkgB bool + for _, loc := range locs { + uriStr := string(loc.URI) + if filepath.Base(filepath.Dir(uriStr)) == "pkg_a" { + foundPkgA = true + } + if filepath.Base(filepath.Dir(uriStr)) == "pkg_b" { + foundPkgB = true + } + } + + assert.True(t, foundPkgA, "expected to find reference in pkg_a (subincludes alias1)") + assert.True(t, foundPkgB, "expected to find reference in pkg_b (subincludes alias2)") +} + +func TestReferencesNoFalsePositivesForSameNameDifferentFile(t *testing.T) { + // Test that we DON'T get false positives when two different files define + // functions with the same name. + // Scenario: + // - defs1/func.build_defs defines duplicate_func() + // - defs2/func.build_defs ALSO defines duplicate_func() (different file!) + // - pkg_uses_defs1 subincludes //defs1 and calls duplicate_func() + // - pkg_uses_defs2 subincludes //defs2 and calls duplicate_func() + // Expected: find-references on defs1's definition should ONLY find pkg_uses_defs1, + // NOT pkg_uses_defs2 (which uses a different definition) + + testDataRoot := filepath.Join(os.Getenv("TEST_DIR"), "tools/build_langserver/lsp/test_data_find_references") + h := initHandlerWithRoot(testDataRoot) + + // Open defs1's file where duplicate_func is defined + defs1File := filepath.Join(testDataRoot, "defs1/func.build_defs") + uri1 := lsp.DocumentURI("file://" + defs1File) + content1, err := os.ReadFile(defs1File) + require.NoError(t, err) + + err = h.Request("textDocument/didOpen", &lsp.DidOpenTextDocumentParams{ + TextDocument: lsp.TextDocumentItem{ + URI: uri1, + Text: string(content1), + }, + }, nil) + require.NoError(t, err) + h.WaitForPackageTree() + + // Request references for duplicate_func from defs1 + var locs []lsp.Location + err = h.Request("textDocument/references", &lsp.ReferenceParams{ + TextDocumentPositionParams: lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{URI: uri1}, + Position: lsp.Position{Line: 0, Character: 6}, // on "duplicate_func" + }, + Context: lsp.ReferenceContext{IncludeDeclaration: false}, + }, &locs) + require.NoError(t, err) + + // Check which packages were found + var foundPkgUsesDefs1, foundPkgUsesDefs2 bool + for _, loc := range locs { + uriStr := string(loc.URI) + dir := filepath.Base(filepath.Dir(uriStr)) + if dir == "pkg_uses_defs1" { + foundPkgUsesDefs1 = true + } + if dir == "pkg_uses_defs2" { + foundPkgUsesDefs2 = true + } + } + + // Should find pkg_uses_defs1 (correctly subincludes defs1) + assert.True(t, foundPkgUsesDefs1, "expected to find reference in pkg_uses_defs1") + // Should NOT find pkg_uses_defs2 (uses different definition from defs2) + assert.False(t, foundPkgUsesDefs2, "should NOT find reference in pkg_uses_defs2 - it uses a different definition") +} + +// TestReferencesSurvivesBrokenParse verifies that find-references doesn't crash +// when a document fails to parse. +func TestReferencesSurvivesBrokenParse(t *testing.T) { + h := initHandler() + + // Open a doc with syntactically broken content + brokenContent := "def broken_func(\n # missing closing paren" + h.Request("textDocument/didOpen", &lsp.DidOpenTextDocumentParams{ + TextDocument: lsp.TextDocumentItem{ + URI: "file:///tmp/broken.build", + Text: brokenContent, + Version: 1, + }, + }, nil) + + // parseIfNeeded should handle the parse failure gracefully + doc := h.docs["/tmp/broken.build"] + if doc != nil { + ast := h.parseIfNeeded(doc) + // Should return empty/partial AST, not panic + t.Logf("parsed broken file, got %d statements", len(ast)) + } + + // Also verify find-references doesn't crash on a function that doesn't exist + locs, err := h.findFunctionReferences("broken_func", "", false) + assert.NoError(t, err) + t.Logf("found %d locations for non-existent func (expected: 0)", len(locs)) +} diff --git a/tools/build_langserver/lsp/test_data_find_references/.plzconfig b/tools/build_langserver/lsp/test_data_find_references/.plzconfig new file mode 100644 index 000000000..8b1807af4 --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/.plzconfig @@ -0,0 +1,2 @@ +[parse] +buildfilename = test.build diff --git a/tools/build_langserver/lsp/test_data_find_references/defs1/func.build_defs b/tools/build_langserver/lsp/test_data_find_references/defs1/func.build_defs new file mode 100644 index 000000000..04b53ab16 --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/defs1/func.build_defs @@ -0,0 +1,3 @@ +def duplicate_func(name:str): + """First definition of duplicate_func - from defs1.""" + return filegroup(name = name) diff --git a/tools/build_langserver/lsp/test_data_find_references/defs1/test.build b/tools/build_langserver/lsp/test_data_find_references/defs1/test.build new file mode 100644 index 000000000..75490c41f --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/defs1/test.build @@ -0,0 +1,5 @@ +filegroup( + name = "defs1", + srcs = ["func.build_defs"], + visibility = ["PUBLIC"], +) diff --git a/tools/build_langserver/lsp/test_data_find_references/defs2/func.build_defs b/tools/build_langserver/lsp/test_data_find_references/defs2/func.build_defs new file mode 100644 index 000000000..09c539d39 --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/defs2/func.build_defs @@ -0,0 +1,3 @@ +def duplicate_func(name:str): + """Second definition of duplicate_func - from defs2 (different file!).""" + return filegroup(name = name) diff --git a/tools/build_langserver/lsp/test_data_find_references/defs2/test.build b/tools/build_langserver/lsp/test_data_find_references/defs2/test.build new file mode 100644 index 000000000..082ac5d12 --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/defs2/test.build @@ -0,0 +1,5 @@ +filegroup( + name = "defs2", + srcs = ["func.build_defs"], + visibility = ["PUBLIC"], +) diff --git a/tools/build_langserver/lsp/test_data_find_references/pkg_a/test.build b/tools/build_langserver/lsp/test_data_find_references/pkg_a/test.build new file mode 100644 index 000000000..a822a8449 --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/pkg_a/test.build @@ -0,0 +1,6 @@ +subinclude("//shared_defs:alias1") + +my_shared_func( + name = "target_a", + srcs = ["a.txt"], +) diff --git a/tools/build_langserver/lsp/test_data_find_references/pkg_b/test.build b/tools/build_langserver/lsp/test_data_find_references/pkg_b/test.build new file mode 100644 index 000000000..4126933f5 --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/pkg_b/test.build @@ -0,0 +1,6 @@ +subinclude("//shared_defs:alias2") + +my_shared_func( + name = "target_b", + srcs = ["b.txt"], +) diff --git a/tools/build_langserver/lsp/test_data_find_references/pkg_uses_defs1/test.build b/tools/build_langserver/lsp/test_data_find_references/pkg_uses_defs1/test.build new file mode 100644 index 000000000..7fa30ff76 --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/pkg_uses_defs1/test.build @@ -0,0 +1,3 @@ +subinclude("//defs1") + +duplicate_func(name = "from_defs1") diff --git a/tools/build_langserver/lsp/test_data_find_references/pkg_uses_defs2/test.build b/tools/build_langserver/lsp/test_data_find_references/pkg_uses_defs2/test.build new file mode 100644 index 000000000..41793cd48 --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/pkg_uses_defs2/test.build @@ -0,0 +1,3 @@ +subinclude("//defs2") + +duplicate_func(name = "from_defs2") diff --git a/tools/build_langserver/lsp/test_data_find_references/shared_defs/my_func.build_defs b/tools/build_langserver/lsp/test_data_find_references/shared_defs/my_func.build_defs new file mode 100644 index 000000000..9b4c890fb --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/shared_defs/my_func.build_defs @@ -0,0 +1,6 @@ +def my_shared_func(name:str, srcs:list=[]): + """A function exposed via multiple labels for testing find-references.""" + return filegroup( + name = name, + srcs = srcs, + ) diff --git a/tools/build_langserver/lsp/test_data_find_references/shared_defs/test.build b/tools/build_langserver/lsp/test_data_find_references/shared_defs/test.build new file mode 100644 index 000000000..48ad8062c --- /dev/null +++ b/tools/build_langserver/lsp/test_data_find_references/shared_defs/test.build @@ -0,0 +1,12 @@ +# Two filegroups exposing the same file under different labels +filegroup( + name = "alias1", + srcs = ["my_func.build_defs"], + visibility = ["PUBLIC"], +) + +filegroup( + name = "alias2", + srcs = ["my_func.build_defs"], + visibility = ["PUBLIC"], +)