Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/cmap/cerrmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,14 @@ func (m *ErrMap[K, V]) GetOrSet(key K, f func() (V, error)) (V, error) {
}
return v.Val, v.Err
}

// Range calls f for each key-value pair in the map.
// No particular consistency guarantees are made during iteration.
func (m *ErrMap[K, V]) Range(f func(key K, val V)) {
m.m.Range(func(key K, val errV[V]) {
if val.Err != nil {
return // skip errors
}
f(key, val.Val)
})
}
19 changes: 19 additions & 0 deletions src/cmap/cmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ func (m *Map[K, V]) Values() []V {
return ret
}

// Range calls f for each key-value pair in the map.
// No particular consistency guarantees are made during iteration.
func (m *Map[K, V]) Range(f func(key K, val V)) {
for i := 0; i < len(m.shards); i++ {
m.shards[i].Range(f)
}
}

// An awaitableValue represents a value in the map & an awaitable channel for it to exist.
type awaitableValue[V any] struct {
Val V
Expand Down Expand Up @@ -195,3 +203,14 @@ func (s *shard[K, V]) Contains(key K) bool {
_, ok := s.m[key]
return ok
}

// Range calls f for each key-value pair in this shard.
func (s *shard[K, V]) Range(f func(key K, val V)) {
s.l.RLock()
defer s.l.RUnlock()
for k, v := range s.m {
if v.Wait == nil { // Only include completed values
f(k, v.Val)
}
}
}
18 changes: 18 additions & 0 deletions src/parse/asp/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,24 @@ func (p *Parser) optimiseBuiltinCalls(stmts []*Statement) {
}
}

// AllFunctionsByFile returns all function definitions grouped by filename.
// This includes functions from builtins, plugins, and subincludes.
// It iterates over the ASTs stored by the interpreter.
func (p *Parser) AllFunctionsByFile() map[string][]*Statement {
if p.interpreter == nil || p.interpreter.asts == nil {
return nil
}
result := make(map[string][]*Statement)
p.interpreter.asts.Range(func(filename string, stmts []*Statement) {
for _, stmt := range stmts {
if stmt.FuncDef != nil {
result[filename] = append(result[filename], stmt)
}
}
})
return result
}

// whitelistedKwargs returns true if the given built-in function name is allowed to
// be called as non-kwargs.
// TODO(peterebden): Come up with a syntax that exposes this directly in the file.
Expand Down
13 changes: 13 additions & 0 deletions src/parse/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ func InitParser(state *core.BuildState) *core.BuildState {
return state
}

// GetAspParser returns the underlying asp.Parser from the state's parser.
// This is useful for tools like the language server that need direct access to AST information.
// Returns nil if the state's parser is not set or is not an aspParser.
func GetAspParser(state *core.BuildState) *asp.Parser {
if state.Parser == nil {
return nil
}
if ap, ok := state.Parser.(*aspParser); ok {
return ap.parser
}
return nil
}

// aspParser implements the core.Parser interface around our parser package.
type aspParser struct {
parser *asp.Parser
Expand Down
10 changes: 9 additions & 1 deletion tools/build_langserver/lsp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
"definition.go",
"diagnostics.go",
"lsp.go",
"references.go",
"symbols.go",
"text.go",
],
Expand All @@ -17,8 +18,10 @@ go_library(
"//rules",
"//src/core",
"//src/fs",
"//src/parse",
"//src/parse/asp",
"//src/plz",
"//src/query",
"//tools/build_langserver/lsp/astutils",
],
)
Expand All @@ -29,15 +32,20 @@ go_test(
srcs = [
"definition_test.go",
"lsp_test.go",
"references_test.go",
"symbols_test.go",
],
data = ["test_data"],
data = [
"test_data",
"test_data_find_references",
],
deps = [
":lsp",
"///third_party/go/github.com_please-build_buildtools//build",
"///third_party/go/github.com_sourcegraph_go-lsp//:go-lsp",
"///third_party/go/github.com_sourcegraph_jsonrpc2//:jsonrpc2",
"///third_party/go/github.com_stretchr_testify//assert",
"///third_party/go/github.com_stretchr_testify//require",
"//src/cli",
"//src/core",
],
Expand Down
6 changes: 3 additions & 3 deletions tools/build_langserver/lsp/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ func (h *Handler) completeString(doc *doc, s string, line, col int) (*lsp.Comple
// completeIdent completes an arbitrary identifier
func (h *Handler) completeIdent(doc *doc, s string, line, col int) (*lsp.CompletionList, error) {
list := &lsp.CompletionList{}
for name, f := range h.builtins {
if strings.HasPrefix(name, s) {
for name, builtins := range h.builtins {
if strings.HasPrefix(name, s) && len(builtins) > 0 {
item := completionItem(name, "", line, col)
item.Documentation = f.Stmt.FuncDef.Docstring
item.Documentation = builtins[0].Stmt.FuncDef.Docstring
item.Kind = lsp.CIKFunction
list.Items = append(list.Items, item)
}
Expand Down
35 changes: 23 additions & 12 deletions tools/build_langserver/lsp/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,40 @@ func (h *Handler) definition(params *lsp.TextDocumentPositionParams) ([]lsp.Loca
ast := h.parseIfNeeded(doc)
f := doc.AspFile()

var locs []lsp.Location
locs := []lsp.Location{}
pos := aspPos(params.Position)
asp.WalkAST(ast, func(expr *asp.Expression) bool {
if !asp.WithinRange(pos, f.Pos(expr.Pos), f.Pos(expr.EndPos)) {
exprStart := f.Pos(expr.Pos)
exprEnd := f.Pos(expr.EndPos)
if !asp.WithinRange(pos, exprStart, exprEnd) {
return false
}

if expr.Val.Ident != nil {
if loc := h.findGlobal(expr.Val.Ident.Name); loc.URI != "" {
locs = append(locs, loc)
}
return false
}

if expr.Val.String != "" {
label := astutils.TrimStrLit(expr.Val.String)
if loc := h.findLabel(doc.PkgName, label); loc.URI != "" {
locs = append(locs, loc)
}
return false
}

return true
})
// It might also be a statement.
// It might also be a statement (e.g. a function call like go_library(...))
asp.WalkAST(ast, func(stmt *asp.Statement) bool {
if stmt.Ident != nil {
endPos := f.Pos(stmt.Pos)
stmtStart := f.Pos(stmt.Pos)
endPos := stmtStart
// TODO(jpoole): The AST should probably just have this information
endPos.Column += len(stmt.Ident.Name)

if !asp.WithinRange(pos, f.Pos(stmt.Pos), endPos) {
return false
if !asp.WithinRange(pos, stmtStart, endPos) {
return true // continue to other statements
}

if loc := h.findGlobal(stmt.Ident.Name); loc.URI != "" {
locs = append(locs, loc)
}
Expand All @@ -78,6 +77,9 @@ func (h *Handler) findLabel(currentPath, label string) lsp.Location {
}

pkg := h.state.Graph.PackageByLabel(l)
if pkg == nil {
return lsp.Location{}
}
uri := lsp.DocumentURI("file://" + filepath.Join(h.root, pkg.Filename))
loc := lsp.Location{URI: uri}
doc, err := h.maybeOpenDoc(uri)
Expand Down Expand Up @@ -137,9 +139,18 @@ func findName(args []asp.CallArgument) string {

// findGlobal returns the location of a global of the given name.
func (h *Handler) findGlobal(name string) lsp.Location {
if f, present := h.builtins[name]; present {
h.mutex.Lock()
builtins := h.builtins[name]
h.mutex.Unlock()
if len(builtins) > 0 {
f := builtins[0]
filename := f.Pos.Filename
// Make path absolute if it's relative
if !filepath.IsAbs(filename) {
filename = filepath.Join(h.root, filename)
}
return lsp.Location{
URI: lsp.DocumentURI("file://" + f.Pos.Filename),
URI: lsp.DocumentURI("file://" + filename),
Range: rng(f.Pos, f.EndPos),
}
}
Expand Down
97 changes: 91 additions & 6 deletions tools/build_langserver/lsp/lsp.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/thought-machine/please/rules"
"github.com/thought-machine/please/src/core"
"github.com/thought-machine/please/src/fs"
"github.com/thought-machine/please/src/parse"
"github.com/thought-machine/please/src/parse/asp"
"github.com/thought-machine/please/src/plz"
)
Expand All @@ -33,14 +34,15 @@ type Handler struct {
mutex sync.Mutex // guards docs
state *core.BuildState
parser *asp.Parser
builtins map[string]builtin
builtins map[string][]builtin
pkgs *pkg
root string
}

type builtin struct {
Stmt *asp.Statement
Pos, EndPos asp.FilePosition
Labels []core.BuildLabel // which subinclude targets can provide this (empty for core builtins)
}

// A Conn is a minimal set of the jsonrpc2.Conn that we need.
Expand All @@ -55,7 +57,7 @@ func NewHandler() *Handler {
return &Handler{
docs: map[string]*doc{},
pkgs: &pkg{},
builtins: map[string]builtin{},
builtins: map[string][]builtin{},
}
}

Expand Down Expand Up @@ -173,6 +175,12 @@ func (h *Handler) handle(method string, params *json.RawMessage) (res interface{
return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams}
}
return h.definition(positionParams)
case "textDocument/references":
referenceParams := &lsp.ReferenceParams{}
if err := json.Unmarshal(*params, referenceParams); err != nil {
return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams}
}
return h.references(referenceParams)
default:
return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeMethodNotFound}
}
Expand All @@ -195,13 +203,35 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul
}
h.state = core.NewBuildState(config)
h.state.NeedBuild = false
// We need an unwrapped parser instance as well for raw access.
h.parser = asp.NewParser(h.state)
// Initialize the parser on state first, so that plz.RunHost uses the same parser.
// This ensures plugin subincludes are stored in the same AST cache we use.
parse.InitParser(h.state)
h.parser = parse.GetAspParser(h.state)
if h.parser == nil {
return nil, fmt.Errorf("failed to get asp parser from state")
}
// Parse everything in the repo up front.
// This is a lot easier than trying to do clever partial parses later on, although
// eventually we may want that if we start dealing with truly large repos.
go func() {
// Start a goroutine to periodically load parser functions as they become available.
// This allows go-to-definition to work progressively while the full parse runs.
done := make(chan struct{})
go func() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-done:
h.loadParserFunctions()
return
case <-ticker.C:
h.loadParserFunctions()
}
}
}()
plz.RunHost(core.WholeGraph, h.state)
close(done)
log.Debug("initial parse complete")
h.buildPackageTree()
log.Debug("built completion package tree")
Expand All @@ -221,6 +251,7 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul
DocumentFormattingProvider: true,
DocumentSymbolProvider: true,
DefinitionProvider: true,
ReferencesProvider: true,
CompletionProvider: &lsp.CompletionOptions{
TriggerCharacters: []string{"/", ":"},
},
Expand Down Expand Up @@ -256,18 +287,72 @@ func (h *Handler) loadBuiltins() error {
f := asp.NewFile(dest, data)
for _, stmt := range stmts {
if stmt.FuncDef != nil {
h.builtins[stmt.FuncDef.Name] = builtin{
h.builtins[stmt.FuncDef.Name] = append(h.builtins[stmt.FuncDef.Name], builtin{
Stmt: stmt,
Pos: f.Pos(stmt.Pos),
EndPos: f.Pos(stmt.EndPos),
}
})
}
}
}
log.Debug("loaded builtin function information")
return nil
}

// loadParserFunctions loads function definitions from the parser's ASTs.
// This includes plugin-defined functions like go_library, python_library, etc.
func (h *Handler) loadParserFunctions() {
funcsByFile := h.parser.AllFunctionsByFile()
if funcsByFile == nil {
return
}
h.mutex.Lock()
defer h.mutex.Unlock()
for filename, stmts := range funcsByFile {
// Read the file to create a File object for position conversion
data, err := os.ReadFile(filename)
if err != nil {
log.Warning("failed to read file %s: %v", filename, err)
continue
}
file := asp.NewFile(filename, data)
labels := h.findLabelsForFile(filename)
for _, stmt := range stmts {
name := stmt.FuncDef.Name
h.builtins[name] = append(h.builtins[name], builtin{
Stmt: stmt,
Pos: file.Pos(stmt.Pos),
EndPos: file.Pos(stmt.EndPos),
Labels: labels,
})
}
}
}

// findLabelsForFile finds all build labels that produce the given output file.
// A file can be exposed by multiple targets (e.g., multiple filegroups pointing to the same source).
func (h *Handler) findLabelsForFile(filename string) []core.BuildLabel {
// Make the filename absolute for comparison
if !filepath.IsAbs(filename) {
filename = filepath.Join(h.root, filename)
}
var labels []core.BuildLabel
for _, pkg := range h.state.Graph.PackageMap() {
for _, target := range pkg.AllTargets() {
for _, out := range target.FullOutputs() {
absOut := out
if !filepath.IsAbs(out) {
absOut = filepath.Join(h.root, out)
}
if absOut == filename {
labels = append(labels, target.Label)
}
}
}
}
return labels
}

// fromURI converts a DocumentURI to a path.
func fromURI(uri lsp.DocumentURI) string {
if !strings.HasPrefix(string(uri), "file://") {
Expand Down
Loading
Loading