From 51a10bcf62151cb25e0c4766e32b170b98e6580d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 26 Apr 2025 11:13:17 -0400 Subject: [PATCH 001/196] jsonschema: resolve URIs for $id Assign a base URI to every schema. See the comment on resolveURIs for details. Change-Id: I2332d825c486766832503cabdd0bf9d819171d53 Reviewed-on: https://go-review.googlesource.com/c/tools/+/669997 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- internal/mcp/jsonschema/resolve.go | 112 ++++++++++++++++++++++- internal/mcp/jsonschema/resolve_test.go | 58 +++++++++++- internal/mcp/jsonschema/schema.go | 17 +++- internal/mcp/jsonschema/validate_test.go | 2 +- 4 files changed, 181 insertions(+), 8 deletions(-) diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index 2a0f1abe391..d6199870d0c 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -10,6 +10,7 @@ package jsonschema import ( "errors" "fmt" + "net/url" "regexp" ) @@ -20,20 +21,44 @@ import ( // Call [Schema.Resolve] to obtain a Resolved from a Schema. type Resolved struct { root *Schema + // map from $ids to their schemas + resolvedURIs map[string]*Schema } // Resolve resolves all references within the schema and performs other tasks that // prepare the schema for validation. -func (root *Schema) Resolve() (*Resolved, error) { +// baseURI can be empty, or an absolute URI (one that starts with a scheme). +// It is resolved (in the URI sense; see [url.ResolveReference]) with root's $id property. +// If the resulting URI is not absolute, then the schema cannot not contain relative URI references. +func (root *Schema) Resolve(baseURI string) (*Resolved, error) { // There are three steps involved in preparing a schema to validate. // 1. Check: validate the schema against a meta-schema, and perform other well-formedness // checks. Precompute some values along the way. - // 2. Resolve URIs: TODO. + // 2. Resolve URIs: determine the base URI of the root and all its subschemas, and + // resolve (in the URI sense) all identifiers and anchors with their bases. This step results + // in a map from URIs to schemas within root. // 3. Resolve references: TODO. if err := root.check(); err != nil { return nil, err } - return &Resolved{root: root}, nil + var base *url.URL + if baseURI == "" { + base = &url.URL{} // so we can call ResolveReference on it + } else { + var err error + base, err = url.Parse(baseURI) + if err != nil { + return nil, fmt.Errorf("parsing base URI: %w", err) + } + } + m, err := resolveURIs(root, base) + if err != nil { + return nil, err + } + return &Resolved{ + root: root, + resolvedURIs: m, + }, nil } func (s *Schema) check() error { @@ -91,3 +116,84 @@ func (s *Schema) checkLocal(report func(error)) { } } } + +// resolveURIs resolves the ids and anchors in all the schemas of root, relative +// to baseURI. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2, section +// 8.2.1. + +// TODO(jba): anchors (section 8.2.2) +// TODO(jba): dynamicAnchors (ditto) +// +// Every schema has a base URI and a parent base URI. +// +// The parent base URI is the base URI of the lexically enclosing schema, or for +// a root schema, the URI it was loaded from or the one supplied to [Schema.Resolve]. +// +// If the schema has no $id property, the base URI of a schema is that of its parent. +// If the schema does have an $id, it must be a URI, possibly relative. The schema's +// base URI is the $id resolved (in the sense of [url.URL.ResolveReference]) against +// the parent base. +// +// As an example, consider this schema loaded from http://a.com/root.json (quotes omitted): +// +// { +// allOf: [ +// {$id: "sub1.json", minLength: 5}, +// {$id: "http://b.com", minimum: 10}, +// {not: {maximum: 20}} +// ] +// } +// +// The base URIs are as follows. Schema locations are expressed in the JSON Pointer notation. +// +// schema base URI +// root http://a.com/root.json +// allOf/0 http://a.com/sub1.json +// allOf/1 http://b.com (absolute $id; doesn't matter that it's not under the loaded URI) +// allOf/2 http://a.com/root.json (inherited from parent) +// allOf/2/not http://a.com/root.json (inherited from parent) +func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { + resolvedURIs := map[string]*Schema{} + + var resolve func(s, base *Schema) error + resolve = func(s, base *Schema) error { + // ids are scoped to the root. + if s.ID == "" { + // If a schema doesn't have an $id, its base is the parent base. + s.baseURI = base.baseURI + } else { + // A non-empty ID establishes a new base. + idURI, err := url.Parse(s.ID) + if err != nil { + return err + } + if idURI.Fragment != "" { + return fmt.Errorf("$id %s must not have a fragment", s.ID) + } + // The base URI for this schema is its $id resolved against the parent base. + s.baseURI = base.baseURI.ResolveReference(idURI) + if !s.baseURI.IsAbs() { + return fmt.Errorf("$id %s does not resolve to an absolute URI (base is %s)", s.ID, s.baseURI) + } + resolvedURIs[s.baseURI.String()] = s + base = s // needed for anchors + } + + for c := range s.children() { + if err := resolve(c, base); err != nil { + return err + } + } + return nil + } + + // Set the root URI to the base for now. If the root has an $id, the base will change. + root.baseURI = baseURI + // The original base, even if changed, is still a valid way to refer to the root. + resolvedURIs[baseURI.String()] = root + if err := resolve(root, root); err != nil { + return nil, err + } + return resolvedURIs, nil +} diff --git a/internal/mcp/jsonschema/resolve_test.go b/internal/mcp/jsonschema/resolve_test.go index 7e2929438f4..474b993fe47 100644 --- a/internal/mcp/jsonschema/resolve_test.go +++ b/internal/mcp/jsonschema/resolve_test.go @@ -5,7 +5,10 @@ package jsonschema import ( + "maps" + "net/url" "regexp" + "slices" "testing" ) @@ -24,7 +27,7 @@ func TestCheckLocal(t *testing.T) { "regexp", }, } { - _, err := tt.s.Resolve() + _, err := tt.s.Resolve("") if err == nil { t.Errorf("%s: unexpectedly passed", tt.s.json()) continue @@ -35,3 +38,56 @@ func TestCheckLocal(t *testing.T) { } } } + +func TestResolveURIs(t *testing.T) { + for _, baseURI := range []string{"", "http://a.com"} { + t.Run(baseURI, func(t *testing.T) { + root := &Schema{ + ID: "http://b.com", + Items: &Schema{ + ID: "/foo.json", + }, + Contains: &Schema{ + ID: "/bar.json", + Anchor: "a", + Items: &Schema{ + Anchor: "b", + Items: &Schema{ + // An ID shouldn't be a query param, but this tests + // resolving an ID with its parent. + ID: "?items", + Anchor: "c", + }, + }, + }, + } + base, err := url.Parse(baseURI) + if err != nil { + t.Fatal(err) + } + got, err := resolveURIs(root, base) + if err != nil { + t.Fatal(err) + } + + want := map[string]*Schema{ + baseURI: root, + "http://b.com/foo.json": root.Items, + "http://b.com/bar.json": root.Contains, + "http://b.com/bar.json?items": root.Contains.Items.Items, + } + if baseURI != root.ID { + want[root.ID] = root + } + + gotKeys := slices.Sorted(maps.Keys(got)) + wantKeys := slices.Sorted(maps.Keys(want)) + if !slices.Equal(gotKeys, wantKeys) { + t.Errorf("ID keys:\ngot %q\nwant %q", gotKeys, wantKeys) + } + if !maps.Equal(got, want) { + t.Errorf("IDs:\ngot %+v\n\nwant %+v", got, want) + } + }) + } +} diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 15960b47ca8..4105607b5ba 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -13,6 +13,7 @@ import ( "fmt" "iter" "math" + "net/url" "regexp" ) @@ -109,6 +110,11 @@ type Schema struct { DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` // computed fields + // If the schema doesn't have an ID, the base URI is that of its parent. + // Otherwise, the base URI is the ID resolved against the parent's baseURI. + // The parent base URI at top level is where the schema was loaded from, or + // if not loaded, then it should be provided to Schema.Resolve. + baseURI *url.URL pattern *regexp.Regexp patternProperties map[*regexp.Regexp]*Schema } @@ -283,11 +289,11 @@ func (s *Schema) every(f func(*Schema) bool) bool { f(s) && s.everyChild(f) } -// everyChild returns an iterator over the immediate child schemas of s. +// everyChild reports whether f is true for every immediate child schema of s. // -// It does not yield nils from fields holding individual schemas, like Contains, +// It does not call f on nil-valued fields holding individual schemas, like Contains, // because a nil value indicates that the field is absent. -// It does yield nils when they occur in slices and maps, so those invalid values +// It does call f on nils when they occur in slices and maps, so those invalid values // can be detected when the schema is validated. func (s *Schema) everyChild(f func(*Schema) bool) bool { // Fields that contain individual schemas. A nil is valid: it just means the field isn't present. @@ -324,3 +330,8 @@ func (s *Schema) everyChild(f func(*Schema) bool) bool { func (s *Schema) all() iter.Seq[*Schema] { return func(yield func(*Schema) bool) { s.every(yield) } } + +// children wraps everyChild in an iterator. +func (s *Schema) children() iter.Seq[*Schema] { + return func(yield func(*Schema) bool) { s.everyChild(yield) } +} diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index 10e757fd8e6..7abf5f0fd27 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -51,7 +51,7 @@ func TestValidate(t *testing.T) { } for _, g := range groups { t.Run(g.Description, func(t *testing.T) { - rs, err := g.Schema.Resolve() + rs, err := g.Schema.Resolve("") if err != nil { t.Fatal(err) } From 9635d6cadabb59f04365f184e8df6c13ca0e9fe8 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 26 Apr 2025 11:24:37 -0400 Subject: [PATCH 002/196] jsonschema: resolve anchors Change-Id: Ibe7d8df0a46d3e3a8b5d0651b1113a54e756c7df Reviewed-on: https://go-review.googlesource.com/c/tools/+/669998 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- internal/mcp/jsonschema/resolve.go | 15 +++++++++++++-- internal/mcp/jsonschema/resolve_test.go | 23 ++++++++++++++++++----- internal/mcp/jsonschema/schema.go | 4 +++- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index d6199870d0c..ae3da3737c3 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -122,8 +122,7 @@ func (s *Schema) checkLocal(report func(error)) { // See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2, section // 8.2.1. -// TODO(jba): anchors (section 8.2.2) -// TODO(jba): dynamicAnchors (ditto) +// TODO(jba): dynamicAnchors (§8.2.2) // // Every schema has a base URI and a parent base URI. // @@ -180,6 +179,18 @@ func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { base = s // needed for anchors } + // Anchors are URI fragments that are scoped to their base. + // We treat them as keys in a map stored within the schema. + if s.Anchor != "" { + if base.anchors[s.Anchor] != nil { + return fmt.Errorf("duplicate anchor %q in %s", s.Anchor, base.baseURI) + } + if base.anchors == nil { + base.anchors = map[string]*Schema{} + } + base.anchors[s.Anchor] = s + } + for c := range s.children() { if err := resolve(c, base); err != nil { return err diff --git a/internal/mcp/jsonschema/resolve_test.go b/internal/mcp/jsonschema/resolve_test.go index 474b993fe47..f8fb2b5dfb1 100644 --- a/internal/mcp/jsonschema/resolve_test.go +++ b/internal/mcp/jsonschema/resolve_test.go @@ -70,23 +70,36 @@ func TestResolveURIs(t *testing.T) { t.Fatal(err) } - want := map[string]*Schema{ + wantIDs := map[string]*Schema{ baseURI: root, "http://b.com/foo.json": root.Items, "http://b.com/bar.json": root.Contains, "http://b.com/bar.json?items": root.Contains.Items.Items, } if baseURI != root.ID { - want[root.ID] = root + wantIDs[root.ID] = root + } + wantAnchors := map[*Schema]map[string]*Schema{ + root.Contains: {"a": root.Contains, "b": root.Contains.Items}, + root.Contains.Items.Items: {"c": root.Contains.Items.Items}, } gotKeys := slices.Sorted(maps.Keys(got)) - wantKeys := slices.Sorted(maps.Keys(want)) + wantKeys := slices.Sorted(maps.Keys(wantIDs)) if !slices.Equal(gotKeys, wantKeys) { t.Errorf("ID keys:\ngot %q\nwant %q", gotKeys, wantKeys) } - if !maps.Equal(got, want) { - t.Errorf("IDs:\ngot %+v\n\nwant %+v", got, want) + if !maps.Equal(got, wantIDs) { + t.Errorf("IDs:\ngot %+v\n\nwant %+v", got, wantIDs) + } + for s := range root.all() { + if want := wantAnchors[s]; want != nil { + if got := s.anchors; !maps.Equal(got, want) { + t.Errorf("anchors:\ngot %+v\n\nwant %+v", got, want) + } + } else if s.anchors != nil { + t.Errorf("non-nil anchors for %s", s) + } } }) } diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 4105607b5ba..1b658045cc9 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -114,7 +114,9 @@ type Schema struct { // Otherwise, the base URI is the ID resolved against the parent's baseURI. // The parent base URI at top level is where the schema was loaded from, or // if not loaded, then it should be provided to Schema.Resolve. - baseURI *url.URL + baseURI *url.URL + // map from anchors to subschemas + anchors map[string]*Schema pattern *regexp.Regexp patternProperties map[*regexp.Regexp]*Schema } From 898dcae1e5247e0376166df3541732b2f3834cd6 Mon Sep 17 00:00:00 2001 From: Peter Weinberger Date: Wed, 23 Apr 2025 12:13:36 -0400 Subject: [PATCH 003/196] gopls/internal/golang/completion: new code for unimported completions The unimported completion code is invoked when the user is looking for a package-level symbol in a package that has not been imported into the current file. The code looks for matching symbols in a number of places. If the user has typed foo.lw (with the cursor just after the 'w'), then a matching symbol will be exported from some package 'foo', with 'l', 'w', a subsequence of its name (converted to lower case). That is, the typed symbol name to the left of the cursor is treated as a pattern. The code looks first for a match in the other files of the current package. Failing that it looks in the standard library, using the stdlib.PackageSymbols built into gopls. Failing that it looks in the rest of the workspace, and failing that it looks in the module cache. Test have been changed to use the new code. (An alternative would be to duplicates the relevant tests, one version using the old code and the other using the new code.) Change-Id: Ie04e5571ce09f637fc3913f8e2c51aa65704844f Reviewed-on: https://go-review.googlesource.com/c/tools/+/667576 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- gopls/internal/cache/imports.go | 5 +- gopls/internal/cache/source.go | 3 +- gopls/internal/cache/view.go | 8 +- .../internal/golang/completion/completion.go | 17 + .../internal/golang/completion/unimported.go | 371 ++++++++++++++++++ .../integration/completion/completion_test.go | 17 +- .../marker/testdata/completion/issue59096.txt | 5 + .../marker/testdata/completion/issue60545.txt | 5 + .../marker/testdata/completion/randv2.txt | 25 ++ .../marker/testdata/completion/unimported.txt | 7 +- internal/imports/fix.go | 9 +- 11 files changed, 459 insertions(+), 13 deletions(-) create mode 100644 gopls/internal/golang/completion/unimported.go create mode 100644 gopls/internal/test/marker/testdata/completion/randv2.txt diff --git a/gopls/internal/cache/imports.go b/gopls/internal/cache/imports.go index 31a1b9d42a5..735801f2345 100644 --- a/gopls/internal/cache/imports.go +++ b/gopls/internal/cache/imports.go @@ -168,7 +168,10 @@ func newModcacheState(dir string) *modcacheState { return s } -func (s *modcacheState) GetIndex() (*modindex.Index, error) { +// getIndex reads the module cache index. It might not exist yet +// inside tests. It might contain no Entries if the cache +// is empty. +func (s *modcacheState) getIndex() (*modindex.Index, error) { s.mu.Lock() defer s.mu.Unlock() ix := s.index diff --git a/gopls/internal/cache/source.go b/gopls/internal/cache/source.go index 047cc3971d8..8e223371291 100644 --- a/gopls/internal/cache/source.go +++ b/gopls/internal/cache/source.go @@ -134,8 +134,7 @@ func (s *goplsSource) ResolveReferences(ctx context.Context, filename string, mi } func (s *goplsSource) resolveCacheReferences(missing imports.References) ([]*result, error) { - state := s.S.view.modcacheState - ix, err := state.GetIndex() + ix, err := s.S.view.ModcacheIndex() if err != nil { event.Error(s.ctx, "resolveCacheReferences", err) } diff --git a/gopls/internal/cache/view.go b/gopls/internal/cache/view.go index 4e8375a77db..9c85e6a8c71 100644 --- a/gopls/internal/cache/view.go +++ b/gopls/internal/cache/view.go @@ -36,6 +36,7 @@ import ( "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/gocommand" "golang.org/x/tools/internal/imports" + "golang.org/x/tools/internal/modindex" "golang.org/x/tools/internal/xcontext" ) @@ -372,6 +373,11 @@ func (v *View) Env() []string { ) } +// ModcacheIndex returns the module cache index +func (v *View) ModcacheIndex() (*modindex.Index, error) { + return v.modcacheState.getIndex() +} + // UpdateFolders updates the set of views for the new folders. // // Calling this causes each view to be reinitialized. @@ -1231,7 +1237,7 @@ func globsMatchPath(globs, target string) bool { n := strings.Count(glob, "/") prefix := target // Walk target, counting slashes, truncating at the N+1'th slash. - for i := 0; i < len(target); i++ { + for i := range len(target) { if target[i] == '/' { if n == 0 { prefix = target[:i] diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index 83be9f2ed80..47fcbf463eb 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -1282,6 +1282,23 @@ func (c *completer) selector(ctx context.Context, sel *ast.SelectorExpr) error { // -- completion of symbols in unimported packages -- + // use new code for unimported completions, if flag allows it + if id, ok := sel.X.(*ast.Ident); ok && c.snapshot.Options().ImportsSource == settings.ImportsSourceGopls { + // The user might have typed strings.TLower, so id.Name==strings, sel.Sel.Name == TLower, + // but the cursor might be inside TLower, so adjust the prefix + prefix := sel.Sel.Name + if c.surrounding != nil { + if c.surrounding.content != sel.Sel.Name { + bug.Reportf("unexpected surrounding: %q != %q", c.surrounding.content, sel.Sel.Name) + } else { + prefix = sel.Sel.Name[:c.surrounding.cursor-c.surrounding.start] + } + } + c.unimported(ctx, metadata.PackageName(id.Name), prefix) + return nil + + } + // The deep completion algorithm is exceedingly complex and // deeply coupled to the now obsolete notions that all // token.Pos values can be interpreted by as a single FileSet diff --git a/gopls/internal/golang/completion/unimported.go b/gopls/internal/golang/completion/unimported.go new file mode 100644 index 00000000000..87c059697f3 --- /dev/null +++ b/gopls/internal/golang/completion/unimported.go @@ -0,0 +1,371 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package completion + +// unimported completion is invoked when the user types something like 'foo.xx', +// foo is known to be a package name not yet imported in the current file, and +// xx (or whatever the user has typed) is interpreted as a hint (pattern) for the +// member of foo that the user is looking for. +// +// This code looks for a suitable completion in a number of places. A 'suitable +// completion' is an exported symbol (so a type, const, var, or func) from package +// foo, which, after converting everything to lower case, has the pattern as a +// subsequence. +// +// The code looks for a suitable completion in +// 1. the imports of some other file of the current package, +// 2. the standard library, +// 3. the imports of some other file in the current workspace, +// 4. the module cache. +// It stops at the first success. + +import ( + "context" + "fmt" + "go/ast" + "go/printer" + "go/token" + "path" + "slices" + "strings" + + "golang.org/x/tools/gopls/internal/cache/metadata" + "golang.org/x/tools/gopls/internal/golang/completion/snippet" + "golang.org/x/tools/gopls/internal/protocol" + "golang.org/x/tools/gopls/internal/util/bug" + "golang.org/x/tools/internal/imports" + "golang.org/x/tools/internal/modindex" + "golang.org/x/tools/internal/stdlib" + "golang.org/x/tools/internal/versions" +) + +func (c *completer) unimported(ctx context.Context, pkgname metadata.PackageName, prefix string) error { + wsIDs, ourIDs := c.findPackageIDs(pkgname) + stdpkgs := c.stdlibPkgs(pkgname) + if len(ourIDs) > 0 { + // use the one in the current package, if possible + items := c.pkgIDmatches(ctx, ourIDs, pkgname, prefix) + if c.scoreList(items) { + return nil + } + } + // do the stdlib next. + // For now, use the workspace version of stdlib packages + // to get function snippets. CL 665335 will fix this. + var x []metadata.PackageID + for _, mp := range stdpkgs { + if slices.Contains(wsIDs, metadata.PackageID(mp)) { + x = append(x, metadata.PackageID(mp)) + } + } + if len(x) > 0 { + items := c.pkgIDmatches(ctx, x, pkgname, prefix) + if c.scoreList(items) { + return nil + } + } + // just use the stdlib + items := c.stdlibMatches(stdpkgs, pkgname, prefix) + if c.scoreList(items) { + return nil + } + + // look in the rest of the workspace + items = c.pkgIDmatches(ctx, wsIDs, pkgname, prefix) + if c.scoreList(items) { + return nil + } + + // look in the module cache, for the last chance + items, err := c.modcacheMatches(pkgname, prefix) + if err == nil { + c.scoreList(items) + } + return nil +} + +// find all the packageIDs for packages in the workspace that have the desired name +// thisPkgIDs contains the ones known to the current package, wsIDs contains the others +func (c *completer) findPackageIDs(pkgname metadata.PackageName) (wsIDs, thisPkgIDs []metadata.PackageID) { + g := c.snapshot.MetadataGraph() + for pid, pkg := range c.snapshot.MetadataGraph().Packages { + if pkg.Name != pkgname { + continue + } + imports := g.ImportedBy[pid] + if slices.Contains(imports, c.pkg.Metadata().ID) { + thisPkgIDs = append(thisPkgIDs, pid) + } else { + wsIDs = append(wsIDs, pid) + } + } + return +} + +// find all the stdlib packages that have the desired name +func (c *completer) stdlibPkgs(pkgname metadata.PackageName) []metadata.PackagePath { + var pkgs []metadata.PackagePath // stlib packages that match pkg + for pkgpath := range stdlib.PackageSymbols { + v := metadata.PackageName(path.Base(pkgpath)) + if v == pkgname { + pkgs = append(pkgs, metadata.PackagePath(pkgpath)) + } else if imports.WithoutVersion(string(pkgpath)) == string(pkgname) { + pkgs = append(pkgs, metadata.PackagePath(pkgpath)) + } + } + return pkgs +} + +// return CompletionItems for all matching symbols in the packages in ids. +func (c *completer) pkgIDmatches(ctx context.Context, ids []metadata.PackageID, pkgname metadata.PackageName, prefix string) []CompletionItem { + pattern := strings.ToLower(prefix) + allpkgsyms, err := c.snapshot.Symbols(ctx, ids...) + if err != nil { + return nil // would if be worth retrying the ids one by one? + } + if len(allpkgsyms) != len(ids) { + bug.Errorf("Symbols returned %d values for %d pkgIDs", len(allpkgsyms), len(ids)) + return nil + } + var got []CompletionItem + for i, pkgID := range ids { + pkg := c.snapshot.MetadataGraph().Packages[pkgID] + if pkg == nil { + bug.Errorf("no metadata for %s", pkgID) + continue // something changed underfoot, otherwise can't happen + } + pkgsyms := allpkgsyms[i] + pkgfname := pkgsyms.Files[0].Path() + if !imports.CanUse(c.filename, pkgfname) { + // avoid unusable internal, etc + continue + } + // are any of these any good? + for np, asym := range pkgsyms.Symbols { + for _, sym := range asym { + if !token.IsExported(sym.Name) { + continue + } + if !usefulCompletion(sym.Name, pattern) { + // for json.U, the existing code finds InvalidUTF8Error + continue + } + var params []string + var kind protocol.CompletionItemKind + var detail string + switch sym.Kind { + case protocol.Function: + foundURI := pkgsyms.Files[np] + fh := c.snapshot.FindFile(foundURI) + pgf, err := c.snapshot.ParseGo(ctx, fh, 0) + if err == nil { + params = funcParams(pgf.File, sym.Name) + } + kind = protocol.FunctionCompletion + detail = fmt.Sprintf("func (from %q)", pkg.PkgPath) + case protocol.Variable: + kind = protocol.VariableCompletion + detail = fmt.Sprintf("var (from %q)", pkg.PkgPath) + case protocol.Constant: + kind = protocol.ConstantCompletion + detail = fmt.Sprintf("const (from %q)", pkg.PkgPath) + default: + continue + } + got = c.appendNewItem(got, sym.Name, + detail, + pkg.PkgPath, + kind, + pkgname, params) + } + } + } + return got +} + +// return CompletionItems for all the matches in packages in pkgs. +func (c *completer) stdlibMatches(pkgs []metadata.PackagePath, pkg metadata.PackageName, prefix string) []CompletionItem { + // check for deprecated symbols someday + got := make([]CompletionItem, 0) + pattern := strings.ToLower(prefix) + // avoid non-determinacy, especially for marker tests + slices.Sort(pkgs) + for _, candpkg := range pkgs { + if std, ok := stdlib.PackageSymbols[string(candpkg)]; ok { + for _, sym := range std { + if !usefulCompletion(sym.Name, pattern) { + continue + } + if !versions.AtLeast(c.goversion, sym.Version.String()) { + continue + } + var kind protocol.CompletionItemKind + var detail string + switch sym.Kind { + case stdlib.Func: + kind = protocol.FunctionCompletion + detail = fmt.Sprintf("func (from %q)", candpkg) + case stdlib.Const: + kind = protocol.ConstantCompletion + detail = fmt.Sprintf("const (from %q)", candpkg) + case stdlib.Var: + kind = protocol.VariableCompletion + detail = fmt.Sprintf("var (from %q)", candpkg) + case stdlib.Type: + kind = protocol.VariableCompletion + detail = fmt.Sprintf("type (from %q)", candpkg) + default: + continue + } + got = c.appendNewItem(got, sym.Name, + //fmt.Sprintf("(from %q)", candpkg), candpkg, + detail, + candpkg, + //convKind(sym.Kind), + kind, + pkg, nil) + } + } + } + return got +} + +func (c *completer) modcacheMatches(pkg metadata.PackageName, prefix string) ([]CompletionItem, error) { + ix, err := c.snapshot.View().ModcacheIndex() + if err != nil { + return nil, err + } + if ix == nil || len(ix.Entries) == 0 { // in tests ix might always be nil + return nil, fmt.Errorf("no index %w", err) + } + // retrieve everything and let usefulCompletion() and the matcher sort them out + cands := ix.Lookup(string(pkg), "", true) + lx := len(cands) + got := make([]CompletionItem, 0, lx) + pattern := strings.ToLower(prefix) + for _, cand := range cands { + if !usefulCompletion(cand.Name, pattern) { + continue + } + var params []string + var kind protocol.CompletionItemKind + var detail string + switch cand.Type { + case modindex.Func: + for _, f := range cand.Sig { + params = append(params, fmt.Sprintf("%s %s", f.Arg, f.Type)) + } + kind = protocol.FunctionCompletion + detail = fmt.Sprintf("func (from %s)", cand.ImportPath) + case modindex.Var: + kind = protocol.VariableCompletion + detail = fmt.Sprintf("var (from %s)", cand.ImportPath) + case modindex.Const: + kind = protocol.ConstantCompletion + detail = fmt.Sprintf("const (from %s)", cand.ImportPath) + default: + continue + } + got = c.appendNewItem(got, cand.Name, + detail, + metadata.PackagePath(cand.ImportPath), + kind, + pkg, params) + } + return got, nil +} + +func (c *completer) appendNewItem(got []CompletionItem, name, detail string, path metadata.PackagePath, kind protocol.CompletionItemKind, pkg metadata.PackageName, params []string) []CompletionItem { + item := CompletionItem{ + Label: name, + Detail: detail, + InsertText: name, + Kind: kind, + } + imp := importInfo{ + importPath: string(path), + name: string(pkg), + } + if imports.ImportPathToAssumedName(string(path)) == string(pkg) { + imp.name = "" + } + item.AdditionalTextEdits, _ = c.importEdits(&imp) + if params != nil { + var sn snippet.Builder + c.functionCallSnippet(name, nil, params, &sn) + item.snippet = &sn + } + got = append(got, item) + return got +} + +// score the list. Return true if any item is added to c.items +func (c *completer) scoreList(items []CompletionItem) bool { + ret := false + for _, item := range items { + item.Score = float64(c.matcher.Score(item.Label)) + if item.Score > 0 { + c.items = append(c.items, item) + ret = true + } + } + return ret +} + +// pattern is always the result of strings.ToLower +func usefulCompletion(name, pattern string) bool { + // this travesty comes from foo.(type) somehow. see issue59096.txt + if pattern == "_" { + return true + } + // convert both to lower case, and then the runes in the pattern have to occur, in order, + // in the name + cand := strings.ToLower(name) + for _, r := range pattern { + ix := strings.IndexRune(cand, r) + if ix < 0 { + return false + } + cand = cand[ix+1:] + } + return true +} + +// return a printed version of the function arguments for snippets +func funcParams(f *ast.File, fname string) []string { + var params []string + setParams := func(list *ast.FieldList) { + if list == nil { + return + } + var cfg printer.Config // slight overkill + param := func(name string, typ ast.Expr) { + var buf strings.Builder + buf.WriteString(name) + buf.WriteByte(' ') + cfg.Fprint(&buf, token.NewFileSet(), typ) + params = append(params, buf.String()) + } + + for _, field := range list.List { + if field.Names != nil { + for _, name := range field.Names { + param(name.Name, field.Type) + } + } else { + param("_", field.Type) + } + } + } + for _, n := range f.Decls { + switch x := n.(type) { + case *ast.FuncDecl: + if x.Recv == nil && x.Name.Name == fname { + setParams(x.Type.Params) + } + } + } + return params +} diff --git a/gopls/internal/test/integration/completion/completion_test.go b/gopls/internal/test/integration/completion/completion_test.go index 8fa03908c01..eb3d0a34161 100644 --- a/gopls/internal/test/integration/completion/completion_test.go +++ b/gopls/internal/test/integration/completion/completion_test.go @@ -17,6 +17,7 @@ import ( "golang.org/x/telemetry/counter/countertest" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/server" + "golang.org/x/tools/gopls/internal/settings" . "golang.org/x/tools/gopls/internal/test/integration" "golang.org/x/tools/gopls/internal/test/integration/fake" "golang.org/x/tools/gopls/internal/util/bug" @@ -305,6 +306,7 @@ func _() { WithOptions( WriteGoSum("."), ProxyFiles(proxy), + Settings{"importsSource": settings.ImportsSourceGopls}, ).Run(t, mod, func(t *testing.T, env *Env) { // Make sure the dependency is in the module cache and accessible for // unimported completions, and then remove it before proceeding. @@ -369,6 +371,7 @@ const Name = "mainmod" ` WithOptions( WriteGoSum("."), + Settings{"importsSource": settings.ImportsSourceGopls}, ProxyFiles(proxy)).Run(t, files, func(t *testing.T, env *Env) { env.CreateBuffer("import.go", "package pkg\nvar _ = mainmod.Name\n") env.SaveBuffer("import.go") @@ -538,6 +541,7 @@ func main() { WithOptions( WindowsLineEndings(), Settings{"ui.completion.usePlaceholders": true}, + Settings{"importsSource": settings.ImportsSourceGopls}, ).Run(t, src, func(t *testing.T, env *Env) { // Trigger unimported completions for the mod.com package. env.OpenFile("main.go") @@ -590,7 +594,10 @@ var Lower = "" for _, supportInsertReplace := range []bool{true, false} { t.Run(fmt.Sprintf("insertReplaceSupport=%v", supportInsertReplace), func(t *testing.T) { capabilities := fmt.Sprintf(`{ "textDocument": { "completion": { "completionItem": {"insertReplaceSupport":%t, "snippetSupport": false } } } }`, supportInsertReplace) - runner := WithOptions(CapabilitiesJSON([]byte(capabilities))) + runner := WithOptions( + CapabilitiesJSON([]byte(capabilities)), + Settings{"importsSource": settings.ImportsSourceGopls}, + ) runner.Run(t, src, func(t *testing.T, env *Env) { env.OpenFile("main.go") env.Await(env.DoneWithOpen()) @@ -671,7 +678,8 @@ func F3[K comparable, V any](map[K]V, chan V) {} ` WithOptions( WindowsLineEndings(), - Settings{"ui.completion.usePlaceholders": true}, + Settings{"ui.completion.usePlaceholders": true, + "importsSource": settings.ImportsSourceGopls}, ).Run(t, src, func(t *testing.T, env *Env) { env.OpenFile("a/a.go") env.Await(env.DoneWithOpen()) @@ -681,8 +689,8 @@ func F3[K comparable, V any](map[K]V, chan V) {} for i, want := range []string{ common + "b.F0(${1:a int}, ${2:b int}, ${3:c float64})\r\n", common + "b.F1(${1:_ int}, ${2:_ chan *string})\r\n", - common + "b.F2[${1:K any}, ${2:V any}](${3:_ map[K]V}, ${4:_ chan V})\r\n", - common + "b.F3[${1:K comparable}, ${2:V any}](${3:_ map[K]V}, ${4:_ chan V})\r\n", + common + "b.F2(${1:_ map[K]V}, ${2:_ chan V})\r\n", + common + "b.F3(${1:_ map[K]V}, ${2:_ chan V})\r\n", } { loc := env.RegexpSearch("a/a.go", "b.F()") completions := env.Completion(loc) @@ -1361,6 +1369,7 @@ func Join() {} WithOptions( ProxyFiles(proxy), + Settings{"importsSource": settings.ImportsSourceGopls}, ).Run(t, files, func(t *testing.T, env *Env) { env.RunGoCommand("mod", "download", "golang.org/toolchain@v0.0.1-go1.21.1.linux-amd64") env.OpenFile("foo.go") diff --git a/gopls/internal/test/marker/testdata/completion/issue59096.txt b/gopls/internal/test/marker/testdata/completion/issue59096.txt index 23d82c4dc9c..15730043dce 100644 --- a/gopls/internal/test/marker/testdata/completion/issue59096.txt +++ b/gopls/internal/test/marker/testdata/completion/issue59096.txt @@ -2,6 +2,11 @@ This test exercises the panic in golang/go#59096: completing at a syntactic type-assert expression was panicking because gopls was translating it into a (malformed) selector expr. +-- settings.json -- +{ + "importsSource": "gopls" +} + -- go.mod -- module example.com diff --git a/gopls/internal/test/marker/testdata/completion/issue60545.txt b/gopls/internal/test/marker/testdata/completion/issue60545.txt index 4d204979b6a..0f0bb6a6210 100644 --- a/gopls/internal/test/marker/testdata/completion/issue60545.txt +++ b/gopls/internal/test/marker/testdata/completion/issue60545.txt @@ -5,6 +5,11 @@ module mod.test go 1.18 +-- settings.json -- +{ + "importsSource": "gopls" +} + -- main.go -- package main diff --git a/gopls/internal/test/marker/testdata/completion/randv2.txt b/gopls/internal/test/marker/testdata/completion/randv2.txt new file mode 100644 index 00000000000..95c8543bd20 --- /dev/null +++ b/gopls/internal/test/marker/testdata/completion/randv2.txt @@ -0,0 +1,25 @@ +Unimported completions has to find math/rand/v2 +-- flags -- +-min_go=go1.22 +-min_go_command=go1.22 + +-- settings.json -- +{ + "importsSource": "gopls" +} + +-- go.mod -- +module unimported.test + +go 1.22 + +-- main.go -- +package main +var _ = rand.Int64 //@complete(re"Int64", Int64, Int64N, x64, Uint64, Uint64N), diag("rand", re"undefined: rand") +// ordering of these requires completion order be deterministic +// for now, we do not know the types. Awaiting CL 665335 +//@item(Int64, "Int64", "func (from \"math/rand/v2\")", "func") +//@item(Int64N, "Int64N", "func (from \"math/rand/v2\")", "func") +//@item(x64, "Uint64", "func (from \"math/rand\")", "func") +//@item(Uint64, "Uint64", "func (from \"math/rand/v2\")", "func") +//@item(Uint64N, "Uint64N", "func (from \"math/rand/v2\")", "func") diff --git a/gopls/internal/test/marker/testdata/completion/unimported.txt b/gopls/internal/test/marker/testdata/completion/unimported.txt index 7d12269c8ba..d5437fb9978 100644 --- a/gopls/internal/test/marker/testdata/completion/unimported.txt +++ b/gopls/internal/test/marker/testdata/completion/unimported.txt @@ -2,6 +2,11 @@ -- flags -- -ignore_extra_diags +-- settings.json -- +{ + "importsSource": "gopls" +} + -- go.mod -- module unimported.test @@ -49,7 +54,7 @@ func _() { /* httptrace */ //@item(httptrace, "httptrace", "\"net/http/httptrace\"", "package") /* httputil */ //@item(httputil, "httputil", "\"net/http/httputil\"", "package") -/* ring.Ring */ //@item(ringring, "Ring", "(from \"container/ring\")", "var") +/* ring.Ring */ //@item(ringring, "Ring", "type (from \"container/ring\")", "var") /* signature.Foo */ //@item(signaturefoo, "Foo", "func (from \"unimported.test/signature\")", "func") diff --git a/internal/imports/fix.go b/internal/imports/fix.go index 89b96381cdc..d2e275934e4 100644 --- a/internal/imports/fix.go +++ b/internal/imports/fix.go @@ -27,12 +27,13 @@ import ( "unicode" "unicode/utf8" + "maps" + "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/gocommand" "golang.org/x/tools/internal/gopathwalk" "golang.org/x/tools/internal/stdlib" - "maps" ) // importToGroup is a list of functions which map from an import path to @@ -290,8 +291,8 @@ func (p *pass) loadPackageNames(ctx context.Context, imports []*ImportInfo) erro return nil } -// if there is a trailing major version, remove it -func withoutVersion(nm string) string { +// WithouVersion removes a trailing major version, if there is one. +func WithoutVersion(nm string) string { if v := path.Base(nm); len(v) > 0 && v[0] == 'v' { if _, err := strconv.Atoi(v[1:]); err == nil { // this is, for instance, called with rand/v2 and returns rand @@ -313,7 +314,7 @@ func (p *pass) importIdentifier(imp *ImportInfo) string { } known := p.knownPackages[imp.ImportPath] if known != nil && known.Name != "" { - return withoutVersion(known.Name) + return WithoutVersion(known.Name) } return ImportPathToAssumedName(imp.ImportPath) } From 055c1afc85a003b9636a03a065b451a3b694f23e Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 5 May 2025 12:10:52 -0400 Subject: [PATCH 004/196] go/ssa: clear Function.subst after building bodyless function The subst field is unlike the other ones because it is set non-nil at construction, not during building, so it is not sufficient to clear it in finishBody, which is not called for bodyless functions. This change causes it to be cleared in buildParamsOnly too, and documents the coupling between the two operations. + test that the sanity check no longer fails Fixes golang/go#73594 Change-Id: I72b471bf1596493f12f6cd9f3ca59a055834aad5 Reviewed-on: https://go-review.googlesource.com/c/tools/+/669955 Reviewed-by: Robert Findley Reviewed-by: Elias Naur Commit-Queue: Alan Donovan LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan --- go/ssa/builder.go | 3 +++ go/ssa/builder_test.go | 2 +- go/ssa/func.go | 2 ++ go/ssa/sanity.go | 7 ++++--- go/ssa/testdata/fixedbugs/issue73594.go | 13 +++++++++++++ 5 files changed, 23 insertions(+), 4 deletions(-) create mode 100644 go/ssa/testdata/fixedbugs/issue73594.go diff --git a/go/ssa/builder.go b/go/ssa/builder.go index 84ccbc0927a..b76b75ea025 100644 --- a/go/ssa/builder.go +++ b/go/ssa/builder.go @@ -2920,6 +2920,9 @@ func (b *builder) buildParamsOnly(fn *Function) { for i, n := 0, params.Len(); i < n; i++ { fn.addParamVar(params.At(i)) } + + // clear out other function state (keep consistent with finishBody) + fn.subst = nil } // buildFromSyntax builds fn.Body from fn.syntax, which must be non-nil. diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go index a48723bd271..be710ad66bf 100644 --- a/go/ssa/builder_test.go +++ b/go/ssa/builder_test.go @@ -1045,8 +1045,8 @@ func TestFixedBugs(t *testing.T) { for _, name := range []string{ "issue66783a", "issue66783b", + "issue73594", } { - t.Run(name, func(t *testing.T) { base := name + ".go" path := filepath.Join(analysistest.TestData(), "fixedbugs", base) diff --git a/go/ssa/func.go b/go/ssa/func.go index 2d52309b623..f48bd7184a4 100644 --- a/go/ssa/func.go +++ b/go/ssa/func.go @@ -386,6 +386,8 @@ func (f *Function) finishBody() { f.results = nil // (used by lifting) f.deferstack = nil // (used by lifting) f.vars = nil // (used by lifting) + + // clear out other function state (keep consistent with buildParamsOnly) f.subst = nil numberRegisters(f) // uses f.namedRegisters diff --git a/go/ssa/sanity.go b/go/ssa/sanity.go index b11680a1e1d..c47a137c884 100644 --- a/go/ssa/sanity.go +++ b/go/ssa/sanity.go @@ -27,9 +27,10 @@ type sanity struct { } // sanityCheck performs integrity checking of the SSA representation -// of the function fn and returns true if it was valid. Diagnostics -// are written to reporter if non-nil, os.Stderr otherwise. Some -// diagnostics are only warnings and do not imply a negative result. +// of the function fn (which must have been "built") and returns true +// if it was valid. Diagnostics are written to reporter if non-nil, +// os.Stderr otherwise. Some diagnostics are only warnings and do not +// imply a negative result. // // Sanity-checking is intended to facilitate the debugging of code // transformation passes. diff --git a/go/ssa/testdata/fixedbugs/issue73594.go b/go/ssa/testdata/fixedbugs/issue73594.go new file mode 100644 index 00000000000..a723b8a0da2 --- /dev/null +++ b/go/ssa/testdata/fixedbugs/issue73594.go @@ -0,0 +1,13 @@ +package issue73594 + +// Regression test for sanity-check failure caused by not clearing +// Function.subst after building a body-less instantiated function. + +type genericType[T any] struct{} + +func (genericType[T]) methodWithoutBody() + +func callMethodWithoutBody() { + msg := &genericType[int]{} + msg.methodWithoutBody() +} From 8be05357f1ab7be6ef479fb289ff5a01858ee6a6 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 5 May 2025 13:21:56 -0400 Subject: [PATCH 005/196] gopls/internal/golang: make "Show assembly" work in tests + test Change-Id: Icaaa31067cf586c4d7c97262c4108f92c1168bac Reviewed-on: https://go-review.googlesource.com/c/tools/+/669995 LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Robert Findley Commit-Queue: Alan Donovan --- gopls/internal/golang/assembly.go | 12 +- gopls/internal/golang/compileropt.go | 4 +- .../test/integration/web/assembly_test.go | 112 ++++++++++++------ 3 files changed, 92 insertions(+), 36 deletions(-) diff --git a/gopls/internal/golang/assembly.go b/gopls/internal/golang/assembly.go index 12244a58c59..77a204a5c47 100644 --- a/gopls/internal/golang/assembly.go +++ b/gopls/internal/golang/assembly.go @@ -21,6 +21,7 @@ import ( "html" "io" "net/http" + "os" "regexp" "strconv" "strings" @@ -36,7 +37,16 @@ import ( // See gopls/internal/test/integration/misc/webserver_test.go for tests. func AssemblyHTML(ctx context.Context, snapshot *cache.Snapshot, w http.ResponseWriter, pkg *cache.Package, symbol string, web Web) { // Prepare to compile the package with -S, and capture its stderr stream. - inv, cleanupInvocation, err := snapshot.GoCommandInvocation(cache.NoNetwork, pkg.Metadata().CompiledGoFiles[0].DirPath(), "build", []string{"-gcflags=-S", "."}) + // We use "go test -c" not "go build" as it covers all three packages + // (p, "p [p.test]", "p_test [p.test]") in the directory, if they exist. + // (See also compileropt.go.) + inv, cleanupInvocation, err := snapshot.GoCommandInvocation(cache.NoNetwork, pkg.Metadata().CompiledGoFiles[0].DirPath(), + "test", []string{ + "-c", + "-o", os.DevNull, + "-gcflags=-S", + ".", + }) if err != nil { // e.g. failed to write overlays (rare) http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/gopls/internal/golang/compileropt.go b/gopls/internal/golang/compileropt.go index bcce82c123f..df6c58145bf 100644 --- a/gopls/internal/golang/compileropt.go +++ b/gopls/internal/golang/compileropt.go @@ -11,7 +11,6 @@ import ( "fmt" "os" "path/filepath" - "runtime" "strings" "golang.org/x/tools/gopls/internal/cache" @@ -44,11 +43,12 @@ func CompilerOptDetails(ctx context.Context, snapshot *cache.Snapshot, pkgDir pr // We use "go test -c" not "go build" as it covers all three packages // (p, "p [p.test]", "p_test [p.test]") in the directory, if they exist. + // (See also assembly.go.) inv, cleanupInvocation, err := snapshot.GoCommandInvocation(cache.NoNetwork, pkgDir.Path(), "test", []string{ "-c", "-vet=off", // weirdly -c doesn't disable vet fmt.Sprintf("-gcflags=-json=0,%s", outDirURI), // JSON schema version 0 - fmt.Sprintf("-o=%s", cond(runtime.GOOS == "windows", "NUL", "/dev/null")), + fmt.Sprintf("-o=%s", os.DevNull), ".", }) if err != nil { diff --git a/gopls/internal/test/integration/web/assembly_test.go b/gopls/internal/test/integration/web/assembly_test.go index 6820cbb7864..f8f363f16b3 100644 --- a/gopls/internal/test/integration/web/assembly_test.go +++ b/gopls/internal/test/integration/web/assembly_test.go @@ -5,6 +5,7 @@ package web_test import ( + "regexp" "runtime" "testing" @@ -48,36 +49,6 @@ func init() { Run(t, files, func(t *testing.T, env *Env) { env.OpenFile("a/a.go") - asmFor := func(pattern string) []byte { - // Invoke the "Browse assembly" code action to start the server. - loc := env.RegexpSearch("a/a.go", pattern) - actions, err := env.Editor.CodeAction(env.Ctx, loc, nil, protocol.CodeActionUnknownTrigger) - if err != nil { - t.Fatalf("CodeAction: %v", err) - } - action, err := codeActionByKind(actions, settings.GoAssembly) - if err != nil { - t.Fatal(err) - } - - // Execute the command. - // Its side effect should be a single showDocument request. - params := &protocol.ExecuteCommandParams{ - Command: action.Command.Command, - Arguments: action.Command.Arguments, - } - var result command.DebuggingResult - collectDocs := env.Awaiter.ListenToShownDocuments() - env.ExecuteCommand(params, &result) - doc := shownDocument(t, collectDocs(), "http:") - if doc == nil { - t.Fatalf("no showDocument call had 'file:' prefix") - } - t.Log("showDocument(package doc) URL:", doc.URI) - - return get(t, doc.URI) - } - // Get the report and do some minimal checks for sensible results. // // Use only portable instructions below! Remember that @@ -88,7 +59,8 @@ func init() { // We conservatively test only on the two most popular // architectures. { - report := asmFor("println") + loc := env.RegexpSearch("a/a.go", "println") + report := asmFor(t, env, loc) checkMatch(t, true, report, `TEXT.*example.com/a.f`) switch runtime.GOARCH { case "amd64", "arm64": @@ -111,7 +83,8 @@ func init() { // Check that code in a package-level var initializer is found too. { - report := asmFor(`f\(123\)`) + loc := env.RegexpSearch("a/a.go", `f\(123\)`) + report := asmFor(t, env, loc) switch runtime.GOARCH { case "amd64", "arm64": checkMatch(t, true, report, `TEXT.*example.com/a.init`) @@ -123,7 +96,8 @@ func init() { // And code in a source-level init function. { - report := asmFor(`f\(789\)`) + loc := env.RegexpSearch("a/a.go", `f\(789\)`) + report := asmFor(t, env, loc) switch runtime.GOARCH { case "amd64", "arm64": checkMatch(t, true, report, `TEXT.*example.com/a.init`) @@ -133,3 +107,75 @@ func init() { } }) } + +// TestTestAssembly exercises assembly listing of tests. +func TestTestAssembly(t *testing.T) { + testenv.NeedsGoCommand1Point(t, 22) // for up-to-date assembly listing + + const files = ` +-- go.mod -- +module example.com + +-- a/a_test.go -- +package a + +import "testing" + +func Test1(*testing.T) { println(0) } + +-- a/a_x_test.go -- +package a_test + +import "testing" + +func Test2(*testing.T) { println(0) } +` + Run(t, files, func(t *testing.T, env *Env) { + for _, test := range []struct { + filename, symbol string + }{ + {"a/a_test.go", "example.com/a.Test1"}, + {"a/a_x_test.go", "example.com/a_test.Test2"}, + } { + env.OpenFile(test.filename) + loc := env.RegexpSearch(test.filename, `println`) + report := asmFor(t, env, loc) + checkMatch(t, true, report, `TEXT.*`+regexp.QuoteMeta(test.symbol)) + switch runtime.GOARCH { + case "amd64", "arm64": + checkMatch(t, true, report, `CALL runtime.printint`) + } + } + }) +} + +// asmFor returns the HTML document served by gopls for a "Show +// assembly" command at the specified location in an open file. +func asmFor(t *testing.T, env *Env, loc protocol.Location) []byte { + // Invoke the "Browse assembly" code action to start the server. + actions, err := env.Editor.CodeAction(env.Ctx, loc, nil, protocol.CodeActionUnknownTrigger) + if err != nil { + t.Fatalf("CodeAction: %v", err) + } + action, err := codeActionByKind(actions, settings.GoAssembly) + if err != nil { + t.Fatal(err) + } + + // Execute the command. + // Its side effect should be a single showDocument request. + params := &protocol.ExecuteCommandParams{ + Command: action.Command.Command, + Arguments: action.Command.Arguments, + } + var result command.DebuggingResult + collectDocs := env.Awaiter.ListenToShownDocuments() + env.ExecuteCommand(params, &result) + doc := shownDocument(t, collectDocs(), "http:") + if doc == nil { + t.Fatalf("no showDocument call had 'file:' prefix") + } + t.Log("showDocument(package doc) URL:", doc.URI) + + return get(t, doc.URI) +} From 0ac692e9300e98da48254da8ab886c12b028d622 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Sun, 27 Apr 2025 11:52:11 -0400 Subject: [PATCH 006/196] gopls/internal/golang: Hover: show allocator size class This CL adds the allocator size class to the size information printed for a struct type, when it is larger than the nominal size. Change-Id: I7a0c6f6fcd5f3a4bc664b88dfa7d0a3d8a8dc358 Reviewed-on: https://go-review.googlesource.com/c/tools/+/668395 Auto-Submit: Alan Donovan Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- gopls/doc/features/passive.md | 7 ++++-- gopls/internal/golang/hover.go | 19 +++++++++++++++- gopls/internal/golang/hover_test.go | 22 +++++++++++++++++++ .../test/marker/testdata/hover/sizeoffset.txt | 10 +++++++++ 4 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 gopls/internal/golang/hover_test.go diff --git a/gopls/doc/features/passive.md b/gopls/doc/features/passive.md index 4557880fdcd..77f7b2f0c06 100644 --- a/gopls/doc/features/passive.md +++ b/gopls/doc/features/passive.md @@ -45,8 +45,11 @@ This information may be useful when optimizing the layout of your data structures, or when reading assembly files or stack traces that refer to each field by its cryptic byte offset. -In addition, Hover reports the percentage of wasted space due to -suboptimal ordering of struct fields, if this figure is 20% or higher: +In addition, Hover reports: +- the struct's size class, which is the number of of bytes actually + allocated by the Go runtime for a single object of this type; and +- the percentage of wasted space due to suboptimal ordering of struct + fields, if this figure is 20% or higher: diff --git a/gopls/internal/golang/hover.go b/gopls/internal/golang/hover.go index 93c89f3af97..dd04f8908c7 100644 --- a/gopls/internal/golang/hover.go +++ b/gopls/internal/golang/hover.go @@ -396,6 +396,7 @@ func hover(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, pp pro } // Compute size information for types, + // including allocator size class, // and (size, offset) for struct fields. // // Also, if a struct type's field ordering is significantly @@ -430,7 +431,7 @@ func hover(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, pp pro path := pathEnclosingObjNode(pgf.File, pos) - // Build string of form "size=... (X% wasted), offset=...". + // Build string of form "size=... (X% wasted), class=..., offset=...". size, wasted, offset := computeSizeOffsetInfo(pkg, path, obj) var buf strings.Builder if size >= 0 { @@ -438,6 +439,11 @@ func hover(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, pp pro if wasted >= 20 { // >=20% wasted fmt.Fprintf(&buf, " (%d%% wasted)", wasted) } + + // Include allocator size class, if larger. + if class := sizeClass(size); class > size { + fmt.Fprintf(&buf, ", class=%s", format(class)) + } } if offset >= 0 { if buf.Len() > 0 { @@ -1776,3 +1782,14 @@ func computeSizeOffsetInfo(pkg *cache.Package, path []ast.Node, obj types.Object return } + +// sizeClass reports the size class for a struct of the specified size, or -1 if unknown.f +// See GOROOT/src/runtime/msize.go for details. +func sizeClass(size int64) int64 { + if size > 1<<16 { + return -1 // avoid allocation + } + // We assume that bytes.Clone doesn't trim, + // and reports the underlying size class; see TestSizeClass. + return int64(cap(bytes.Clone(make([]byte, size)))) +} diff --git a/gopls/internal/golang/hover_test.go b/gopls/internal/golang/hover_test.go new file mode 100644 index 00000000000..3d55bfe993c --- /dev/null +++ b/gopls/internal/golang/hover_test.go @@ -0,0 +1,22 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package golang + +import "testing" + +func TestSizeClass(t *testing.T) { + // See GOROOT/src/runtime/msize.go for details. + for _, test := range [...]struct{ size, class int64 }{ + {8, 8}, + {9, 16}, + {16, 16}, + {17, 24}, + } { + got := sizeClass(test.size) + if got != test.class { + t.Errorf("sizeClass(%d) = %d, want %d", test.size, got, test.class) + } + } +} diff --git a/gopls/internal/test/marker/testdata/hover/sizeoffset.txt b/gopls/internal/test/marker/testdata/hover/sizeoffset.txt index 54af8cdc6ec..7f475511478 100644 --- a/gopls/internal/test/marker/testdata/hover/sizeoffset.txt +++ b/gopls/internal/test/marker/testdata/hover/sizeoffset.txt @@ -45,6 +45,10 @@ type wasteful struct { //@ hover("wasteful", "wasteful", wasteful) c bool } +type sizeclass struct { //@ hover("sizeclass", "sizeclass", sizeclass) + a [5]*int +} + -- @T -- ```go type T struct { // size=48 (0x30) @@ -65,6 +69,12 @@ type wasteful struct { // size=48 (0x30) (29% wasted) c bool } ``` +-- @sizeclass -- +```go +type sizeclass struct { // size=40 (0x28), class=48 (0x30) + a [5]*int +} +``` -- @a -- ```go field a int // size=8, offset=0 From 6b12a4e06389d434e2dd061d14ad54acc33171b0 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 6 May 2025 15:31:39 +0000 Subject: [PATCH 007/196] internal/mcp/protocol: move out of mcp/internal, as it's used in the API The generated protocol package is part of the SDK API, so it shouldn't be internal. Change-Id: I5306af48019b1599707b939b348eb6ae4f766ca9 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670356 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/mcp/client.go | 2 +- internal/mcp/cmd_test.go | 2 +- internal/mcp/content.go | 2 +- internal/mcp/content_test.go | 2 +- internal/mcp/examples/hello/main.go | 2 +- internal/mcp/mcp_test.go | 2 +- internal/mcp/prompt.go | 2 +- internal/mcp/prompt_test.go | 2 +- internal/mcp/{internal => }/protocol/content.go | 0 internal/mcp/{internal => }/protocol/doc.go | 0 internal/mcp/{internal => }/protocol/generate.go | 0 internal/mcp/{internal => }/protocol/protocol.go | 0 internal/mcp/server.go | 2 +- internal/mcp/sse_test.go | 2 +- internal/mcp/tool.go | 2 +- internal/mcp/transport.go | 2 +- 16 files changed, 12 insertions(+), 12 deletions(-) rename internal/mcp/{internal => }/protocol/content.go (100%) rename internal/mcp/{internal => }/protocol/doc.go (100%) rename internal/mcp/{internal => }/protocol/generate.go (100%) rename internal/mcp/{internal => }/protocol/protocol.go (100%) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index c9b5b4c40b6..adf8f17f00c 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -12,7 +12,7 @@ import ( "sync" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" - "golang.org/x/tools/internal/mcp/internal/protocol" + "golang.org/x/tools/internal/mcp/protocol" ) // A Client is an MCP client, which may be connected to an MCP server diff --git a/internal/mcp/cmd_test.go b/internal/mcp/cmd_test.go index 822d6498883..7cdc0e3237b 100644 --- a/internal/mcp/cmd_test.go +++ b/internal/mcp/cmd_test.go @@ -13,7 +13,7 @@ import ( "github.com/google/go-cmp/cmp" "golang.org/x/tools/internal/mcp" - "golang.org/x/tools/internal/mcp/internal/protocol" + "golang.org/x/tools/internal/mcp/protocol" ) const runAsServer = "_MCP_RUN_AS_SERVER" diff --git a/internal/mcp/content.go b/internal/mcp/content.go index f0e20136fbc..7a3687dd284 100644 --- a/internal/mcp/content.go +++ b/internal/mcp/content.go @@ -7,7 +7,7 @@ package mcp import ( "fmt" - "golang.org/x/tools/internal/mcp/internal/protocol" + "golang.org/x/tools/internal/mcp/protocol" ) // Content is the union of supported content types: [TextContent], diff --git a/internal/mcp/content_test.go b/internal/mcp/content_test.go index f48f51e7689..950175cb5ac 100644 --- a/internal/mcp/content_test.go +++ b/internal/mcp/content_test.go @@ -9,7 +9,7 @@ import ( "github.com/google/go-cmp/cmp" "golang.org/x/tools/internal/mcp" - "golang.org/x/tools/internal/mcp/internal/protocol" + "golang.org/x/tools/internal/mcp/protocol" ) func TestContent(t *testing.T) { diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index 3f80254fd33..b3d8a3ea3e4 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -12,7 +12,7 @@ import ( "os" "golang.org/x/tools/internal/mcp" - "golang.org/x/tools/internal/mcp/internal/protocol" + "golang.org/x/tools/internal/mcp/protocol" ) var httpAddr = flag.String("http", "", "if set, use SSE HTTP at this address, instead of stdin/stdout") diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index e6fbcd2949a..9ead44513f4 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -16,8 +16,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "golang.org/x/tools/internal/mcp/internal/protocol" "golang.org/x/tools/internal/mcp/jsonschema" + "golang.org/x/tools/internal/mcp/protocol" ) type hiParams struct { diff --git a/internal/mcp/prompt.go b/internal/mcp/prompt.go index 2faa3800ef9..0810cc13254 100644 --- a/internal/mcp/prompt.go +++ b/internal/mcp/prompt.go @@ -11,9 +11,9 @@ import ( "reflect" "slices" - "golang.org/x/tools/internal/mcp/internal/protocol" "golang.org/x/tools/internal/mcp/internal/util" "golang.org/x/tools/internal/mcp/jsonschema" + "golang.org/x/tools/internal/mcp/protocol" ) // A PromptHandler handles a call to prompts/get. diff --git a/internal/mcp/prompt_test.go b/internal/mcp/prompt_test.go index 6fccc9b936f..912125629b6 100644 --- a/internal/mcp/prompt_test.go +++ b/internal/mcp/prompt_test.go @@ -10,7 +10,7 @@ import ( "github.com/google/go-cmp/cmp" "golang.org/x/tools/internal/mcp" - "golang.org/x/tools/internal/mcp/internal/protocol" + "golang.org/x/tools/internal/mcp/protocol" ) // testPromptHandler is used for type inference in TestMakePrompt. diff --git a/internal/mcp/internal/protocol/content.go b/internal/mcp/protocol/content.go similarity index 100% rename from internal/mcp/internal/protocol/content.go rename to internal/mcp/protocol/content.go diff --git a/internal/mcp/internal/protocol/doc.go b/internal/mcp/protocol/doc.go similarity index 100% rename from internal/mcp/internal/protocol/doc.go rename to internal/mcp/protocol/doc.go diff --git a/internal/mcp/internal/protocol/generate.go b/internal/mcp/protocol/generate.go similarity index 100% rename from internal/mcp/internal/protocol/generate.go rename to internal/mcp/protocol/generate.go diff --git a/internal/mcp/internal/protocol/protocol.go b/internal/mcp/protocol/protocol.go similarity index 100% rename from internal/mcp/internal/protocol/protocol.go rename to internal/mcp/protocol/protocol.go diff --git a/internal/mcp/server.go b/internal/mcp/server.go index d549db50b71..459fd3ce721 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -13,7 +13,7 @@ import ( "sync" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" - "golang.org/x/tools/internal/mcp/internal/protocol" + "golang.org/x/tools/internal/mcp/protocol" ) // A Server is an instance of an MCP server. diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index 57a8e746405..e6d355b105f 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -12,7 +12,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "golang.org/x/tools/internal/mcp/internal/protocol" + "golang.org/x/tools/internal/mcp/protocol" ) func TestSSEServer(t *testing.T) { diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index ca2e6b6b62d..6428a188f46 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -9,9 +9,9 @@ import ( "encoding/json" "slices" - "golang.org/x/tools/internal/mcp/internal/protocol" "golang.org/x/tools/internal/mcp/internal/util" "golang.org/x/tools/internal/mcp/jsonschema" + "golang.org/x/tools/internal/mcp/protocol" ) // A ToolHandler handles a call to tools/call. diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index 015b9670e38..1dad3bd1c5a 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -15,7 +15,7 @@ import ( "sync" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" - "golang.org/x/tools/internal/mcp/internal/protocol" + "golang.org/x/tools/internal/mcp/protocol" "golang.org/x/tools/internal/xcontext" ) From ff4f533cb8efe0b3690d11a947e1e510cfe1ab2a Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Mon, 5 May 2025 16:33:23 -0400 Subject: [PATCH 008/196] internal/mcp: add README.md and CONTRIBUTING.md Move examples from package doc to README.md with package decl and imports to ensure it's compilable. Move status of the package from package doc to README.md. Add core concepts of MCP with example code snippets and testing recommendations. Change-Id: I5bf5000b922fa3942aba922713b12d0347295811 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670316 Reviewed-by: Sam Thanawalla Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/mcp/CONTRIBUTING.md | 26 +++++++ internal/mcp/README.md | 132 +++++++++++++++++++++++++++++++++++ internal/mcp/mcp.go | 28 +------- 3 files changed, 159 insertions(+), 27 deletions(-) create mode 100644 internal/mcp/CONTRIBUTING.md create mode 100644 internal/mcp/README.md diff --git a/internal/mcp/CONTRIBUTING.md b/internal/mcp/CONTRIBUTING.md new file mode 100644 index 00000000000..c271074fc01 --- /dev/null +++ b/internal/mcp/CONTRIBUTING.md @@ -0,0 +1,26 @@ +# Contributing to Go MCP package + +Go is an open source project. + +It is the work of hundreds of contributors. We appreciate your help! + +## Filing issues + +When [filing an issue](https://golang.org/issue/new), make sure to answer these five questions: + +1. What version of Go are you using (`go version`)? +2. What operating system and processor architecture are you using? +3. What did you do? +4. What did you expect to see? +5. What did you see instead? + +General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker. +The gophers there will answer or ask you to file an issue if you've tripped over a bug. + +## Contributing code + +Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html) +before sending patches. + +Unless otherwise noted, the Go source files are distributed under +the BSD-style license found in the LICENSE file. diff --git a/internal/mcp/README.md b/internal/mcp/README.md new file mode 100644 index 00000000000..8ce9277bec9 --- /dev/null +++ b/internal/mcp/README.md @@ -0,0 +1,132 @@ +# MCP package + +[![PkgGoDev](https://pkg.go.dev/badge/golang.org/x/tools)](https://pkg.go.dev/golang.org/x/tools/internal/mcp) + +The mcp package provides an SDK for writing [model context protocol](https://modelcontextprotocol.io/introduction) +clients and servers. It is a work-in-progress. As of writing, it is a prototype +to explore the design space of client/server transport and binding. + +## Installation + +The mcp package is currently internal and cannot be imported using `go get`. + +## Quickstart + +Here's an example that creates a client that talks to an MCP server running +as a sidecar process: + +```go +package main + +import ( + "context" + "log" + "os/exec" + + "golang.org/x/tools/internal/mcp" +) + +func main() { + ctx := context.Background() + // Create a new client, with no features. + client := mcp.NewClient("mcp-client", "v1.0.0", nil) + // Connect to a server over stdin/stdout + transport := mcp.NewCommandTransport(exec.Command("myserver")) + if err := client.Connect(ctx, transport, nil); err != nil { + log.Fatal(err) + } + // Call a tool on the server. + if content, err := client.CallTool(ctx, "greet", map[string]any{"name": "you"}); err != nil { + log.Printf("CallTool returns error: %v", err) + } else { + log.Printf("CallTool returns: %v", content) + } +} +``` + +Here is an example of the corresponding server, connected over stdin/stdout: + +```go +package main + +import ( + "context" + + "golang.org/x/tools/internal/mcp" +) + +type HiParams struct { + Name string `json:"name"` +} + +func SayHi(ctx context.Context, cc *mcp.ClientConnection, params *HiParams) ([]mcp.Content, error) { + return []mcp.Content{ + mcp.TextContent{Text: "Hi " + params.Name}, + }, nil +} + +func main() { + // Create a server with a single tool. + server := mcp.NewServer("greeter", "v1.0.0", nil) + server.AddTools(mcp.MakeTool("greet", "say hi", SayHi)) + // Run the server over stdin/stdout, until the client diconnects + _ = server.Run(context.Background(), mcp.NewStdIOTransport(), nil) +} +``` + +## Core Concepts + +The mcp package leverages Go's [reflect](https://pkg.go.dev/reflect) package to +automatically generate the JSON schema for your tools / prompts' input +parameters. As an mcp server developer, ensure your input parameter structs +include the standard `"json"` tags (as demonstrated in the `HiParams` example). +Refer to the [jsonschema](https://www.google.com/search?q=internal/jsonschema/infer.go) +package for detailed information on schema inference. + +### Tools + +Tools in MCP allow servers to expose executable functions that can be invoked by clients and used by LLMs to perform actions. The server can add tools using + +```go +... +server := mcp.NewServer("greeter", "v1.0.0", nil) +server.AddTools(mcp.MakeTool("greet", "say hi", SayHi)) +... +``` + +### Prompts + +Prompts enable servers to define reusable prompt templates and workflows that clients can easily surface to users and LLMs. The server can add prompts by using + +```go +... +server := mcp.NewServer("greeter", "v0.0.1", nil) +server.AddPrompts(mcp.MakePrompt("greet", "", PromptHi)) +... +``` + +### Resources + +Resources are a core primitive in the Model Context Protocol (MCP) that allow servers to expose data and content that can be read by clients and used as context for LLM interactions. + + +Resources are not supported yet. + +## Testing + +To test your client or server using stdio transport, you can use local +transport instead of creating real stdio transportation. See [example](server_example_test.go). + +To test your client or server using sse transport, you can use the [httptest](https://pkg.go.dev/net/http/httptest) +package. See [example](sse_example_test.go). + +## Code of Conduct + +This project follows the [Go Community Code of Conduct](https://go.dev/conduct). +If you encounter a conduct-related issue, please mail conduct@golang.org. + +## License + +Unless otherwise noted, the Go source files are distributed under the BSD-style license found in the [LICENSE](../../LICENSE) file. + +Upon a potential move to [modelcontextprotocol](https://github.com/modelcontextprotocol), the license will be updated to the MIT License, and the license header will reflect the Go MCP SDK Authors. diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go index 40710b68edc..5cc0549bcca 100644 --- a/internal/mcp/mcp.go +++ b/internal/mcp/mcp.go @@ -3,8 +3,7 @@ // license that can be found in the LICENSE file. // The mcp package provides an SDK for writing model context protocol clients -// and servers. It is a work-in-progress. As of writing, it is a prototype to -// explore the design space of client/server transport and binding. +// and servers. // // To get started, create either a [Client] or [Server], and connect it to a // peer using a [Transport]. The diagram below illustrates how this works: @@ -35,31 +34,6 @@ // may define their own custom Transports by implementing the [Transport] // interface. // -// Here's an example that creates a client that talks to an MCP server running -// as a sidecar process: -// -// import "golang.org/x/tools/internal/mcp" -// ... -// // Create a new client, with no features. -// client := mcp.NewClient("mcp-client", "v1.0.0", nil) -// // Connect to a server over stdin/stdout -// transport := mcp.NewCommandTransport(exec.Command("myserver")) -// if err := client.Connect(ctx, transport, nil); err != nil { -// log.Fatal(err) -// } -// // Call a tool on the server. -// content, err := client.CallTool(ctx, "greet", map[string]any{"name": "you"}) -// -// Here is an example of the corresponding server, connected over stdin/stdout: -// -// import "golang.org/x/tools/internal/mcp" -// ... -// // Create a server with a single tool. -// server := mcp.NewServer("greeter", "v1.0.0", nil) -// server.AddTool(mcp.MakeTool("greet", "say hi", SayHi)) -// // Run the server over stdin/stdout, until the client diconnects -// _ = server.Run(ctx, mcp.NewStdIOTransport(), nil) -// // # TODO // // - Support all content types. From 887e16cb1da986f0c698e279255bed9b66f236c4 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 3 May 2025 06:54:00 -0400 Subject: [PATCH 009/196] internal/mcp: NewClient takes Transport Pass the client Transport to NewClient instead of Connect. Combine the ConnectionOptions with the ClientOptions. Change-Id: I3acd6cf6d5bf46666742f2a73e2acaef2057ea77 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670215 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/client.go | 30 ++++++++++++++++++----------- internal/mcp/cmd_test.go | 4 ++-- internal/mcp/mcp.go | 2 +- internal/mcp/mcp_test.go | 13 ++++++------- internal/mcp/server_example_test.go | 4 ++-- internal/mcp/sse_example_test.go | 4 ++-- internal/mcp/sse_test.go | 4 ++-- internal/mcp/transport.go | 2 +- 8 files changed, 35 insertions(+), 28 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index adf8f17f00c..214b516d00b 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -16,10 +16,12 @@ import ( ) // A Client is an MCP client, which may be connected to an MCP server -// using the [Client.Connect] method. +// using the [Client.Start] method. type Client struct { name string version string + transport Transport + opts ClientOptions mu sync.Mutex conn *jsonrpc2.Connection initializeResult *protocol.InitializeResult @@ -27,18 +29,25 @@ type Client struct { // NewClient creates a new Client. // -// Use [Client.Connect] to connect it to an MCP server. +// Use [Client.Start] to connect it to an MCP server. // // If non-nil, the provided options configure the Client. -func NewClient(name, version string, opts *ClientOptions) *Client { - return &Client{ - name: name, - version: version, +func NewClient(name, version string, t Transport, opts *ClientOptions) *Client { + c := &Client{ + name: name, + version: version, + transport: t, } + if opts != nil { + c.opts = *opts + } + return c } // ClientOptions configures the behavior of the client. -type ClientOptions struct{} +type ClientOptions struct { + ConnectionOptions +} // bind implements the binder[*ServerConnection] interface, so that Clients can // be connected using [connect]. @@ -56,20 +65,19 @@ func (c *Client) disconnect(*Client) { // return an error. } -// Connect connects the MCP client over the given transport and initializes an -// MCP session. +// Start begins an MCP session by connecting the MCP client over its transport. // // Typically, it is the responsibility of the client to close the connection // when it is no longer needed. However, if the connection is closed by the // server, calls or notifications will return an error wrapping // [ErrConnectionClosed]. -func (c *Client) Connect(ctx context.Context, t Transport, opts *ConnectionOptions) (err error) { +func (c *Client) Start(ctx context.Context) (err error) { defer func() { if err != nil { _ = c.Close() } }() - _, err = connect(ctx, t, opts, c) + _, err = connect(ctx, c.transport, &c.opts.ConnectionOptions, c) if err != nil { return err } diff --git a/internal/mcp/cmd_test.go b/internal/mcp/cmd_test.go index 7cdc0e3237b..c730c4b58ee 100644 --- a/internal/mcp/cmd_test.go +++ b/internal/mcp/cmd_test.go @@ -49,8 +49,8 @@ func TestCmdTransport(t *testing.T) { cmd := exec.Command(exe) cmd.Env = append(os.Environ(), runAsServer+"=true") - client := mcp.NewClient("client", "v0.0.1", nil) - if err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil); err != nil { + client := mcp.NewClient("client", "v0.0.1", mcp.NewCommandTransport(cmd), nil) + if err := client.Start(ctx); err != nil { log.Fatal(err) } got, err := client.CallTool(ctx, "greet", map[string]any{"name": "user"}) diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go index 5cc0549bcca..fa35b92e0dc 100644 --- a/internal/mcp/mcp.go +++ b/internal/mcp/mcp.go @@ -14,7 +14,7 @@ // // A [Client] is an MCP client, which can be configured with various client // capabilities. Clients may be connected to a [Server] instance -// using the [Client.Connect] method. +// using the [Client.Start] method. // // Similarly, a [Server] is an MCP server, which can be configured with various // server capabilities. Servers may be connected to one or more [Client] diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 9ead44513f4..a3893c3f113 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -81,10 +81,10 @@ func TestEndToEnd(t *testing.T) { clientWG.Done() }() - c := NewClient("testClient", "v1.0.0", nil) + c := NewClient("testClient", "v1.0.0", ct, nil) // Connect the client. - if err := c.Connect(ctx, ct, nil); err != nil { + if err := c.Start(ctx); err != nil { t.Fatal(err) } @@ -210,8 +210,8 @@ func basicConnection(t *testing.T, tools ...*Tool) (*ClientConnection, *Client) t.Fatal(err) } - c := NewClient("testClient", "v1.0.0", nil) - if err := c.Connect(ctx, ct, nil); err != nil { + c := NewClient("testClient", "v1.0.0", ct, nil) + if err := c.Start(ctx); err != nil { t.Fatal(err) } return cc, c @@ -250,13 +250,12 @@ func TestBatching(t *testing.T) { t.Fatal(err) } - c := NewClient("testClient", "v1.0.0", nil) - opts := new(ConnectionOptions) + c := NewClient("testClient", "v1.0.0", ct, nil) // TODO: this test is broken, because increasing the batch size here causes // 'initialize' to block. Therefore, we can only test with a size of 1. const batchSize = 1 BatchSize(ct, batchSize) - if err := c.Connect(ctx, ct, opts); err != nil { + if err := c.Start(ctx); err != nil { t.Fatal(err) } defer c.Close() diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index e532416cb6e..840c09eccdf 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -34,8 +34,8 @@ func ExampleServer() { log.Fatal(err) } - client := mcp.NewClient("client", "v0.0.1", nil) - if err := client.Connect(ctx, clientTransport, nil); err != nil { + client := mcp.NewClient("client", "v0.0.1", clientTransport, nil) + if err := client.Start(ctx); err != nil { log.Fatal(err) } diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go index 7aeb24d1154..9bf3be50fa5 100644 --- a/internal/mcp/sse_example_test.go +++ b/internal/mcp/sse_example_test.go @@ -34,8 +34,8 @@ func ExampleSSEHandler() { ctx := context.Background() transport := mcp.NewSSEClientTransport(httpServer.URL) - client := mcp.NewClient("test", "v1.0.0", nil) - if err := client.Connect(ctx, transport, nil); err != nil { + client := mcp.NewClient("test", "v1.0.0", transport, nil) + if err := client.Start(ctx); err != nil { log.Fatal(err) } defer client.Close() diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index e6d355b105f..99f9a294b02 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -36,8 +36,8 @@ func TestSSEServer(t *testing.T) { clientTransport := NewSSEClientTransport(httpServer.URL) - c := NewClient("testClient", "v1.0.0", nil) - if err := c.Connect(ctx, clientTransport, nil); err != nil { + c := NewClient("testClient", "v1.0.0", clientTransport, nil) + if err := c.Start(ctx); err != nil { t.Fatal(err) } if err := c.Ping(ctx); err != nil { diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index 1dad3bd1c5a..1f19aa27b9d 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -27,7 +27,7 @@ var ErrConnectionClosed = errors.New("connection closed") // and server. // // Transports should be used for at most one call to [Server.Connect] or -// [Client.Connect]. +// [Client.Start]. type Transport interface { // Connect returns the logical stream. // From deec52fd04ab27db6bc028303a68d35de822bb9e Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 6 May 2025 13:40:06 -0400 Subject: [PATCH 010/196] internal/typesinternal: use TypeAndValue.IsBuiltin in ClassifyCall Change-Id: I4bfeee9bcf452783d567e58be754b75e59626dd1 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670396 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- internal/typesinternal/classify_call.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/internal/typesinternal/classify_call.go b/internal/typesinternal/classify_call.go index 649c82b6bea..3db2a135b97 100644 --- a/internal/typesinternal/classify_call.go +++ b/internal/typesinternal/classify_call.go @@ -65,14 +65,16 @@ func ClassifyCall(info *types.Info, call *ast.CallExpr) CallKind { if info.Types == nil { panic("ClassifyCall: info.Types is nil") } - if info.Types[call.Fun].IsType() { + tv := info.Types[call.Fun] + if tv.IsType() { return CallConversion } + if tv.IsBuiltin() { + return CallBuiltin + } obj := info.Uses[UsedIdent(info, call.Fun)] // Classify the call by the type of the object, if any. switch obj := obj.(type) { - case *types.Builtin: - return CallBuiltin case *types.Func: if interfaceMethod(obj) { return CallInterface From 7231669975d2fa642ea8f8e1f7a700c9aa650c89 Mon Sep 17 00:00:00 2001 From: cuishuang Date: Wed, 30 Apr 2025 17:42:46 +0800 Subject: [PATCH 011/196] gopls/internal/analysis/modernize: don't offer a fix when initialization statement is not empty Fixes golang/go#73547 Change-Id: I878f7ab71c1dce896f5eef7fb319cf99b2394f88 Reviewed-on: https://go-review.googlesource.com/c/tools/+/669055 Reviewed-by: Alan Donovan Auto-Submit: Alan Donovan Reviewed-by: Cherry Mui LUCI-TryBot-Result: Go LUCI --- gopls/internal/analysis/modernize/stringscutprefix.go | 3 ++- .../testdata/src/stringscutprefix/stringscutprefix.go | 4 ++++ .../testdata/src/stringscutprefix/stringscutprefix.go.golden | 4 ++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/gopls/internal/analysis/modernize/stringscutprefix.go b/gopls/internal/analysis/modernize/stringscutprefix.go index f8e9be63e3c..f04c0b2ebe8 100644 --- a/gopls/internal/analysis/modernize/stringscutprefix.go +++ b/gopls/internal/analysis/modernize/stringscutprefix.go @@ -59,7 +59,8 @@ func stringscutprefix(pass *analysis.Pass) { ifStmt := curIfStmt.Node().(*ast.IfStmt) // pattern1 - if call, ok := ifStmt.Cond.(*ast.CallExpr); ok && len(ifStmt.Body.List) > 0 { + if call, ok := ifStmt.Cond.(*ast.CallExpr); ok && ifStmt.Init == nil && len(ifStmt.Body.List) > 0 { + obj := typeutil.Callee(info, call) if !analysisinternal.IsFunctionNamed(obj, "strings", "HasPrefix") && !analysisinternal.IsFunctionNamed(obj, "bytes", "HasPrefix") { diff --git a/gopls/internal/analysis/modernize/testdata/src/stringscutprefix/stringscutprefix.go b/gopls/internal/analysis/modernize/testdata/src/stringscutprefix/stringscutprefix.go index 7679bdb6e67..c108df3fd29 100644 --- a/gopls/internal/analysis/modernize/testdata/src/stringscutprefix/stringscutprefix.go +++ b/gopls/internal/analysis/modernize/testdata/src/stringscutprefix/stringscutprefix.go @@ -59,6 +59,10 @@ func _() { a := strings.TrimPrefix(s, pre) // noop, as the argument isn't the same _ = a } + if s1 := s; strings.HasPrefix(s1, pre) { + a := strings.TrimPrefix(s1, pre) // noop, as IfStmt.Init is present + _ = a + } } var value0 string diff --git a/gopls/internal/analysis/modernize/testdata/src/stringscutprefix/stringscutprefix.go.golden b/gopls/internal/analysis/modernize/testdata/src/stringscutprefix/stringscutprefix.go.golden index a6c52b08802..caf52c42606 100644 --- a/gopls/internal/analysis/modernize/testdata/src/stringscutprefix/stringscutprefix.go.golden +++ b/gopls/internal/analysis/modernize/testdata/src/stringscutprefix/stringscutprefix.go.golden @@ -59,6 +59,10 @@ func _() { a := strings.TrimPrefix(s, pre) // noop, as the argument isn't the same _ = a } + if s1 := s; strings.HasPrefix(s1, pre) { + a := strings.TrimPrefix(s1, pre) // noop, as IfStmt.Init is present + _ = a + } } var value0 string From f8980b642a255c5dca4810257aa0eac888dad5a7 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 6 May 2025 07:25:27 -0400 Subject: [PATCH 012/196] internal/mcp/jsonschema: support validating structs The instance passed to Resolved.Validate can be a struct, or can contain structs. A struct is treated like the JSON object that it marshals to. Also support pointers to any depth. That is, if an instance is a **int, then Validate will treat it like an int (or nil). That won't happen if the instance is unmarshaled into an `any`, but may if it's unmarshaled into a struct. Change-Id: I39f9af58028bd779887495754519615cd5dfb6c8 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670395 Reviewed-by: Alan Donovan Auto-Submit: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/mcp/jsonschema/infer.go | 2 +- internal/mcp/jsonschema/schema.go | 7 +- internal/mcp/jsonschema/validate.go | 131 ++++++++++++++++++++--- internal/mcp/jsonschema/validate_test.go | 32 ++++++ 4 files changed, 153 insertions(+), 19 deletions(-) diff --git a/internal/mcp/jsonschema/infer.go b/internal/mcp/jsonschema/infer.go index 7fbf3fe630f..b5605fd56a1 100644 --- a/internal/mcp/jsonschema/infer.go +++ b/internal/mcp/jsonschema/infer.go @@ -102,7 +102,7 @@ func typeSchema(t reflect.Type, seen map[reflect.Type]*Schema) (*Schema, error) case reflect.Struct: s.Type = "object" // no additional properties are allowed - s.AdditionalProperties = &Schema{Not: &Schema{}} + s.AdditionalProperties = falseSchema() for i := range t.NumField() { field := t.Field(i) diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 1b658045cc9..5bb6de6eb56 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -121,6 +121,11 @@ type Schema struct { patternProperties map[*regexp.Regexp]*Schema } +// falseSchema returns a new Schema tree that fails to validate any value. +func falseSchema() *Schema { + return &Schema{Not: &Schema{}} +} + // String returns a short description of the schema. func (s *Schema) String() string { if s.ID != "" { @@ -182,7 +187,7 @@ func (s *Schema) UnmarshalJSON(data []byte) error { *s = Schema{} } else { // false is the schema that validates nothing. - *s = Schema{Not: &Schema{}} + *s = *falseSchema() } return nil } diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index 60f03c55412..28ab0fb1e4b 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -7,11 +7,13 @@ package jsonschema import ( "fmt" "hash/maphash" + "iter" "math" "math/big" "reflect" "slices" "strings" + "sync" "unicode/utf8" ) @@ -55,8 +57,8 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // We checked for nil schemas in [Schema.Resolve]. assert(schema != nil, "nil schema") - // Step through interfaces. - if instance.IsValid() && instance.Kind() == reflect.Interface { + // Step through interfaces and pointers. + for instance.Kind() == reflect.Pointer || instance.Kind() == reflect.Interface { instance = instance.Elem() } @@ -324,16 +326,18 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // objects // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.2 - if instance.Kind() == reflect.Map { - if kt := instance.Type().Key(); kt.Kind() != reflect.String { - return fmt.Errorf("map key type %s is not a string", kt) + if instance.Kind() == reflect.Map || instance.Kind() == reflect.Struct { + if instance.Kind() == reflect.Map { + if kt := instance.Type().Key(); kt.Kind() != reflect.String { + return fmt.Errorf("map key type %s is not a string", kt) + } } // Track the evaluated properties for just this schema, to support additionalProperties. // If we used anns here, then we'd be including properties evaluated in subschemas // from allOf, etc., which additionalProperties shouldn't observe. evalProps := map[string]bool{} for prop, schema := range schema.Properties { - val := instance.MapIndex(reflect.ValueOf(prop)) + val := property(instance, prop) if !val.IsValid() { // It's OK if the instance doesn't have the property. continue @@ -344,8 +348,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an evalProps[prop] = true } if len(schema.PatternProperties) > 0 { - for vprop, val := range instance.Seq2() { - prop := vprop.String() + for prop, val := range properties(instance) { // Check every matching pattern. for re, schema := range schema.patternProperties { if re.MatchString(prop) { @@ -359,8 +362,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an } if schema.AdditionalProperties != nil { // Apply to all properties not handled above. - for vprop, val := range instance.Seq2() { - prop := vprop.String() + for prop, val := range properties(instance) { if !evalProps[prop] { if err := st.validate(val, schema.AdditionalProperties, nil, append(path, prop)); err != nil { return err @@ -371,8 +373,10 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an } anns.noteProperties(evalProps) if schema.PropertyNames != nil { - for prop := range instance.Seq() { - if err := st.validate(prop, schema.PropertyNames, nil, append(path, prop.String())); err != nil { + // Note: properties unnecessarily fetches each value. We could define a propertyNames function + // if performance ever matters. + for prop := range properties(instance) { + if err := st.validate(reflect.ValueOf(prop), schema.PropertyNames, nil, append(path, prop)); err != nil { return err } } @@ -380,18 +384,18 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.5 if schema.MinProperties != nil { - if n, m := instance.Len(), *schema.MinProperties; n < m { + if n, m := numProperties(instance), *schema.MinProperties; n < m { return fmt.Errorf("minProperties: object has %d properties, less than %d", n, m) } } if schema.MaxProperties != nil { - if n, m := instance.Len(), *schema.MaxProperties; n > m { + if n, m := numProperties(instance), *schema.MaxProperties; n > m { return fmt.Errorf("maxProperties: object has %d properties, greater than %d", n, m) } } hasProperty := func(prop string) bool { - return instance.MapIndex(reflect.ValueOf(prop)).IsValid() + return property(instance, prop).IsValid() } missingProperties := func(props []string) []string { @@ -438,8 +442,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an if schema.UnevaluatedProperties != nil && !anns.allProperties { // This looks a lot like AdditionalProperties, but depends on in-place keywords like allOf // in addition to sibling keywords. - for vprop, val := range instance.Seq2() { - prop := vprop.String() + for prop, val := range properties(instance) { if !anns.evaluatedProperties[prop] { if err := st.validate(val, schema.UnevaluatedProperties, nil, append(path, prop)); err != nil { return err @@ -460,6 +463,100 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an return nil } +// property returns the value of the property of v with the given name, or the invalid +// reflect.Value if there is none. +// If v is a map, the property is the value of the map whose key is name. +// If v is a struct, the property is the value of the field with the given name according +// to the encoding/json package (see [jsonName]). +// If v is anything else, property panics. +func property(v reflect.Value, name string) reflect.Value { + switch v.Kind() { + case reflect.Map: + return v.MapIndex(reflect.ValueOf(name)) + case reflect.Struct: + props := structPropertiesOf(v.Type()) + if index, ok := props[name]; ok { + return v.FieldByIndex(index) + } + return reflect.Value{} + default: + panic(fmt.Sprintf("property(%q): bad value %s of kind %s", name, v, v.Kind())) + } +} + +// properties returns an iterator over the names and values of all properties +// in v, which must be a map or a struct. +func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { + return func(yield func(string, reflect.Value) bool) { + switch v.Kind() { + case reflect.Map: + for k, e := range v.Seq2() { + if !yield(k.String(), e) { + return + } + } + case reflect.Struct: + for name, index := range structPropertiesOf(v.Type()) { + if !yield(name, v.FieldByIndex(index)) { + return + } + } + default: + panic(fmt.Sprintf("bad value %s of kind %s", v, v.Kind())) + } + } +} + +// numProperties returns the number of v's properties. +// v must be a map or a struct. +func numProperties(v reflect.Value) int { + switch v.Kind() { + case reflect.Map: + return v.Len() + case reflect.Struct: + return len(structPropertiesOf(v.Type())) + default: + panic(fmt.Sprintf("properties: bad value: %s of kind %s", v, v.Kind())) + } +} + +// A propertyMap is a map from property name to struct field index. +type propertyMap = map[string][]int + +var structProperties sync.Map // from reflect.Type to propertyMap + +// structPropertiesOf returns the JSON Schema properties for the struct type t. +// The caller must not mutate the result. +func structPropertiesOf(t reflect.Type) propertyMap { + // Mutex not necessary: at worst we'll recompute the same value. + if props, ok := structProperties.Load(t); ok { + return props.(propertyMap) + } + props := map[string][]int{} + for _, sf := range reflect.VisibleFields(t) { + if name, ok := jsonName(sf); ok { + props[name] = sf.Index + } + } + structProperties.Store(t, props) + return props +} + +// jsonName returns the name for f as would be used by [json.Marshal]. +// That is the name in the json struct tag, or the field name if there is no tag. +// If f is not exported or the tag name is "-", jsonName returns "", false. +func jsonName(f reflect.StructField) (string, bool) { + if !f.IsExported() { + return "", false + } + if tag, ok := f.Tag.Lookup("json"); ok { + if name, _, _ := strings.Cut(tag, ","); name != "" { + return name, name != "-" + } + } + return f.Name, true +} + func formatPath(path []any) string { var b strings.Builder for i, p := range path { diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index 7abf5f0fd27..79cd19e51e3 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -80,3 +80,35 @@ func TestValidate(t *testing.T) { }) } } + +func TestStructInstance(t *testing.T) { + instance := struct { + I int + B bool `json:"b"` + u int + }{1, true, 0} + + // The instance fails for all of these schemas, demonstrating that it + // was processed correctly. + for _, schema := range []*Schema{ + {MinProperties: Ptr(3)}, + {MaxProperties: Ptr(1)}, + {Required: []string{"i"}}, // the name is "I" + {Required: []string{"B"}}, // the name is "b" + {PropertyNames: &Schema{MinLength: Ptr(2)}}, + {Properties: map[string]*Schema{"b": {Type: "number"}}}, + {Required: []string{"I"}, AdditionalProperties: falseSchema()}, + {DependentRequired: map[string][]string{"b": {"u"}}}, + {DependentSchemas: map[string]*Schema{"b": falseSchema()}}, + {UnevaluatedProperties: falseSchema()}, + } { + res, err := schema.Resolve("") + if err != nil { + t.Fatal(err) + } + err = res.Validate(instance) + if err == nil { + t.Errorf("succeeded but wanted failure; schema = %s", schema.json()) + } + } +} From f71ad04fb8c814bbb06c54b6f41641fbb6493d56 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 26 Apr 2025 15:34:28 -0400 Subject: [PATCH 013/196] jsonschema: implement JSON Pointers for schemas Implement enough of the JSON Pointer spec to suffice for JSON Schema. For details, see the comment at the top of json_pointer.go. Change-Id: Ic297f1a1a08903b4b3c6666fe57ac1f45c96c0f7 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670535 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- internal/mcp/jsonschema/json_pointer.go | 150 +++++++++++++++++++ internal/mcp/jsonschema/json_pointer_test.go | 78 ++++++++++ 2 files changed, 228 insertions(+) create mode 100644 internal/mcp/jsonschema/json_pointer.go create mode 100644 internal/mcp/jsonschema/json_pointer_test.go diff --git a/internal/mcp/jsonschema/json_pointer.go b/internal/mcp/jsonschema/json_pointer.go new file mode 100644 index 00000000000..687743ffbae --- /dev/null +++ b/internal/mcp/jsonschema/json_pointer.go @@ -0,0 +1,150 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements JSON Pointers. +// A JSON Pointer is a path that refers to one JSON value within another. +// If the path is empty, it refers to the root value. +// Otherwise, it is a sequence of slash-prefixed strings, like "/points/1/x", +// selecting successive properties (for JSON objects) or items (for JSON arrays). +// For example, when applied to this JSON value: +// { +// "points": [ +// {"x": 1, "y": 2}, +// {"x": 3, "y": 4} +// ] +// } +// +// the JSON Pointer "/points/1/x" refers to the number 3. +// See the spec at https://datatracker.ietf.org/doc/html/rfc6901. + +package jsonschema + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" +) + +var jsonPointerReplacer = strings.NewReplacer("~0", "~", "~1", "/") + +// parseJSONPointer splits a JSON Pointer into a sequence of segments. It doesn't +// convert strings to numbers, because that depends on the traversal: a segment +// is treated as a number when applied to an array, but a string when applied to +// an object. See section 4 of the spec. +func parseJSONPointer(ptr string) (segments []string, err error) { + if ptr == "" { + return nil, nil + } + if ptr[0] != '/' { + return nil, fmt.Errorf("JSON Pointer %q does not begin with '/'", ptr) + } + // Unlike file paths, consecutive slashes are not coalesced. + // Split is nicer than Cut here, because it gets a final "/" right. + segments = strings.Split(ptr[1:], "/") + if strings.Contains(ptr, "~") { + // Undo the simple escaping rules that allow one to include a slash in a segment. + for i := range segments { + segments[i] = jsonPointerReplacer.Replace(segments[i]) + } + } + return segments, nil +} + +// dereferenceJSONPointer returns the Schema that sptr points to within s, +// or an error if none. +// This implementation suffices for JSON Schema: pointers are applied only to Schemas, +// and refer only to Schemas. +func dereferenceJSONPointer(s *Schema, sptr string) (_ *Schema, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("JSON Pointer %q: %w", sptr, err) + } + }() + + segments, err := parseJSONPointer(sptr) + if err != nil { + return nil, err + } + v := reflect.ValueOf(s) + for _, seg := range segments { + switch v.Kind() { + case reflect.Pointer: + v = v.Elem() + if !v.IsValid() { + return nil, errors.New("navigated to nil reference") + } + fallthrough // if valid, can only be a pointer to a Schema + + case reflect.Struct: + // The segment must refer to a field in a Schema. + if v.Type() != reflect.TypeFor[Schema]() { + return nil, fmt.Errorf("navigated to non-Schema %s", v.Type()) + } + v = lookupSchemaField(v, seg) + if !v.IsValid() { + return nil, fmt.Errorf("no schema field %q", seg) + } + case reflect.Slice, reflect.Array: + // The segment must be an integer without leading zeroes that refers to an item in the + // slice or array. + if seg == "-" { + return nil, errors.New("the JSON Pointer array segment '-' is not supported") + } + if len(seg) > 1 && seg[0] == '0' { + return nil, fmt.Errorf("segment %q has leading zeroes", seg) + } + n, err := strconv.Atoi(seg) + if err != nil { + return nil, fmt.Errorf("invalid int: %q", seg) + } + if n < 0 || n >= v.Len() { + return nil, fmt.Errorf("index %d is out of bounds for array of length %d", n, v.Len()) + } + v = v.Index(n) + // Cannot be invalid. + case reflect.Map: + // The segment must be a key in the map. + v = v.MapIndex(reflect.ValueOf(seg)) + if !v.IsValid() { + return nil, fmt.Errorf("no key %q in map", seg) + } + default: + return nil, fmt.Errorf("value %s (%s) is not a schema, slice or map", v, v.Type()) + } + } + if s, ok := v.Interface().(*Schema); ok { + return s, nil + } + return nil, fmt.Errorf("does not refer to a schema, but to a %s", v.Type()) +} + +// map from JSON names for fields in a Schema to their indexes in the struct. +var schemaFields = map[string][]int{} + +func init() { + for _, f := range reflect.VisibleFields(reflect.TypeFor[Schema]()) { + if name, ok := jsonName(f); ok { + schemaFields[name] = f.Index + } + } +} + +// lookupSchemaField returns the value of the field with the given name in v, +// or the zero value if there is no such field or it is not of type Schema or *Schema. +func lookupSchemaField(v reflect.Value, name string) reflect.Value { + if name == "type" { + // The "type" keyword may refer to Type or Types. + // At most one will be non-zero. + if t := v.FieldByName("Type"); !t.IsZero() { + return t + } + return v.FieldByName("Types") + } + if index := schemaFields[name]; index != nil { + return v.FieldByIndex(index) + } + return reflect.Value{} +} diff --git a/internal/mcp/jsonschema/json_pointer_test.go b/internal/mcp/jsonschema/json_pointer_test.go new file mode 100644 index 00000000000..d31e19cdf9f --- /dev/null +++ b/internal/mcp/jsonschema/json_pointer_test.go @@ -0,0 +1,78 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jsonschema + +import ( + "strings" + "testing" +) + +func TestDereferenceJSONPointer(t *testing.T) { + s := &Schema{ + AllOf: []*Schema{{}, {}}, + Defs: map[string]*Schema{ + "": {Properties: map[string]*Schema{"": {}}}, + "A": {}, + "B": { + Defs: map[string]*Schema{ + "X": {}, + "Y": {}, + }, + }, + "/~": {}, + "~1": {}, + }, + } + + for _, tt := range []struct { + ptr string + want any + }{ + {"", s}, + {"/$defs/A", s.Defs["A"]}, + {"/$defs/B", s.Defs["B"]}, + {"/$defs/B/$defs/X", s.Defs["B"].Defs["X"]}, + {"/$defs//properties/", s.Defs[""].Properties[""]}, + {"/allOf/1", s.AllOf[1]}, + {"/$defs/~1~0", s.Defs["/~"]}, + {"/$defs/~01", s.Defs["~1"]}, + } { + got, err := dereferenceJSONPointer(s, tt.ptr) + if err != nil { + t.Fatal(err) + } + if got != tt.want { + t.Errorf("%s:\ngot %+v\nwant %+v", tt.ptr, got, tt.want) + } + } +} + +func TestDerefernceJSONPointerErrors(t *testing.T) { + s := &Schema{ + Type: "t", + Items: &Schema{}, + Required: []string{"a"}, + } + for _, tt := range []struct { + ptr string + want string // error must contain this string + }{ + {"x", "does not begin"}, // parse error: no initial '/' + {"/minItems", "does not refer to a schema"}, + {"/minItems/x", "navigated to nil"}, + {"/required/-", "not supported"}, + {"/required/01", "leading zeroes"}, + {"/required/x", "invalid int"}, + {"/required/1", "out of bounds"}, + {"/properties/x", "no key"}, + } { + _, err := dereferenceJSONPointer(s, tt.ptr) + if err == nil { + t.Errorf("%q: succeeded, want failure", tt.ptr) + } else if !strings.Contains(err.Error(), tt.want) { + t.Errorf("%q: error is %q, which does not contain %q", tt.ptr, err, tt.want) + } + } +} From c83623246594e4a2da815d6e8ee2355578bcb6d5 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 7 May 2025 11:50:44 -0400 Subject: [PATCH 014/196] internal/mcp: rename Make* to New* Change-Id: I8d3d2ca125b7c24dd265937c4fe910ea823ab811 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670537 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/cmd_test.go | 2 +- internal/mcp/examples/hello/main.go | 4 ++-- internal/mcp/examples/sse/main.go | 4 ++-- internal/mcp/mcp_test.go | 12 ++++++------ internal/mcp/prompt.go | 4 ++-- internal/mcp/prompt_test.go | 12 ++++++------ internal/mcp/server_example_test.go | 2 +- internal/mcp/sse_example_test.go | 2 +- internal/mcp/sse_test.go | 2 +- internal/mcp/tool.go | 6 +++--- internal/mcp/tool_test.go | 14 +++++++------- 11 files changed, 32 insertions(+), 32 deletions(-) diff --git a/internal/mcp/cmd_test.go b/internal/mcp/cmd_test.go index c730c4b58ee..e9f2c8fbb64 100644 --- a/internal/mcp/cmd_test.go +++ b/internal/mcp/cmd_test.go @@ -31,7 +31,7 @@ func runServer() { ctx := context.Background() server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.MakeTool("greet", "say hi", SayHi)) + server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) if err := server.Run(ctx, mcp.NewStdIOTransport(), nil); err != nil { log.Fatal(err) diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index b3d8a3ea3e4..d85317705d3 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -40,10 +40,10 @@ func main() { flag.Parse() server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.MakeTool("greet", "say hi", SayHi, mcp.Input( + server.AddTools(mcp.NewTool("greet", "say hi", SayHi, mcp.Input( mcp.Property("name", mcp.Description("the name to say hi to")), ))) - server.AddPrompts(mcp.MakePrompt("greet", "", PromptHi)) + server.AddPrompts(mcp.NewPrompt("greet", "", PromptHi)) if *httpAddr != "" { handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { diff --git a/internal/mcp/examples/sse/main.go b/internal/mcp/examples/sse/main.go index c95fd34c746..14c51b910c1 100644 --- a/internal/mcp/examples/sse/main.go +++ b/internal/mcp/examples/sse/main.go @@ -33,10 +33,10 @@ func main() { } server1 := mcp.NewServer("greeter1", "v0.0.1", nil) - server1.AddTools(mcp.MakeTool("greet1", "say hi", SayHi)) + server1.AddTools(mcp.NewTool("greet1", "say hi", SayHi)) server2 := mcp.NewServer("greeter2", "v0.0.1", nil) - server2.AddTools(mcp.MakeTool("greet2", "say hello", SayHi)) + server2.AddTools(mcp.NewTool("greet2", "say hello", SayHi)) log.Printf("MCP servers serving at %s\n", *httpAddr) handler := mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index a3893c3f113..a91d50a3705 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -38,18 +38,18 @@ func TestEndToEnd(t *testing.T) { s := NewServer("testServer", "v1.0.0", nil) // The 'greet' tool says hi. - s.AddTools(MakeTool("greet", "say hi", sayHi)) + s.AddTools(NewTool("greet", "say hi", sayHi)) // The 'fail' tool returns this error. failure := errors.New("mcp failure") s.AddTools( - MakeTool("fail", "just fail", func(context.Context, *ClientConnection, struct{}) ([]Content, error) { + NewTool("fail", "just fail", func(context.Context, *ClientConnection, struct{}) ([]Content, error) { return nil, failure }), ) s.AddPrompts( - MakePrompt("code_review", "do a code review", func(_ context.Context, _ *ClientConnection, params struct{ Code string }) (*protocol.GetPromptResult, error) { + NewPrompt("code_review", "do a code review", func(_ context.Context, _ *ClientConnection, params struct{ Code string }) (*protocol.GetPromptResult, error) { return &protocol.GetPromptResult{ Description: "Code review prompt", Messages: []protocol.PromptMessage{ @@ -57,7 +57,7 @@ func TestEndToEnd(t *testing.T) { }, }, nil }), - MakePrompt("fail", "", func(_ context.Context, _ *ClientConnection, params struct{}) (*protocol.GetPromptResult, error) { + NewPrompt("fail", "", func(_ context.Context, _ *ClientConnection, params struct{}) (*protocol.GetPromptResult, error) { return nil, failure }), ) @@ -218,7 +218,7 @@ func basicConnection(t *testing.T, tools ...*Tool) (*ClientConnection, *Client) } func TestServerClosing(t *testing.T) { - cc, c := basicConnection(t, MakeTool("greet", "say hi", sayHi)) + cc, c := basicConnection(t, NewTool("greet", "say hi", sayHi)) defer c.Close() ctx := context.Background() @@ -293,7 +293,7 @@ func TestCancellation(t *testing.T) { } return nil, nil } - _, sc := basicConnection(t, MakeTool("slow", "a slow request", slowRequest)) + _, sc := basicConnection(t, NewTool("slow", "a slow request", slowRequest)) defer sc.Close() ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/mcp/prompt.go b/internal/mcp/prompt.go index 0810cc13254..3b816048d13 100644 --- a/internal/mcp/prompt.go +++ b/internal/mcp/prompt.go @@ -25,7 +25,7 @@ type Prompt struct { Handler PromptHandler } -// MakePrompt is a helper to use reflection to create a prompt for the given +// NewPrompt is a helper to use reflection to create a prompt for the given // handler. // // The arguments for the prompt are extracted from the request type for the @@ -33,7 +33,7 @@ type Prompt struct { // of type string or *string. The argument names for the resulting prompt // definition correspond to the JSON names of the request fields, and any // fields that are not marked "omitempty" are considered required. -func MakePrompt[TReq any](name, description string, handler func(context.Context, *ClientConnection, TReq) (*protocol.GetPromptResult, error), opts ...PromptOption) *Prompt { +func NewPrompt[TReq any](name, description string, handler func(context.Context, *ClientConnection, TReq) (*protocol.GetPromptResult, error), opts ...PromptOption) *Prompt { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) diff --git a/internal/mcp/prompt_test.go b/internal/mcp/prompt_test.go index 912125629b6..9d18d348597 100644 --- a/internal/mcp/prompt_test.go +++ b/internal/mcp/prompt_test.go @@ -13,26 +13,26 @@ import ( "golang.org/x/tools/internal/mcp/protocol" ) -// testPromptHandler is used for type inference in TestMakePrompt. +// testPromptHandler is used for type inference in TestNewPrompt. func testPromptHandler[T any](context.Context, *mcp.ClientConnection, T) (*protocol.GetPromptResult, error) { panic("not implemented") } -func TestMakePrompt(t *testing.T) { +func TestNewPrompt(t *testing.T) { tests := []struct { prompt *mcp.Prompt want []protocol.PromptArgument }{ { - mcp.MakePrompt("empty", "", testPromptHandler[struct{}]), + mcp.NewPrompt("empty", "", testPromptHandler[struct{}]), nil, }, { - mcp.MakePrompt("add_arg", "", testPromptHandler[struct{}], mcp.Argument("x")), + mcp.NewPrompt("add_arg", "", testPromptHandler[struct{}], mcp.Argument("x")), []protocol.PromptArgument{{Name: "x"}}, }, { - mcp.MakePrompt("combo", "", testPromptHandler[struct { + mcp.NewPrompt("combo", "", testPromptHandler[struct { Name string `json:"name"` Country string `json:"country,omitempty"` State string @@ -48,7 +48,7 @@ func TestMakePrompt(t *testing.T) { } for _, test := range tests { if diff := cmp.Diff(test.want, test.prompt.Definition.Arguments); diff != "" { - t.Errorf("MakePrompt(%v) mismatch (-want +got):\n%s", test.prompt.Definition.Name, diff) + t.Errorf("NewPrompt(%v) mismatch (-want +got):\n%s", test.prompt.Definition.Name, diff) } } } diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index 840c09eccdf..3a50877303e 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -27,7 +27,7 @@ func ExampleServer() { clientTransport, serverTransport := mcp.NewLocalTransport() server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.MakeTool("greet", "say hi", SayHi)) + server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) clientConnection, err := server.Connect(ctx, serverTransport, nil) if err != nil { diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go index 9bf3be50fa5..49673a13bd1 100644 --- a/internal/mcp/sse_example_test.go +++ b/internal/mcp/sse_example_test.go @@ -26,7 +26,7 @@ func Add(ctx context.Context, cc *mcp.ClientConnection, params *AddParams) ([]mc func ExampleSSEHandler() { server := mcp.NewServer("adder", "v0.0.1", nil) - server.AddTools(mcp.MakeTool("add", "add two numbers", Add)) + server.AddTools(mcp.NewTool("add", "add two numbers", Add)) handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }) httpServer := httptest.NewServer(handler) diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index 99f9a294b02..801006f4a7b 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -20,7 +20,7 @@ func TestSSEServer(t *testing.T) { t.Run(fmt.Sprintf("closeServerFirst=%t", closeServerFirst), func(t *testing.T) { ctx := context.Background() server := NewServer("testServer", "v1.0.0", nil) - server.AddTools(MakeTool("greet", "say hi", sayHi)) + server.AddTools(NewTool("greet", "say hi", sayHi)) sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index 6428a188f46..9b2ceeeda8a 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -23,7 +23,7 @@ type Tool struct { Handler ToolHandler } -// MakeTool is a helper to make a tool using reflection on the given handler. +// NewTool is a helper to make a tool using reflection on the given handler. // // If provided, variadic [ToolOption] values may be used to customize the tool. // @@ -32,11 +32,11 @@ type Tool struct { // schema may be customized using the [Input] option. // // The handler request type must translate to a valid schema, as documented by -// [jsonschema.ForType]; otherwise, MakeTool panics. +// [jsonschema.ForType]; otherwise, NewTool panics. // // TODO: just have the handler return a CallToolResult: returning []Content is // going to be inconsistent with other server features. -func MakeTool[TReq any](name, description string, handler func(context.Context, *ClientConnection, TReq) ([]Content, error), opts ...ToolOption) *Tool { +func NewTool[TReq any](name, description string, handler func(context.Context, *ClientConnection, TReq) ([]Content, error), opts ...ToolOption) *Tool { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) diff --git a/internal/mcp/tool_test.go b/internal/mcp/tool_test.go index 85b5e55e931..b3a2cc025d4 100644 --- a/internal/mcp/tool_test.go +++ b/internal/mcp/tool_test.go @@ -14,18 +14,18 @@ import ( "golang.org/x/tools/internal/mcp/jsonschema" ) -// testToolHandler is used for type inference in TestMakeTool. +// testToolHandler is used for type inference in TestNewTool. func testToolHandler[T any](context.Context, *mcp.ClientConnection, T) ([]mcp.Content, error) { panic("not implemented") } -func TestMakeTool(t *testing.T) { +func TestNewTool(t *testing.T) { tests := []struct { tool *mcp.Tool want *jsonschema.Schema }{ { - mcp.MakeTool("basic", "", testToolHandler[struct { + mcp.NewTool("basic", "", testToolHandler[struct { Name string `json:"name"` }]), &jsonschema.Schema{ @@ -38,7 +38,7 @@ func TestMakeTool(t *testing.T) { }, }, { - mcp.MakeTool("enum", "", testToolHandler[struct{ Name string }], mcp.Input( + mcp.NewTool("enum", "", testToolHandler[struct{ Name string }], mcp.Input( mcp.Property("Name", mcp.Enum("x", "y", "z")), )), &jsonschema.Schema{ @@ -51,7 +51,7 @@ func TestMakeTool(t *testing.T) { }, }, { - mcp.MakeTool("required", "", testToolHandler[struct { + mcp.NewTool("required", "", testToolHandler[struct { Name string `json:"name"` Language string `json:"language"` X int `json:"x,omitempty"` @@ -71,7 +71,7 @@ func TestMakeTool(t *testing.T) { }, }, { - mcp.MakeTool("set_schema", "", testToolHandler[struct { + mcp.NewTool("set_schema", "", testToolHandler[struct { X int `json:"x,omitempty"` Y int `json:"y,omitempty"` }], mcp.Input( @@ -84,7 +84,7 @@ func TestMakeTool(t *testing.T) { } for _, test := range tests { if diff := cmp.Diff(test.want, test.tool.Definition.InputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Errorf("MakeTool(%v) mismatch (-want +got):\n%s", test.tool.Definition.Name, diff) + t.Errorf("NewTool(%v) mismatch (-want +got):\n%s", test.tool.Definition.Name, diff) } } } From 3f0db3480a6a5414c32f0641a82574cd1a7d4efa Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 7 May 2025 16:43:03 +0000 Subject: [PATCH 015/196] internal/mcp: rename ClientConnection to ServerConnection It's too confusing that you send server requests through the Client, and send client requests through the ClientConnection. Rename ClientConnection to ServerConnection, so that the types on the server side of the control diagram start with 'Server'. Change-Id: Idfd27af464a9296a2341cd23b2075cd9d0eb6b45 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670538 LUCI-TryBot-Result: Go LUCI Auto-Submit: Robert Findley Reviewed-by: Jonathan Amsterdam --- internal/mcp/README.md | 3 +- internal/mcp/client.go | 4 +- internal/mcp/examples/hello/main.go | 4 +- internal/mcp/examples/sse/main.go | 2 +- internal/mcp/mcp.go | 4 +- internal/mcp/mcp_test.go | 12 +++--- internal/mcp/prompt.go | 6 +-- internal/mcp/prompt_test.go | 2 +- internal/mcp/server.go | 60 +++++++++++++++-------------- internal/mcp/server_example_test.go | 2 +- internal/mcp/sse.go | 8 ++-- internal/mcp/sse_example_test.go | 2 +- internal/mcp/sse_test.go | 8 ++-- internal/mcp/tool.go | 6 +-- internal/mcp/tool_test.go | 2 +- internal/mcp/transport.go | 3 +- 16 files changed, 65 insertions(+), 63 deletions(-) diff --git a/internal/mcp/README.md b/internal/mcp/README.md index 8ce9277bec9..34e53d86230 100644 --- a/internal/mcp/README.md +++ b/internal/mcp/README.md @@ -59,7 +59,7 @@ type HiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ClientConnection, params *HiParams) ([]mcp.Content, error) { +func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) ([]mcp.Content, error) { return []mcp.Content{ mcp.TextContent{Text: "Hi " + params.Name}, }, nil @@ -110,6 +110,7 @@ server.AddPrompts(mcp.MakePrompt("greet", "", PromptHi)) Resources are a core primitive in the Model Context Protocol (MCP) that allow servers to expose data and content that can be read by clients and used as context for LLM interactions. + Resources are not supported yet. ## Testing diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 214b516d00b..8af1c7ea8f6 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -49,7 +49,7 @@ type ClientOptions struct { ConnectionOptions } -// bind implements the binder[*ServerConnection] interface, so that Clients can +// bind implements the binder[*Client] interface, so that Clients can // be connected using [connect]. func (c *Client) bind(conn *jsonrpc2.Connection) *Client { c.mu.Lock() @@ -58,7 +58,7 @@ func (c *Client) bind(conn *jsonrpc2.Connection) *Client { return c } -// disconnect implements the binder[*ServerConnection] interface, so that +// disconnect implements the binder[*Client] interface, so that // Clients can be connected using [connect]. func (c *Client) disconnect(*Client) { // Do nothing. In particular, do not set conn to nil: it needs to exist so it can diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index d85317705d3..56a32618e7b 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -21,13 +21,13 @@ type HiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ClientConnection, params *HiParams) ([]mcp.Content, error) { +func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) ([]mcp.Content, error) { return []mcp.Content{ mcp.TextContent{Text: "Hi " + params.Name}, }, nil } -func PromptHi(ctx context.Context, cc *mcp.ClientConnection, params *HiParams) (*protocol.GetPromptResult, error) { +func PromptHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) (*protocol.GetPromptResult, error) { return &protocol.GetPromptResult{ Description: "Code review prompt", Messages: []protocol.PromptMessage{ diff --git a/internal/mcp/examples/sse/main.go b/internal/mcp/examples/sse/main.go index 14c51b910c1..fc590f7e0eb 100644 --- a/internal/mcp/examples/sse/main.go +++ b/internal/mcp/examples/sse/main.go @@ -19,7 +19,7 @@ type SayHiParams struct { Name string `json:"name" mcp:"the name to say hi to"` } -func SayHi(ctx context.Context, cc *mcp.ClientConnection, params *SayHiParams) ([]mcp.Content, error) { +func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *SayHiParams) ([]mcp.Content, error) { return []mcp.Content{ mcp.TextContent{Text: "Hi " + params.Name}, }, nil diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go index fa35b92e0dc..a20b11def3f 100644 --- a/internal/mcp/mcp.go +++ b/internal/mcp/mcp.go @@ -10,7 +10,7 @@ // // Client Server // ⇅ (jsonrpc2) ⇅ -// Client Transport ⇄ Server Transport ⇄ ClientConnection +// Client Transport ⇄ Server Transport ⇄ ServerConnection // // A [Client] is an MCP client, which can be configured with various client // capabilities. Clients may be connected to a [Server] instance @@ -19,7 +19,7 @@ // Similarly, a [Server] is an MCP server, which can be configured with various // server capabilities. Servers may be connected to one or more [Client] // instances using the [Server.Connect] method, which creates a -// [ClientConnection]. +// [ServerConnection]. // // A [Transport] connects a bidirectional [Stream] of jsonrpc2 messages. In // practice, transports in the MCP spec are are either client transports or diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index a91d50a3705..2797519c28c 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -24,7 +24,7 @@ type hiParams struct { Name string } -func sayHi(ctx context.Context, cc *ClientConnection, v hiParams) ([]Content, error) { +func sayHi(ctx context.Context, cc *ServerConnection, v hiParams) ([]Content, error) { if err := cc.Ping(ctx); err != nil { return nil, fmt.Errorf("ping failed: %v", err) } @@ -43,13 +43,13 @@ func TestEndToEnd(t *testing.T) { // The 'fail' tool returns this error. failure := errors.New("mcp failure") s.AddTools( - NewTool("fail", "just fail", func(context.Context, *ClientConnection, struct{}) ([]Content, error) { + NewTool("fail", "just fail", func(context.Context, *ServerConnection, struct{}) ([]Content, error) { return nil, failure }), ) s.AddPrompts( - NewPrompt("code_review", "do a code review", func(_ context.Context, _ *ClientConnection, params struct{ Code string }) (*protocol.GetPromptResult, error) { + NewPrompt("code_review", "do a code review", func(_ context.Context, _ *ServerConnection, params struct{ Code string }) (*protocol.GetPromptResult, error) { return &protocol.GetPromptResult{ Description: "Code review prompt", Messages: []protocol.PromptMessage{ @@ -57,7 +57,7 @@ func TestEndToEnd(t *testing.T) { }, }, nil }), - NewPrompt("fail", "", func(_ context.Context, _ *ClientConnection, params struct{}) (*protocol.GetPromptResult, error) { + NewPrompt("fail", "", func(_ context.Context, _ *ServerConnection, params struct{}) (*protocol.GetPromptResult, error) { return nil, failure }), ) @@ -195,7 +195,7 @@ func TestEndToEnd(t *testing.T) { // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, tools ...*Tool) (*ClientConnection, *Client) { +func basicConnection(t *testing.T, tools ...*Tool) (*ServerConnection, *Client) { t.Helper() ctx := context.Background() @@ -283,7 +283,7 @@ func TestCancellation(t *testing.T) { cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, cc *ClientConnection, v struct{}) ([]Content, error) { + slowRequest := func(ctx context.Context, cc *ServerConnection, v struct{}) ([]Content, error) { start <- struct{}{} select { case <-ctx.Done(): diff --git a/internal/mcp/prompt.go b/internal/mcp/prompt.go index 3b816048d13..878a54eed7c 100644 --- a/internal/mcp/prompt.go +++ b/internal/mcp/prompt.go @@ -17,7 +17,7 @@ import ( ) // A PromptHandler handles a call to prompts/get. -type PromptHandler func(context.Context, *ClientConnection, map[string]string) (*protocol.GetPromptResult, error) +type PromptHandler func(context.Context, *ServerConnection, map[string]string) (*protocol.GetPromptResult, error) // A Prompt is a prompt definition bound to a prompt handler. type Prompt struct { @@ -33,7 +33,7 @@ type Prompt struct { // of type string or *string. The argument names for the resulting prompt // definition correspond to the JSON names of the request fields, and any // fields that are not marked "omitempty" are considered required. -func NewPrompt[TReq any](name, description string, handler func(context.Context, *ClientConnection, TReq) (*protocol.GetPromptResult, error), opts ...PromptOption) *Prompt { +func NewPrompt[TReq any](name, description string, handler func(context.Context, *ServerConnection, TReq) (*protocol.GetPromptResult, error), opts ...PromptOption) *Prompt { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) @@ -61,7 +61,7 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, Required: required[name], }) } - prompt.Handler = func(ctx context.Context, cc *ClientConnection, args map[string]string) (*protocol.GetPromptResult, error) { + prompt.Handler = func(ctx context.Context, cc *ServerConnection, args map[string]string) (*protocol.GetPromptResult, error) { // For simplicity, just marshal and unmarshal the arguments. // This could be avoided in the future. rawArgs, err := json.Marshal(args) diff --git a/internal/mcp/prompt_test.go b/internal/mcp/prompt_test.go index 9d18d348597..22afe812305 100644 --- a/internal/mcp/prompt_test.go +++ b/internal/mcp/prompt_test.go @@ -14,7 +14,7 @@ import ( ) // testPromptHandler is used for type inference in TestNewPrompt. -func testPromptHandler[T any](context.Context, *mcp.ClientConnection, T) (*protocol.GetPromptResult, error) { +func testPromptHandler[T any](context.Context, *mcp.ServerConnection, T) (*protocol.GetPromptResult, error) { panic("not implemented") } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 459fd3ce721..4f9e196e03f 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -28,7 +28,7 @@ type Server struct { mu sync.Mutex prompts []*Prompt tools []*Tool - clients []*ClientConnection + conns []*ServerConnection } // ServerOptions is used to configure behavior of the server. @@ -74,14 +74,14 @@ func (s *Server) AddTools(tools ...*Tool) { // Clients returns an iterator that yields the current set of client // connections. -func (s *Server) Clients() iter.Seq[*ClientConnection] { +func (s *Server) Clients() iter.Seq[*ServerConnection] { s.mu.Lock() - clients := slices.Clone(s.clients) + clients := slices.Clone(s.conns) s.mu.Unlock() return slices.Values(clients) } -func (s *Server) listPrompts(_ context.Context, _ *ClientConnection, params *protocol.ListPromptsParams) (*protocol.ListPromptsResult, error) { +func (s *Server) listPrompts(_ context.Context, _ *ServerConnection, params *protocol.ListPromptsParams) (*protocol.ListPromptsResult, error) { s.mu.Lock() defer s.mu.Unlock() @@ -92,7 +92,7 @@ func (s *Server) listPrompts(_ context.Context, _ *ClientConnection, params *pro return res, nil } -func (s *Server) getPrompt(ctx context.Context, cc *ClientConnection, params *protocol.GetPromptParams) (*protocol.GetPromptResult, error) { +func (s *Server) getPrompt(ctx context.Context, cc *ServerConnection, params *protocol.GetPromptParams) (*protocol.GetPromptResult, error) { s.mu.Lock() var prompt *Prompt if i := slices.IndexFunc(s.prompts, func(t *Prompt) bool { @@ -108,7 +108,7 @@ func (s *Server) getPrompt(ctx context.Context, cc *ClientConnection, params *pr return prompt.Handler(ctx, cc, params.Arguments) } -func (s *Server) listTools(_ context.Context, _ *ClientConnection, params *protocol.ListToolsParams) (*protocol.ListToolsResult, error) { +func (s *Server) listTools(_ context.Context, _ *ServerConnection, params *protocol.ListToolsParams) (*protocol.ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() @@ -119,7 +119,7 @@ func (s *Server) listTools(_ context.Context, _ *ClientConnection, params *proto return res, nil } -func (s *Server) callTool(ctx context.Context, cc *ClientConnection, params *protocol.CallToolParams) (*protocol.CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, cc *ServerConnection, params *protocol.CallToolParams) (*protocol.CallToolResult, error) { s.mu.Lock() var tool *Tool if i := slices.IndexFunc(s.tools, func(t *Tool) bool { @@ -146,22 +146,22 @@ func (s *Server) Run(ctx context.Context, t Transport, opts *ConnectionOptions) return cc.Wait() } -// bind implements the binder[*ClientConnection] interface, so that Servers can +// bind implements the binder[*ServerConnection] interface, so that Servers can // be connected using [connect]. -func (s *Server) bind(conn *jsonrpc2.Connection) *ClientConnection { - cc := &ClientConnection{conn: conn, server: s} +func (s *Server) bind(conn *jsonrpc2.Connection) *ServerConnection { + cc := &ServerConnection{conn: conn, server: s} s.mu.Lock() - s.clients = append(s.clients, cc) + s.conns = append(s.conns, cc) s.mu.Unlock() return cc } -// disconnect implements the binder[*ClientConnection] interface, so that +// disconnect implements the binder[*ServerConnection] interface, so that // Servers can be connected using [connect]. -func (s *Server) disconnect(cc *ClientConnection) { +func (s *Server) disconnect(cc *ServerConnection) { s.mu.Lock() defer s.mu.Unlock() - s.clients = slices.DeleteFunc(s.clients, func(cc2 *ClientConnection) bool { + s.conns = slices.DeleteFunc(s.conns, func(cc2 *ServerConnection) bool { return cc2 == cc }) } @@ -172,15 +172,17 @@ func (s *Server) disconnect(cc *ClientConnection) { // It returns a connection object that may be used to terminate the connection // (with [Connection.Close]), or await client termination (with // [Connection.Wait]). -func (s *Server) Connect(ctx context.Context, t Transport, opts *ConnectionOptions) (*ClientConnection, error) { +func (s *Server) Connect(ctx context.Context, t Transport, opts *ConnectionOptions) (*ServerConnection, error) { return connect(ctx, t, opts, s) } -// A ClientConnection is a connection with an MCP client. +// A ServerConnection is a connection from a single MCP client. Its methods can +// be used to send requests or notifications to the client. Create a connection +// by calling [Server.Connect]. // -// It handles messages from the client, and can be used to send messages to the -// client. Create a connection by calling [Server.Connect]. -type ClientConnection struct { +// Call [ServerConnection.Close] to close the connection, or await client +// termination with [ServerConnection.Wait]. +type ServerConnection struct { server *Server conn *jsonrpc2.Connection @@ -190,11 +192,11 @@ type ClientConnection struct { } // Ping makes an MCP "ping" request to the client. -func (cc *ClientConnection) Ping(ctx context.Context) error { +func (cc *ServerConnection) Ping(ctx context.Context) error { return call(ctx, cc.conn, "ping", nil, nil) } -func (cc *ClientConnection) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { +func (cc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { cc.mu.Lock() initialized := cc.initialized cc.mu.Unlock() @@ -210,7 +212,7 @@ func (cc *ClientConnection) handle(ctx context.Context, req *jsonrpc2.Request) ( } } - // TODO: embed the incoming request ID in the ClientContext (or, more likely, + // TODO: embed the incoming request ID in the client context (or, more likely, // a wrapper around it), so that we can correlate responses and notifications // to the handler; this is required for the new session-based transport. @@ -239,7 +241,7 @@ func (cc *ClientConnection) handle(ctx context.Context, req *jsonrpc2.Request) ( return nil, jsonrpc2.ErrNotHandled } -func (cc *ClientConnection) initialize(ctx context.Context, _ *ClientConnection, params *protocol.InitializeParams) (*protocol.InitializeResult, error) { +func (cc *ServerConnection) initialize(ctx context.Context, _ *ServerConnection, params *protocol.InitializeParams) (*protocol.InitializeResult, error) { cc.mu.Lock() cc.initializeParams = params cc.mu.Unlock() @@ -274,15 +276,15 @@ func (cc *ClientConnection) initialize(ctx context.Context, _ *ClientConnection, }, nil } -// Close performs a graceful close of the connection, preventing new requests -// from being handled, and waiting for ongoing requests to return. Close then -// terminates the connection. -func (cc *ClientConnection) Close() error { +// Close performs a graceful shutdown of the connection, preventing new +// requests from being handled, and waiting for ongoing requests to return. +// Close then terminates the connection. +func (cc *ServerConnection) Close() error { return cc.conn.Close() } // Wait waits for the connection to be closed by the client. -func (cc *ClientConnection) Wait() error { +func (cc *ServerConnection) Wait() error { return cc.conn.Wait() } @@ -290,7 +292,7 @@ func (cc *ClientConnection) Wait() error { // // Importantly, it returns nil if the handler returned an error, which is a // requirement of the jsonrpc2 package. -func dispatch[TConn, TParams, TResult any](ctx context.Context, conn TConn, req *jsonrpc2.Request, f func(context.Context, TConn, TParams) (TResult, error)) (any, error) { +func dispatch[TParams, TResult any](ctx context.Context, conn *ServerConnection, req *jsonrpc2.Request, f func(context.Context, *ServerConnection, TParams) (TResult, error)) (any, error) { var params TParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { return nil, err diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index 3a50877303e..9cdc1f2ad9f 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -16,7 +16,7 @@ type SayHiParams struct { Name string `json:"name" mcp:"the name to say hi to"` } -func SayHi(ctx context.Context, cc *mcp.ClientConnection, params *SayHiParams) ([]mcp.Content, error) { +func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *SayHiParams) ([]mcp.Content, error) { return []mcp.Content{ mcp.TextContent{Text: "Hi " + params.Name}, }, nil diff --git a/internal/mcp/sse.go b/internal/mcp/sse.go index bdc62a71cd3..511d4bca70f 100644 --- a/internal/mcp/sse.go +++ b/internal/mcp/sse.go @@ -63,8 +63,8 @@ func writeEvent(w io.Writer, evt event) (int, error) { // // https://modelcontextprotocol.io/specification/2024-11-05/basic/transports type SSEHandler struct { - getServer func(request *http.Request) *Server - onClient func(*ClientConnection) // for testing; must not block + getServer func(request *http.Request) *Server + onConnection func(*ServerConnection) // for testing; must not block mu sync.Mutex sessions map[string]*sseSession @@ -177,8 +177,8 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { http.Error(w, "connection failed", http.StatusInternalServerError) return } - if h.onClient != nil { - h.onClient(cc) + if h.onConnection != nil { + h.onConnection(cc) } defer cc.Close() // close the transport when the GET exits diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go index 49673a13bd1..028084faf66 100644 --- a/internal/mcp/sse_example_test.go +++ b/internal/mcp/sse_example_test.go @@ -18,7 +18,7 @@ type AddParams struct { X, Y int } -func Add(ctx context.Context, cc *mcp.ClientConnection, params *AddParams) ([]mcp.Content, error) { +func Add(ctx context.Context, cc *mcp.ServerConnection, params *AddParams) ([]mcp.Content, error) { return []mcp.Content{ mcp.TextContent{Text: fmt.Sprintf("%d", params.X+params.Y)}, }, nil diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index 801006f4a7b..f901964a51f 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -24,10 +24,10 @@ func TestSSEServer(t *testing.T) { sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) - clients := make(chan *ClientConnection, 1) - sseHandler.onClient = func(cc *ClientConnection) { + conns := make(chan *ServerConnection, 1) + sseHandler.onConnection = func(cc *ServerConnection) { select { - case clients <- cc: + case conns <- cc: default: } } @@ -43,7 +43,7 @@ func TestSSEServer(t *testing.T) { if err := c.Ping(ctx); err != nil { t.Fatal(err) } - cc := <-clients + cc := <-conns gotHi, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}) if err != nil { t.Fatal(err) diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index 9b2ceeeda8a..4c0f0eafe00 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -15,7 +15,7 @@ import ( ) // A ToolHandler handles a call to tools/call. -type ToolHandler func(context.Context, *ClientConnection, map[string]json.RawMessage) (*protocol.CallToolResult, error) +type ToolHandler func(context.Context, *ServerConnection, map[string]json.RawMessage) (*protocol.CallToolResult, error) // A Tool is a tool definition that is bound to a tool handler. type Tool struct { @@ -36,12 +36,12 @@ type Tool struct { // // TODO: just have the handler return a CallToolResult: returning []Content is // going to be inconsistent with other server features. -func NewTool[TReq any](name, description string, handler func(context.Context, *ClientConnection, TReq) ([]Content, error), opts ...ToolOption) *Tool { +func NewTool[TReq any](name, description string, handler func(context.Context, *ServerConnection, TReq) ([]Content, error), opts ...ToolOption) *Tool { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) } - wrapped := func(ctx context.Context, cc *ClientConnection, args map[string]json.RawMessage) (*protocol.CallToolResult, error) { + wrapped := func(ctx context.Context, cc *ServerConnection, args map[string]json.RawMessage) (*protocol.CallToolResult, error) { // For simplicity, just marshal and unmarshal the arguments. // This could be avoided in the future. rawArgs, err := json.Marshal(args) diff --git a/internal/mcp/tool_test.go b/internal/mcp/tool_test.go index b3a2cc025d4..45bc82048e1 100644 --- a/internal/mcp/tool_test.go +++ b/internal/mcp/tool_test.go @@ -15,7 +15,7 @@ import ( ) // testToolHandler is used for type inference in TestNewTool. -func testToolHandler[T any](context.Context, *mcp.ClientConnection, T) ([]mcp.Content, error) { +func testToolHandler[T any](context.Context, *mcp.ServerConnection, T) ([]mcp.Content, error) { panic("not implemented") } diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index 1f19aa27b9d..1edbc6521b6 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -73,8 +73,7 @@ func NewLocalTransport() (*IOTransport, *IOTransport) { return &IOTransport{c1}, &IOTransport{c2} } -// handler is an unexported version of jsonrpc2.Handler, to be implemented by -// [ServerConnection] and [ClientConnection]. +// handler is an unexported version of jsonrpc2.Handler. type handler interface { handle(ctx context.Context, req *jsonrpc2.Request) (result any, err error) } From 4a7262515ace698d9b71e55a50024457da173a7e Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 6 May 2025 19:58:27 +0000 Subject: [PATCH 016/196] internal/mcp/design: start a full design doc, with stubs Start writing up a formal design doc for an MCP SDK based on what we've learned from the prototype. Include discussion of package layout, jsonrpc2, and transport. Lay out the rest of the design sections, leaving stub comments describing what they should cover. Change-Id: Id0e5b408eac8c0acb5ef90a32d0d69473d8e83f0 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670536 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI Reviewed-by: Sam Thanawalla Auto-Submit: Robert Findley --- internal/mcp/design/design.md | 282 ++++++++++++++++++++++++++++++++++ internal/mcp/transport.go | 3 +- 2 files changed, 283 insertions(+), 2 deletions(-) create mode 100644 internal/mcp/design/design.md diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md new file mode 100644 index 00000000000..31ed720a043 --- /dev/null +++ b/internal/mcp/design/design.md @@ -0,0 +1,282 @@ +# Go MCP SDK design + +This file discusses the design of a Go SDK for the [model context +protocol](https://modelcontextprotocol.io/specification/2025-03-26). It is +intended to seed a GitHub discussion about the official Go MCP SDK, and so +approaches each design aspect from first principles. Of course, there is +significant prior art in various unofficial SDKs, along with other practical +constraints. Nevertheless, if we can first agree on the most natural way to +model the MCP spec, we can then discuss the shortest path to get there. + +# Requirements + +These may be obvious, but it's worthwhile to define goals for an official MCP +SDK. An official SDK should aim to be: + +- **complete**: it should be possible to implement every feature of the MCP + spec, and these features should conform to all of the semantics described by + the spec. +- **idiomatic**: as much as possible, MCP features should be modeled using + features of the Go language and its standard library. Additionally, the SDK + should repeat idioms from similar domains. +- **robust**: the SDK itself should be well tested and reliable, and should + enable easy testability for its users. +- **future-proof**: the SDK should allow for future evolution of the MCP spec, + in such a way that we can (as much as possible) avoid incompatible changes to + the SDK API. +- **extensible**: to best serve the previous four concerns, the SDK should be + minimal. However, it should admit extensibility using (for example) simple + interfaces, middleware, or hooks. + +# Design considerations + +In the sections below, we visit each aspect of the MCP spec, in approximately +the order they are presented by the [official spec](https://modelcontextprotocol.io/specification/2025-03-26) +For each, we discuss considerations for the Go implementation. In many cases an +API is suggested, though in some there many be open questions. + + + +## Foundations + +### Package layout + +In the sections that follow, it is assumed that most of the MCP API lives in a +single shared package, the `mcp` package. This is inconsistent with other MCP +SDKs, but is consistent with Go packages like `net/http`, `net/rpc`, or +`google.golang.org/grpc`. + +Functionality that is not directly related to MCP (like jsonschema or jsonrpc2) +belongs in a separate package. + +### jsonrpc2 and Transports + +The MCP is defined in terms of client-server communication over bidirectional +JSON-RPC message streams. Specifically, version `2025-03-26` of the spec +defines two transports: + +- **stdio**: communication with a subprocess over stdin/stdout. +- **streamable http**: communication over a relatively complicated series of + text/event-stream GET and HTTP POST requests. + +Additionally, version `2024-11-05` of the spec defined a simpler HTTP transport: + +- **sse**: client issues a hanging GET request and receives messages via + `text/event-stream`, and sends messages via POST to a session endpoint. + +Furthermore, the spec [states](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#custom-transports) that it must be possible for users to define their own custom transports. + +Given the diversity of the transport implementations, they can be challenging +to abstract. However, since JSON-RPC requires a bidirectional stream, we can +use this to model the MCP transport abstraction: + +```go +// A Transport is used to create a bidirectional connection between MCP client +// and server. +type Transport interface { + Connect(ctx context.Context) (Stream, error) +} + +// A Stream is a bidirectional jsonrpc2 Stream. +type Stream interface { + Read(ctx context.Context) (jsonrpc2.Message, error) + Write(ctx context.Context, jsonrpc2.Message) error + Close() error +} +``` + +Specifically, a `Transport` is something that connects a logical JSON-RPC +stream, and nothing more (methods accept a Go `Context` and return an `error`, +as is idiomatic for APIs that do I/O). Streams must be closeable in order to +implement client and server shutdown, and therefore conform to the `io.Closer` +interface. + +Other SDKs define higher-level transports, with, for example, methods to send a +notification or make a call. Those are jsonrpc2 operations on top of the +logical stream, and the lower-level interface is easier to implement in most +cases, which means it is easier to implement custom transports or middleware. + +For our prototype, we've used an internal `jsonrpc2` package based on the Go +language server `gopls`, which we propose to fork for the MCP SDK. It already +handles concerns like client/server connection, request lifecycle, +cancellation, and shutdown. + +In the MCP Spec, the **stdio** transport uses newline-delimited JSON to +communicate over stdin/stdout. It's possible to model both client side and +server side of this communication with a shared type that communicates over an +`io.ReadWriteCloser`. However, for the purposes of future-proofing, we should +use a distinct types for both client and server stdio transport. + +The `CommandTransport` is the client side of the stdio transport, and +connects by starting a command and binding its jsonrpc2 stream to its +stdin/stdout. + +```go +// A CommandTransport is a [Transport] that runs a command and communicates +// with it over stdin/stdout, using newline-delimited JSON. +type CommandTransport struct { /* unexported fields */ } + +// NewCommandTransport returns a [CommandTransport] that runs the given command +// and communicates with it over stdin/stdout. +func NewCommandTransport(cmd *exec.Command) *CommandTransport + +// Connect starts the command, and connects to it over stdin/stdout. +func (t *CommandTransport) Connect(ctx context.Context) (Stream, error) { +``` + +The `StdIOTransport` is the server side of the stdio transport, and connects by +binding to `os.Stdin` and `os.Stdout`. + +```go +// A StdIOTransport is a [Transport] that communicates using newline-delimited +// JSON over stdin/stdout. +type StdIOTransport struct { /* unexported fields */ } + +func NewStdIOTransport() *StdIOTransport { + +func (t *StdIOTransport) Connect(context.Context) (Stream, error) +``` + +The HTTP transport APIs are even more asymmetrical. Since connections are initiated +via HTTP requests, the client developer will create a transport, but +the server developer will typically install an HTTP handler. Internally, the +HTTP handler will create a transport for each new client connection. + +Importantly, since they serve many connections, the HTTP handlers must accept a +callback to get an MCP server for each new session. + +```go +// SSEHandler is an http.Handler that serves SSE-based MCP sessions as defined by +// the 2024-11-05 version of the MCP protocol. +type SSEHandler struct { /* unexported fields */ } + +// NewSSEHandler returns a new [SSEHandler] that is ready to serve HTTP. +// +// The getServer function is used to bind created servers for new sessions. It +// is OK for getServer to return the same server multiple times. +func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler + +func (*SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) + +// Close prevents the SSEHandler from accepting new sessions, closes active +// sessions, and awaits their graceful termination. +func (*SSEHandler) Close() error +``` + +Notably absent are options to hook into the request handling for the purposes +of authentication or context injection. These concerns are better handled using +standard HTTP middleware patterns. + + + + + +### Protocol types + + + +### Clients and Servers + + + +### Errors + + + +### Cancellation + + + +### Progress handling + + + +### Ping / Keepalive + + + +## Client Features + +### Roots + + + +### Sampling + + + +## Server Features + +### Tools + + + +#### JSON Schema + + + +### Prompts + + + +### Resources + + + +### Completion + + + +### Logging + + + +### Pagination + + + +## Compatibility with existing SDKs + + diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index 1edbc6521b6..0b1ed3a2d28 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -35,8 +35,7 @@ type Transport interface { Connect(ctx context.Context) (Stream, error) } -// A Stream is an abstract bidirectional jsonrpc2 Stream. -// It is used by [connect] to establish a [jsonrpc2.Connection]. +// A Stream is a bidirectional jsonrpc2 Stream. type Stream interface { jsonrpc2.Reader jsonrpc2.Writer From dbc82b6d796309612826fb144bf18704209ca24a Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 7 May 2025 09:27:59 -0400 Subject: [PATCH 017/196] internal/mcp/jsonschema: make Schema.every fully recursive s.every now visits the entire subtree of schemas at s, not only the immediate children of s. This re-introduces a tower: now of closures instead of iterators. We could remove it having everyChild call every recursively instead of returning, just passing f down all the way unwrapped. But this is going to complicate the recursive function inside resolveURIs, which needs to pass parent context down. It would have the same problem that ast.Inspect has, and we'd have to adopt the same solution, maintaining a stack explicitly. I feel it's cleaner to keep the recursion simple at the cost of building several closures. To do otherwise would be premature optimization. Change-Id: Ic453cdaac482da8147dafc34b0117d25458d1c9d Reviewed-on: https://go-review.googlesource.com/c/tools/+/670675 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- internal/mcp/jsonschema/schema.go | 2 +- internal/mcp/jsonschema/schema_test.go | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 5bb6de6eb56..ade91e1ef31 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -293,7 +293,7 @@ func Ptr[T any](x T) *T { return &x } // It stops when f returns false. func (s *Schema) every(f func(*Schema) bool) bool { return s == nil || - f(s) && s.everyChild(f) + f(s) && s.everyChild(func(s *Schema) bool { return s.every(f) }) } // everyChild reports whether f is true for every immediate child schema of s. diff --git a/internal/mcp/jsonschema/schema_test.go b/internal/mcp/jsonschema/schema_test.go index 2bda7818af2..4d042d560b6 100644 --- a/internal/mcp/jsonschema/schema_test.go +++ b/internal/mcp/jsonschema/schema_test.go @@ -111,3 +111,18 @@ func TestUnmarshalErrors(t *testing.T) { } } + +func TestEvery(t *testing.T) { + // Schema.every should visit all descendants of a schema, not just the immediate ones. + s := &Schema{ + Items: &Schema{ + Items: &Schema{}, + }, + } + want := 3 + got := 0 + s.every(func(*Schema) bool { got++; return true }) + if got != want { + t.Errorf("got %d, want %d", got, want) + } +} From 89c7c2cd971df18650b29c17c3bd52bc607f3934 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 8 May 2025 07:29:11 -0400 Subject: [PATCH 018/196] internal/mcp/design: logging Design for logging. Change-Id: I787bb51318337bd09f973e04d0e4dbf526e09fb6 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670935 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Reviewed-by: Sam Thanawalla Auto-Submit: Jonathan Amsterdam --- internal/mcp/design/design.md | 56 ++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 31ed720a043..6b6a29ebb59 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -62,7 +62,7 @@ The MCP is defined in terms of client-server communication over bidirectional JSON-RPC message streams. Specifically, version `2025-03-26` of the spec defines two transports: -- **stdio**: communication with a subprocess over stdin/stdout. +- ****: communication with a subprocess over stdin/stdout. - **streamable http**: communication over a relatively complicated series of text/event-stream GET and HTTP POST requests. @@ -271,12 +271,60 @@ Server.RemoveResources ### Logging - +Servers have access to a `slog.Logger` that writes to the client. A call to +a log method like `Info`is translated to a `LoggingMessageNotification` as +follows: + +- An attribute with key "logger" is used to populate the "logger" field of the notification. + +- The remaining attributes and the message populate the "data" field with the + output of a `slog.JSONHandler`: The result is always a JSON object, with the + key "msg" for the message. + +- The standard slog levels `Info`, `Debug`, `Warn` and `Error` map to the + corresponding levels in the MCP spec. The other spec levels will be mapped + to integers between the slog levels. For example, "notice" is level 2 because + it is between "warning" (slog value 4) and "info" (slog value 0). + The `mcp` package defines consts for these levels. To log at the "notice" + level, a server would call `Log(ctx, mcp.LevelNotice, "message")`. ### Pagination -## Compatibility with existing SDKs +## Differences with mcp-go + +The most popular MCP package for Go is [mcp-go](https://pkg.go.dev/github.com/ +mark3labs/mcp-go). While we admire the thoughfulness of its design and the high +quality of its implementation, we made different choices. Although the APIs are +not compatible, translating between them is straightforward. (Later, we will +provide a detailed translation guide.) + +## Packages + +As we mentioned above, we decided to put most of the API into a single package. +The exceptions are the JSON-RPC layer, the JSON Schema implementation, and the +parts of the MCP protocol that users don't need. The resulting `mcp` includes +all the functionality of mcp-go's `mcp`, `client`, `server` and `transport` +packages, but is smaller than the `mcp` package alone. + +## Hooks + +Version 0.26.0 of mcp-go defines 24 server hooks. Each hook consists of a field +in the `Hooks` struct, a `Hooks.Add` method, and a type for the hook function. +As described above, these can be replaced by middleware. We +don't define any middleware types at present, but will do so if there is demand. +(We're minimalists, not savages.) + +## Servers + +In mcp-go, server authors create an `MCPServer`, populate it with tools, +resources and so on, and then wrap it in an `SSEServer` or `StdioServer`. These +also use session IDs, which are exposed. Users can manage their own sessions +with `RegisterSession` and `UnregisterSession`. - +We find the similarity in names among the three server types to be confusing, +and we could not discover any uses of the session methods in the open-source +ecosystem. In our design is similar, server authors create a `Server`, and then +connect it to a `Transport` or SSE handler. We manage multiple web clients for a +single server using session IDs internally, but do not expose them. From bad5619656c748f03c01a9dfe2c9dd4e5f96ba7c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 8 May 2025 07:09:48 -0400 Subject: [PATCH 019/196] internal/mcp/protocol: spell "Id" idiomatically Enhance the generator so that it recognizes initialisms like "Id" and writes them in all caps, as is idiomatic for Go. Also, fix the import path of jsonschema in the generated code. Change-Id: Ib12e9fbd543d360035d5e324f0aefca584ac7c85 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670916 Auto-Submit: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/protocol/generate.go | 40 ++++++++++++++++++++++++++++--- internal/mcp/protocol/protocol.go | 2 +- internal/mcp/transport.go | 4 ++-- 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/internal/mcp/protocol/generate.go b/internal/mcp/protocol/generate.go index e8c92bc802e..30bc81d978b 100644 --- a/internal/mcp/protocol/generate.go +++ b/internal/mcp/protocol/generate.go @@ -22,11 +22,12 @@ import ( "net/http" "os" "reflect" + "regexp" "slices" "strings" - "golang.org/x/tools/internal/mcp/internal/jsonschema" "golang.org/x/tools/internal/mcp/internal/util" + "golang.org/x/tools/internal/mcp/jsonschema" ) var schemaFile = flag.String("schema_file", "", "if set, use this file as the persistent schema file") @@ -138,7 +139,7 @@ package protocol import ( "encoding/json" - "golang.org/x/tools/internal/mcp/internal/jsonschema" + "golang.org/x/tools/internal/mcp/jsonschema" ) `) @@ -381,11 +382,44 @@ func canHaveAdditionalProperties(s *jsonschema.Schema) bool { // exportName returns an exported name for a Go symbol, based on the given name // in the JSON schema, removing leading underscores and capitalizing. +// It also rewrites initialisms. func exportName(s string) string { if strings.HasPrefix(s, "_") { s = s[1:] } - return strings.ToUpper(s[:1]) + s[1:] + s = strings.ToUpper(s[:1]) + s[1:] + // Replace an initialism if it is its own "word": see the init function below for + // a definition. + // There is probably a clever way to write this whole thing with one regexp and + // a Replace method, but it would be quite obscure. + // This doesn't have to be fast, because the first match will rarely succeed. + for ism, re := range initialisms { + replacement := strings.ToUpper(ism) + // Find the index of one match at a time, and replace. (We can't find all + // at once, because the replacement will change the indices.) + for { + if loc := re.FindStringIndex(s); loc != nil { + s = s[:loc[0]] + replacement + s[loc[1]:] + } else { + break + } + } + } + return s +} + +// Map from initialism to the regexp that matches it. +var initialisms = map[string]*regexp.Regexp{ + "Id": nil, + "Url": nil, + "Uri": nil, +} + +func init() { + for ism := range initialisms { + // Match ism if it is at the end, or followed by an uppercase letter or a number. + initialisms[ism] = regexp.MustCompile(ism + `($|[A-Z0-9])`) + } } func assert(cond bool, msg string) { diff --git a/internal/mcp/protocol/protocol.go b/internal/mcp/protocol/protocol.go index ead91c5b2b1..399408238a4 100644 --- a/internal/mcp/protocol/protocol.go +++ b/internal/mcp/protocol/protocol.go @@ -62,7 +62,7 @@ type CancelledParams struct { // // This MUST correspond to the ID of a request previously issued in the same // direction. - RequestId any `json:"requestId"` + RequestID any `json:"requestId"` } // Capabilities a client may support. Known capabilities are defined here, in diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index 0b1ed3a2d28..403d3a2371c 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -135,7 +135,7 @@ func (c *canceller) Preempt(ctx context.Context, req *jsonrpc2.Request) (result if err := json.Unmarshal(req.Params, ¶ms); err != nil { return nil, err } - id, err := jsonrpc2.MakeID(params.RequestId) + id, err := jsonrpc2.MakeID(params.RequestID) if err != nil { return nil, err } @@ -156,7 +156,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params, // Notify the peer of cancellation. err := conn.Notify(xcontext.Detach(ctx), "notifications/cancelled", &protocol.CancelledParams{ Reason: ctx.Err().Error(), - RequestId: call.ID().Raw(), + RequestID: call.ID().Raw(), }) return errors.Join(ctx.Err(), err) case err != nil: From 51dcb32d9b958687a0361730dd5aa3c291419f00 Mon Sep 17 00:00:00 2001 From: cuishuang Date: Thu, 1 May 2025 22:43:53 +0800 Subject: [PATCH 020/196] gopls/internal/analysis/modernize: add checks to prevent invalid fixes from slicesContains When using `slices.ContainsFunc` to refactor loops that iterate over a slice of concrete types with a callback function that accepts an interface parameter, Go's type system fails to infer the correct type parameters. This results in compilation errors like `in call to slices.ContainsFunc, S (type []ConcreteType) does not satisfy ~[]E`. Fixes golang/go#73564 Change-Id: Icfe6b170a5a89be4c9a2ed7c80eb7d7d53ef4d11 Reviewed-on: https://go-review.googlesource.com/c/tools/+/669355 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan Reviewed-by: Cherry Mui Auto-Submit: Alan Donovan --- .../analysis/modernize/slicescontains.go | 24 ++++++++++++++++--- .../src/slicescontains/slicescontains.go | 15 ++++++++++++ .../slicescontains/slicescontains.go.golden | 15 ++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/gopls/internal/analysis/modernize/slicescontains.go b/gopls/internal/analysis/modernize/slicescontains.go index 78a569eeca9..b5cd56022e1 100644 --- a/gopls/internal/analysis/modernize/slicescontains.go +++ b/gopls/internal/analysis/modernize/slicescontains.go @@ -129,9 +129,27 @@ func slicescontains(pass *analysis.Pass) { isSliceElem(cond.Args[0]) && typeutil.Callee(info, cond) != nil { // not a conversion - // skip variadic functions - if sig, ok := info.TypeOf(cond.Fun).(*types.Signature); ok && sig.Variadic() { - return + // Attempt to get signature + sig, isSignature := info.TypeOf(cond.Fun).(*types.Signature) + if isSignature { + // skip variadic functions + if sig.Variadic() { + return + } + + // Check for interface parameter with concrete argument, + // if the function has parameters. + if sig.Params().Len() > 0 { + paramType := sig.Params().At(0).Type() + elemType := info.TypeOf(cond.Args[0]) + + // If the function's first parameter is an interface + // and the argument passed is a concrete (non-interface) type, + // then we return and do not suggest this refactoring. + if types.IsInterface(paramType) && !types.IsInterface(elemType) { + return + } + } } funcName = "ContainsFunc" diff --git a/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go index 03bcfc69904..326608725d4 100644 --- a/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go +++ b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go @@ -169,3 +169,18 @@ func nopeVariadicContainsFunc(slice []int) bool { } return false } + +// Negative test case for implicit C->I conversion +type I interface{ F() } +type C int + +func (C) F() {} + +func nopeImplicitConversionContainsFunc(slice []C, f func(I) bool) bool { + for _, elem := range slice { + if f(elem) { + return true + } + } + return false +} diff --git a/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go.golden b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go.golden index 67e5b544960..9a16b749863 100644 --- a/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go.golden +++ b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go.golden @@ -125,3 +125,18 @@ func nopeVariadicContainsFunc(slice []int) bool { } return false } + +// Negative test case for implicit C->I conversion +type I interface{ F() } +type C int + +func (C) F() {} + +func nopeImplicitConversionContainsFunc(slice []C, f func(I) bool) bool { + for _, elem := range slice { + if f(elem) { + return true + } + } + return false +} \ No newline at end of file From 6736a6d7b1cd791e00e52b2fab88bfa0e4afde15 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 26 Apr 2025 11:26:58 -0400 Subject: [PATCH 021/196] jsonschema: resolve refs Implement the $ref keyword, except for remote references (those pointing outside the root schema). Change-Id: Ia5ca88464e62e9c613970d8d9fcd1512ab288934 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670775 Reviewed-by: Alan Donovan Auto-Submit: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/mcp/jsonschema/resolve.go | 144 ++- internal/mcp/jsonschema/resolve_test.go | 37 +- internal/mcp/jsonschema/schema.go | 5 +- .../jsonschema/testdata/draft2020-12/ref.json | 1052 +++++++++++++++++ internal/mcp/jsonschema/util.go | 2 +- internal/mcp/jsonschema/validate.go | 11 +- internal/mcp/jsonschema/validate_test.go | 16 +- 7 files changed, 1242 insertions(+), 25 deletions(-) create mode 100644 internal/mcp/jsonschema/testdata/draft2020-12/ref.json diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index ae3da3737c3..512ead0341d 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -12,6 +12,7 @@ import ( "fmt" "net/url" "regexp" + "strings" ) // A Resolved consists of a [Schema] along with associated information needed to @@ -25,22 +26,31 @@ type Resolved struct { resolvedURIs map[string]*Schema } +// A Loader reads and unmarshals the schema at uri, if any. +type Loader func(uri *url.URL) (*Schema, error) + // Resolve resolves all references within the schema and performs other tasks that // prepare the schema for validation. +// // baseURI can be empty, or an absolute URI (one that starts with a scheme). // It is resolved (in the URI sense; see [url.ResolveReference]) with root's $id property. // If the resulting URI is not absolute, then the schema cannot not contain relative URI references. -func (root *Schema) Resolve(baseURI string) (*Resolved, error) { - // There are three steps involved in preparing a schema to validate. - // 1. Check: validate the schema against a meta-schema, and perform other well-formedness - // checks. Precompute some values along the way. - // 2. Resolve URIs: determine the base URI of the root and all its subschemas, and +// +// loader loads schemas that are referred to by a $ref but not under root (a remote reference). +// If nil, remote references will return an error. +func (root *Schema) Resolve(baseURI string, loader Loader) (*Resolved, error) { + // There are four steps involved in preparing a schema to validate. + // 1. Load: read the schema from somewhere and unmarshal it. + // This schema (root) may have been loaded or created in memory, but other schemas that + // come into the picture in step 4 will be loaded by the given loader. + // 2. Check: validate the schema against a meta-schema, and perform other well-formedness checks. + // Precompute some values along the way. + // 3. Resolve URIs: determine the base URI of the root and all its subschemas, and // resolve (in the URI sense) all identifiers and anchors with their bases. This step results // in a map from URIs to schemas within root. - // 3. Resolve references: TODO. - if err := root.check(); err != nil { - return nil, err - } + // These three steps are idempotent. They may occur a several times on a schema, if + // it is loaded from several places. + // 4. Resolve references: all refs in the schemas are replaced with the schema they refer to. var base *url.URL if baseURI == "" { base = &url.URL{} // so we can call ResolveReference on it @@ -51,14 +61,104 @@ func (root *Schema) Resolve(baseURI string) (*Resolved, error) { return nil, fmt.Errorf("parsing base URI: %w", err) } } - m, err := resolveURIs(root, base) + + if loader == nil { + loader = func(uri *url.URL) (*Schema, error) { + return nil, errors.New("cannot resolve remote schemas: no loader passed to Schema.Resolve") + } + } + r := &resolver{ + loader: loader, + loaded: map[string]*Resolved{}, + } + + return r.resolve(root, base) + // TODO: before we return, throw away anything we don't need for validation. +} + +// A resolver holds the state for resolution. +type resolver struct { + loader Loader + // A cache of loaded and partly resolved schemas. (They may not have had their + // refs resolved.) The cache ensures that the loader will never be called more + // than once with the same URI, and that reference cycles are handled properly. + loaded map[string]*Resolved +} + +func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { + if baseURI.Fragment != "" { + return nil, fmt.Errorf("base URI %s must not have a fragment", baseURI) + } + if err := s.check(); err != nil { + return nil, err + } + + m, err := resolveURIs(s, baseURI) if err != nil { return nil, err } - return &Resolved{ - root: root, - resolvedURIs: m, - }, nil + rs := &Resolved{root: s, resolvedURIs: m} + // Remember the schema by both the URI we loaded it from and its canonical name, + // which may differ if the schema has an $id. + // We must set the map before calling resolveRefs, or ref cycles will cause unbounded recursion. + r.loaded[baseURI.String()] = rs + r.loaded[s.baseURI.String()] = rs + + if err := r.resolveRefs(rs); err != nil { + return nil, err + } + return rs, nil +} + +// resolveRefs replaces all refs in the schemas with the schema they refer to. +// A reference that doesn't resolve within the schema may refer to some other schema +// that needs to be loaded. +func (r *resolver) resolveRefs(rs *Resolved) error { + for s := range rs.root.all() { + if s.Ref == "" { + continue + } + refURI, err := url.Parse(s.Ref) + if err != nil { + return err + } + // URI-resolve the ref against the current base URI to get a complete URI. + refURI = s.baseURI.ResolveReference(refURI) + // The non-fragment part of a ref URI refers to the base URI of some schema. + u := *refURI + u.Fragment = "" + fraglessRefURI := &u + // Look it up locally. + referencedSchema := rs.resolvedURIs[fraglessRefURI.String()] + if referencedSchema == nil { + // The schema is remote. Maybe we've already loaded it. + // We assume that the non-fragment part of refURI refers to a top-level schema + // document. That is, we don't support the case exemplified by + // http://foo.com/bar.json/baz, where the document is in bar.json and + // the reference points to a subschema within it. + // TODO: support that case. + loadedResolved := r.loaded[fraglessRefURI.String()] + if loadedResolved == nil { + // Try to load the schema. + ls, err := r.loader(fraglessRefURI) + if err != nil { + return fmt.Errorf("loading %s: %w", fraglessRefURI, err) + } + loadedResolved, err = r.resolve(ls, fraglessRefURI) + if err != nil { + return err + } + } + referencedSchema = loadedResolved.root + assert(referencedSchema != nil, "nil referenced schema") + } + // The fragment selects the referenced schema, or a subschema of it. + s.resolvedRef, err = lookupFragment(referencedSchema, refURI.Fragment) + if err != nil { + return err + } + } + return nil } func (s *Schema) check() error { @@ -208,3 +308,19 @@ func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { } return resolvedURIs, nil } + +// lookupFragment returns the schema referenced by frag in s, or an error +// if there isn't one or something else went wrong. +func lookupFragment(s *Schema, frag string) (*Schema, error) { + // frag is either a JSON Pointer or the name of an anchor. + // A JSON Pointer is either the empty string or begins with a '/', + // whereas anchors are always non-empty strings that don't contain slashes. + if frag != "" && !strings.HasPrefix(frag, "/") { + if fs := s.anchors[frag]; fs != nil { + return fs, nil + } + return nil, fmt.Errorf("no anchor %q in %s", frag, s) + } + // frag is a JSON Pointer. Follow it. + return dereferenceJSONPointer(s, frag) +} diff --git a/internal/mcp/jsonschema/resolve_test.go b/internal/mcp/jsonschema/resolve_test.go index f8fb2b5dfb1..2b469d4dd9a 100644 --- a/internal/mcp/jsonschema/resolve_test.go +++ b/internal/mcp/jsonschema/resolve_test.go @@ -5,6 +5,7 @@ package jsonschema import ( + "errors" "maps" "net/url" "regexp" @@ -27,7 +28,7 @@ func TestCheckLocal(t *testing.T) { "regexp", }, } { - _, err := tt.s.Resolve("") + _, err := tt.s.Resolve("", nil) if err == nil { t.Errorf("%s: unexpectedly passed", tt.s.json()) continue @@ -104,3 +105,37 @@ func TestResolveURIs(t *testing.T) { }) } } + +func TestRefCycle(t *testing.T) { + // Verify that cycles of refs are OK. + // The test suite doesn't check this, suprisingly. + schemas := map[string]*Schema{ + "root": {Ref: "a"}, + "a": {Ref: "b"}, + "b": {Ref: "a"}, + } + + loader := func(uri *url.URL) (*Schema, error) { + s, ok := schemas[uri.Path[1:]] + if !ok { + return nil, errors.New("not found") + } + return s, nil + } + + rs, err := schemas["root"].Resolve("", loader) + if err != nil { + t.Fatal(err) + } + + check := func(s *Schema, key string) { + t.Helper() + if s.resolvedRef != schemas[key] { + t.Errorf("%s resolvedRef != schemas[%q]", s.json(), key) + } + } + + check(rs.root, "a") + check(schemas["a"], "b") + check(schemas["b"], "a") +} diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index ade91e1ef31..c032661baf1 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -115,8 +115,11 @@ type Schema struct { // The parent base URI at top level is where the schema was loaded from, or // if not loaded, then it should be provided to Schema.Resolve. baseURI *url.URL + // The schema to which Ref refers. + resolvedRef *Schema // map from anchors to subschemas - anchors map[string]*Schema + anchors map[string]*Schema + // compiled regexps pattern *regexp.Regexp patternProperties map[*regexp.Regexp]*Schema } diff --git a/internal/mcp/jsonschema/testdata/draft2020-12/ref.json b/internal/mcp/jsonschema/testdata/draft2020-12/ref.json new file mode 100644 index 00000000000..0ac02fb9139 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/draft2020-12/ref.json @@ -0,0 +1,1052 @@ +[ + { + "description": "root pointer ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "foo": {"$ref": "#"} + }, + "additionalProperties": false + }, + "tests": [ + { + "description": "match", + "data": {"foo": false}, + "valid": true + }, + { + "description": "recursive match", + "data": {"foo": {"foo": false}}, + "valid": true + }, + { + "description": "mismatch", + "data": {"bar": false}, + "valid": false + }, + { + "description": "recursive mismatch", + "data": {"foo": {"bar": false}}, + "valid": false + } + ] + }, + { + "description": "relative pointer ref to object", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "foo": {"type": "integer"}, + "bar": {"$ref": "#/properties/foo"} + } + }, + "tests": [ + { + "description": "match", + "data": {"bar": 3}, + "valid": true + }, + { + "description": "mismatch", + "data": {"bar": true}, + "valid": false + } + ] + }, + { + "description": "relative pointer ref to array", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "prefixItems": [ + {"type": "integer"}, + {"$ref": "#/prefixItems/0"} + ] + }, + "tests": [ + { + "description": "match array", + "data": [1, 2], + "valid": true + }, + { + "description": "mismatch array", + "data": [1, "foo"], + "valid": false + } + ] + }, + { + "description": "escaped pointer ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "tilde~field": {"type": "integer"}, + "slash/field": {"type": "integer"}, + "percent%field": {"type": "integer"} + }, + "properties": { + "tilde": {"$ref": "#/$defs/tilde~0field"}, + "slash": {"$ref": "#/$defs/slash~1field"}, + "percent": {"$ref": "#/$defs/percent%25field"} + } + }, + "tests": [ + { + "description": "slash invalid", + "data": {"slash": "aoeu"}, + "valid": false + }, + { + "description": "tilde invalid", + "data": {"tilde": "aoeu"}, + "valid": false + }, + { + "description": "percent invalid", + "data": {"percent": "aoeu"}, + "valid": false + }, + { + "description": "slash valid", + "data": {"slash": 123}, + "valid": true + }, + { + "description": "tilde valid", + "data": {"tilde": 123}, + "valid": true + }, + { + "description": "percent valid", + "data": {"percent": 123}, + "valid": true + } + ] + }, + { + "description": "nested refs", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "a": {"type": "integer"}, + "b": {"$ref": "#/$defs/a"}, + "c": {"$ref": "#/$defs/b"} + }, + "$ref": "#/$defs/c" + }, + "tests": [ + { + "description": "nested ref valid", + "data": 5, + "valid": true + }, + { + "description": "nested ref invalid", + "data": "a", + "valid": false + } + ] + }, + { + "description": "ref applies alongside sibling keywords", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "reffed": { + "type": "array" + } + }, + "properties": { + "foo": { + "$ref": "#/$defs/reffed", + "maxItems": 2 + } + } + }, + "tests": [ + { + "description": "ref valid, maxItems valid", + "data": { "foo": [] }, + "valid": true + }, + { + "description": "ref valid, maxItems invalid", + "data": { "foo": [1, 2, 3] }, + "valid": false + }, + { + "description": "ref invalid", + "data": { "foo": "string" }, + "valid": false + } + ] + }, + { + "description": "remote ref, containing refs itself", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "https://json-schema.org/draft/2020-12/schema" + }, + "tests": [ + { + "description": "remote ref valid", + "data": {"minLength": 1}, + "valid": true + }, + { + "description": "remote ref invalid", + "data": {"minLength": -1}, + "valid": false + } + ] + }, + { + "description": "property named $ref that is not a reference", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "$ref": {"type": "string"} + } + }, + "tests": [ + { + "description": "property named $ref valid", + "data": {"$ref": "a"}, + "valid": true + }, + { + "description": "property named $ref invalid", + "data": {"$ref": 2}, + "valid": false + } + ] + }, + { + "description": "property named $ref, containing an actual $ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "$ref": {"$ref": "#/$defs/is-string"} + }, + "$defs": { + "is-string": { + "type": "string" + } + } + }, + "tests": [ + { + "description": "property named $ref valid", + "data": {"$ref": "a"}, + "valid": true + }, + { + "description": "property named $ref invalid", + "data": {"$ref": 2}, + "valid": false + } + ] + }, + { + "description": "$ref to boolean schema true", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/bool", + "$defs": { + "bool": true + } + }, + "tests": [ + { + "description": "any value is valid", + "data": "foo", + "valid": true + } + ] + }, + { + "description": "$ref to boolean schema false", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/bool", + "$defs": { + "bool": false + } + }, + "tests": [ + { + "description": "any value is invalid", + "data": "foo", + "valid": false + } + ] + }, + { + "description": "Recursive references between schemas", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/tree", + "description": "tree of nodes", + "type": "object", + "properties": { + "meta": {"type": "string"}, + "nodes": { + "type": "array", + "items": {"$ref": "node"} + } + }, + "required": ["meta", "nodes"], + "$defs": { + "node": { + "$id": "http://localhost:1234/draft2020-12/node", + "description": "node", + "type": "object", + "properties": { + "value": {"type": "number"}, + "subtree": {"$ref": "tree"} + }, + "required": ["value"] + } + } + }, + "tests": [ + { + "description": "valid tree", + "data": { + "meta": "root", + "nodes": [ + { + "value": 1, + "subtree": { + "meta": "child", + "nodes": [ + {"value": 1.1}, + {"value": 1.2} + ] + } + }, + { + "value": 2, + "subtree": { + "meta": "child", + "nodes": [ + {"value": 2.1}, + {"value": 2.2} + ] + } + } + ] + }, + "valid": true + }, + { + "description": "invalid tree", + "data": { + "meta": "root", + "nodes": [ + { + "value": 1, + "subtree": { + "meta": "child", + "nodes": [ + {"value": "string is invalid"}, + {"value": 1.2} + ] + } + }, + { + "value": 2, + "subtree": { + "meta": "child", + "nodes": [ + {"value": 2.1}, + {"value": 2.2} + ] + } + } + ] + }, + "valid": false + } + ] + }, + { + "description": "refs with quote", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "foo\"bar": {"$ref": "#/$defs/foo%22bar"} + }, + "$defs": { + "foo\"bar": {"type": "number"} + } + }, + "tests": [ + { + "description": "object with numbers is valid", + "data": { + "foo\"bar": 1 + }, + "valid": true + }, + { + "description": "object with strings is invalid", + "data": { + "foo\"bar": "1" + }, + "valid": false + } + ] + }, + { + "description": "ref creates new scope when adjacent to keywords", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "A": { + "unevaluatedProperties": false + } + }, + "properties": { + "prop1": { + "type": "string" + } + }, + "$ref": "#/$defs/A" + }, + "tests": [ + { + "description": "referenced subschema doesn't see annotations from properties", + "data": { + "prop1": "match" + }, + "valid": false + } + ] + }, + { + "description": "naive replacement of $ref with its destination is not correct", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "a_string": { "type": "string" } + }, + "enum": [ + { "$ref": "#/$defs/a_string" } + ] + }, + "tests": [ + { + "description": "do not evaluate the $ref inside the enum, matching any string", + "data": "this is a string", + "valid": false + }, + { + "description": "do not evaluate the $ref inside the enum, definition exact match", + "data": { "type": "string" }, + "valid": false + }, + { + "description": "match the enum exactly", + "data": { "$ref": "#/$defs/a_string" }, + "valid": true + } + ] + }, + { + "description": "refs with relative uris and defs", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://example.com/schema-relative-uri-defs1.json", + "properties": { + "foo": { + "$id": "schema-relative-uri-defs2.json", + "$defs": { + "inner": { + "properties": { + "bar": { "type": "string" } + } + } + }, + "$ref": "#/$defs/inner" + } + }, + "$ref": "schema-relative-uri-defs2.json" + }, + "tests": [ + { + "description": "invalid on inner field", + "data": { + "foo": { + "bar": 1 + }, + "bar": "a" + }, + "valid": false + }, + { + "description": "invalid on outer field", + "data": { + "foo": { + "bar": "a" + }, + "bar": 1 + }, + "valid": false + }, + { + "description": "valid on both fields", + "data": { + "foo": { + "bar": "a" + }, + "bar": "a" + }, + "valid": true + } + ] + }, + { + "description": "relative refs with absolute uris and defs", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://example.com/schema-refs-absolute-uris-defs1.json", + "properties": { + "foo": { + "$id": "http://example.com/schema-refs-absolute-uris-defs2.json", + "$defs": { + "inner": { + "properties": { + "bar": { "type": "string" } + } + } + }, + "$ref": "#/$defs/inner" + } + }, + "$ref": "schema-refs-absolute-uris-defs2.json" + }, + "tests": [ + { + "description": "invalid on inner field", + "data": { + "foo": { + "bar": 1 + }, + "bar": "a" + }, + "valid": false + }, + { + "description": "invalid on outer field", + "data": { + "foo": { + "bar": "a" + }, + "bar": 1 + }, + "valid": false + }, + { + "description": "valid on both fields", + "data": { + "foo": { + "bar": "a" + }, + "bar": "a" + }, + "valid": true + } + ] + }, + { + "description": "$id must be resolved against nearest parent, not just immediate parent", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://example.com/a.json", + "$defs": { + "x": { + "$id": "http://example.com/b/c.json", + "not": { + "$defs": { + "y": { + "$id": "d.json", + "type": "number" + } + } + } + } + }, + "allOf": [ + { + "$ref": "http://example.com/b/d.json" + } + ] + }, + "tests": [ + { + "description": "number is valid", + "data": 1, + "valid": true + }, + { + "description": "non-number is invalid", + "data": "a", + "valid": false + } + ] + }, + { + "description": "order of evaluation: $id and $ref", + "schema": { + "$comment": "$id must be evaluated before $ref to get the proper $ref destination", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://example.com/draft2020-12/ref-and-id1/base.json", + "$ref": "int.json", + "$defs": { + "bigint": { + "$comment": "canonical uri: https://example.com/ref-and-id1/int.json", + "$id": "int.json", + "maximum": 10 + }, + "smallint": { + "$comment": "canonical uri: https://example.com/ref-and-id1-int.json", + "$id": "/draft2020-12/ref-and-id1-int.json", + "maximum": 2 + } + } + }, + "tests": [ + { + "description": "data is valid against first definition", + "data": 5, + "valid": true + }, + { + "description": "data is invalid against first definition", + "data": 50, + "valid": false + } + ] + }, + { + "description": "order of evaluation: $id and $anchor and $ref", + "schema": { + "$comment": "$id must be evaluated before $ref to get the proper $ref destination", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://example.com/draft2020-12/ref-and-id2/base.json", + "$ref": "#bigint", + "$defs": { + "bigint": { + "$comment": "canonical uri: /ref-and-id2/base.json#/$defs/bigint; another valid uri for this location: /ref-and-id2/base.json#bigint", + "$anchor": "bigint", + "maximum": 10 + }, + "smallint": { + "$comment": "canonical uri: https://example.com/ref-and-id2#/$defs/smallint; another valid uri for this location: https://example.com/ref-and-id2/#bigint", + "$id": "https://example.com/draft2020-12/ref-and-id2/", + "$anchor": "bigint", + "maximum": 2 + } + } + }, + "tests": [ + { + "description": "data is valid against first definition", + "data": 5, + "valid": true + }, + { + "description": "data is invalid against first definition", + "data": 50, + "valid": false + } + ] + }, + { + "description": "simple URN base URI with $ref via the URN", + "schema": { + "$comment": "URIs do not have to have HTTP(s) schemes", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "urn:uuid:deadbeef-1234-ffff-ffff-4321feebdaed", + "minimum": 30, + "properties": { + "foo": {"$ref": "urn:uuid:deadbeef-1234-ffff-ffff-4321feebdaed"} + } + }, + "tests": [ + { + "description": "valid under the URN IDed schema", + "data": {"foo": 37}, + "valid": true + }, + { + "description": "invalid under the URN IDed schema", + "data": {"foo": 12}, + "valid": false + } + ] + }, + { + "description": "simple URN base URI with JSON pointer", + "schema": { + "$comment": "URIs do not have to have HTTP(s) schemes", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "urn:uuid:deadbeef-1234-00ff-ff00-4321feebdaed", + "properties": { + "foo": {"$ref": "#/$defs/bar"} + }, + "$defs": { + "bar": {"type": "string"} + } + }, + "tests": [ + { + "description": "a string is valid", + "data": {"foo": "bar"}, + "valid": true + }, + { + "description": "a non-string is invalid", + "data": {"foo": 12}, + "valid": false + } + ] + }, + { + "description": "URN base URI with NSS", + "schema": { + "$comment": "RFC 8141 §2.2", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "urn:example:1/406/47452/2", + "properties": { + "foo": {"$ref": "#/$defs/bar"} + }, + "$defs": { + "bar": {"type": "string"} + } + }, + "tests": [ + { + "description": "a string is valid", + "data": {"foo": "bar"}, + "valid": true + }, + { + "description": "a non-string is invalid", + "data": {"foo": 12}, + "valid": false + } + ] + }, + { + "description": "URN base URI with r-component", + "schema": { + "$comment": "RFC 8141 §2.3.1", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "urn:example:foo-bar-baz-qux?+CCResolve:cc=uk", + "properties": { + "foo": {"$ref": "#/$defs/bar"} + }, + "$defs": { + "bar": {"type": "string"} + } + }, + "tests": [ + { + "description": "a string is valid", + "data": {"foo": "bar"}, + "valid": true + }, + { + "description": "a non-string is invalid", + "data": {"foo": 12}, + "valid": false + } + ] + }, + { + "description": "URN base URI with q-component", + "schema": { + "$comment": "RFC 8141 §2.3.2", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "urn:example:weather?=op=map&lat=39.56&lon=-104.85&datetime=1969-07-21T02:56:15Z", + "properties": { + "foo": {"$ref": "#/$defs/bar"} + }, + "$defs": { + "bar": {"type": "string"} + } + }, + "tests": [ + { + "description": "a string is valid", + "data": {"foo": "bar"}, + "valid": true + }, + { + "description": "a non-string is invalid", + "data": {"foo": 12}, + "valid": false + } + ] + }, + { + "description": "URN base URI with URN and JSON pointer ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "urn:uuid:deadbeef-1234-0000-0000-4321feebdaed", + "properties": { + "foo": {"$ref": "urn:uuid:deadbeef-1234-0000-0000-4321feebdaed#/$defs/bar"} + }, + "$defs": { + "bar": {"type": "string"} + } + }, + "tests": [ + { + "description": "a string is valid", + "data": {"foo": "bar"}, + "valid": true + }, + { + "description": "a non-string is invalid", + "data": {"foo": 12}, + "valid": false + } + ] + }, + { + "description": "URN base URI with URN and anchor ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "urn:uuid:deadbeef-1234-ff00-00ff-4321feebdaed", + "properties": { + "foo": {"$ref": "urn:uuid:deadbeef-1234-ff00-00ff-4321feebdaed#something"} + }, + "$defs": { + "bar": { + "$anchor": "something", + "type": "string" + } + } + }, + "tests": [ + { + "description": "a string is valid", + "data": {"foo": "bar"}, + "valid": true + }, + { + "description": "a non-string is invalid", + "data": {"foo": 12}, + "valid": false + } + ] + }, + { + "description": "URN ref with nested pointer ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "urn:uuid:deadbeef-4321-ffff-ffff-1234feebdaed", + "$defs": { + "foo": { + "$id": "urn:uuid:deadbeef-4321-ffff-ffff-1234feebdaed", + "$defs": {"bar": {"type": "string"}}, + "$ref": "#/$defs/bar" + } + } + }, + "tests": [ + { + "description": "a string is valid", + "data": "bar", + "valid": true + }, + { + "description": "a non-string is invalid", + "data": 12, + "valid": false + } + ] + }, + { + "description": "ref to if", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://example.com/ref/if", + "if": { + "$id": "http://example.com/ref/if", + "type": "integer" + } + }, + "tests": [ + { + "description": "a non-integer is invalid due to the $ref", + "data": "foo", + "valid": false + }, + { + "description": "an integer is valid", + "data": 12, + "valid": true + } + ] + }, + { + "description": "ref to then", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://example.com/ref/then", + "then": { + "$id": "http://example.com/ref/then", + "type": "integer" + } + }, + "tests": [ + { + "description": "a non-integer is invalid due to the $ref", + "data": "foo", + "valid": false + }, + { + "description": "an integer is valid", + "data": 12, + "valid": true + } + ] + }, + { + "description": "ref to else", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://example.com/ref/else", + "else": { + "$id": "http://example.com/ref/else", + "type": "integer" + } + }, + "tests": [ + { + "description": "a non-integer is invalid due to the $ref", + "data": "foo", + "valid": false + }, + { + "description": "an integer is valid", + "data": 12, + "valid": true + } + ] + }, + { + "description": "ref with absolute-path-reference", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://example.com/ref/absref.json", + "$defs": { + "a": { + "$id": "http://example.com/ref/absref/foobar.json", + "type": "number" + }, + "b": { + "$id": "http://example.com/absref/foobar.json", + "type": "string" + } + }, + "$ref": "/absref/foobar.json" + }, + "tests": [ + { + "description": "a string is valid", + "data": "foo", + "valid": true + }, + { + "description": "an integer is invalid", + "data": 12, + "valid": false + } + ] + }, + { + "description": "$id with file URI still resolves pointers - *nix", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "file:///folder/file.json", + "$defs": { + "foo": { + "type": "number" + } + }, + "$ref": "#/$defs/foo" + }, + "tests": [ + { + "description": "number is valid", + "data": 1, + "valid": true + }, + { + "description": "non-number is invalid", + "data": "a", + "valid": false + } + ] + }, + { + "description": "$id with file URI still resolves pointers - windows", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "file:///c:/folder/file.json", + "$defs": { + "foo": { + "type": "number" + } + }, + "$ref": "#/$defs/foo" + }, + "tests": [ + { + "description": "number is valid", + "data": 1, + "valid": true + }, + { + "description": "non-number is invalid", + "data": "a", + "valid": false + } + ] + }, + { + "description": "empty tokens in $ref json-pointer", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "": { + "$defs": { + "": { "type": "number" } + } + } + }, + "allOf": [ + { + "$ref": "#/$defs//$defs/" + } + ] + }, + "tests": [ + { + "description": "number is valid", + "data": 1, + "valid": true + }, + { + "description": "non-number is invalid", + "data": "a", + "valid": false + } + ] + } +] diff --git a/internal/mcp/jsonschema/util.go b/internal/mcp/jsonschema/util.go index 266a324f338..7e07345f8cc 100644 --- a/internal/mcp/jsonschema/util.go +++ b/internal/mcp/jsonschema/util.go @@ -279,6 +279,6 @@ func jsonType(v reflect.Value) (string, bool) { func assert(cond bool, msg string) { if !cond { - panic(msg) + panic("assertion failed: " + msg) } } diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index 28ab0fb1e4b..b529e232ad5 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -153,6 +153,15 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an } } + var anns annotations // all the annotations for this call and child calls + + // $ref: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.1 + if schema.Ref != "" { + if err := st.validate(instance, schema.resolvedRef, &anns, path); err != nil { + return err + } + } + // logic // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2 // These must happen before arrays and objects because if they evaluate an item or property, @@ -162,8 +171,6 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // If any of these fail, then validation fails, even if there is an unevaluatedXXX // keyword in the schema. The spec is unclear about this, but that is the intention. - var anns annotations // all the annotations for this call and child calls - valid := func(s *Schema, anns *annotations) bool { return st.validate(instance, s, anns, path) == nil } if schema.AllOf != nil { diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index 79cd19e51e3..f8fe929eb0e 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -8,6 +8,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "testing" ) @@ -51,15 +52,18 @@ func TestValidate(t *testing.T) { } for _, g := range groups { t.Run(g.Description, func(t *testing.T) { - rs, err := g.Schema.Resolve("") - if err != nil { - t.Fatal(err) + if strings.Contains(g.Description, "remote ref") { + t.Skip("remote refs not yet supported") } for s := range g.Schema.all() { - if s.Defs != nil || s.Ref != "" { + if s.DynamicAnchor != "" || s.DynamicRef != "" { t.Skip("schema or subschema has unimplemented keywords") } } + rs, err := g.Schema.Resolve("", nil) + if err != nil { + t.Fatal(err) + } for _, test := range g.Tests { t.Run(test.Description, func(t *testing.T) { err = rs.Validate(test.Data) @@ -71,7 +75,7 @@ func TestValidate(t *testing.T) { } if t.Failed() { t.Errorf("schema: %s", g.Schema.json()) - t.Fatalf("instance: %v", test.Data) + t.Fatalf("instance: %v (%[1]T)", test.Data) } }) } @@ -102,7 +106,7 @@ func TestStructInstance(t *testing.T) { {DependentSchemas: map[string]*Schema{"b": falseSchema()}}, {UnevaluatedProperties: falseSchema()}, } { - res, err := schema.Resolve("") + res, err := schema.Resolve("", nil) if err != nil { t.Fatal(err) } From 3d893350aea4008bc6b87ef367398e977c977480 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 8 May 2025 15:12:40 +0000 Subject: [PATCH 022/196] internal/mcp: a new extensibility point for HTTP transports Expose the SSEServerTransport type, so that clients can customize the handling of sessions and endpoints by writing their own HTTP handler. This is similar to the behavior of the typescript SDK. Also, finish documenting transports in the design doc. Change-Id: Ief5aa8424ba2946a9615e366969a4750f40224a8 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671015 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam Auto-Submit: Robert Findley --- internal/mcp/design/design.md | 83 +++++++++++++-- internal/mcp/sse.go | 189 ++++++++++++++++++++-------------- 2 files changed, 188 insertions(+), 84 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 6b6a29ebb59..6b76e5d0b01 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -128,7 +128,7 @@ type CommandTransport struct { /* unexported fields */ } func NewCommandTransport(cmd *exec.Command) *CommandTransport // Connect starts the command, and connects to it over stdin/stdout. -func (t *CommandTransport) Connect(ctx context.Context) (Stream, error) { +func (*CommandTransport) Connect(ctx context.Context) (Stream, error) { ``` The `StdIOTransport` is the server side of the stdio transport, and connects by @@ -139,7 +139,7 @@ binding to `os.Stdin` and `os.Stdout`. // JSON over stdin/stdout. type StdIOTransport struct { /* unexported fields */ } -func NewStdIOTransport() *StdIOTransport { +func NewStdIOTransport() *StdIOTransport func (t *StdIOTransport) Connect(context.Context) (Stream, error) ``` @@ -174,13 +174,80 @@ Notably absent are options to hook into the request handling for the purposes of authentication or context injection. These concerns are better handled using standard HTTP middleware patterns. - +By default, the SSE handler creates messages endpoints with the +`?sessionId=...` query parameter. Users that want more control over the +management of sessions and session endpoints may write their own handler, and +create `SSEServerTransport` instances themselves, for incoming GET requests. + +```go +// A SSEServerTransport is a logical SSE session created through a hanging GET +// request. +// +// When connected, it it returns the following [Stream] implementation: +// - Writes are SSE 'message' events to the GET response. +// - Reads are received from POSTs to the session endpoint, via +// [SSEServerTransport.ServeHTTP]. +// - Close terminates the hanging GET. +type SSEServerTransport struct { /* ... */ } + +// NewSSEServerTransport creates a new SSE transport for the given messages +// endpoint, and hanging GET response. +// +// Use [SSEServerTransport.Connect] to initiate the flow of messages. +// +// The transport is itself an [http.Handler]. It is the caller's responsibility +// to ensure that the resulting transport serves HTTP requests on the given +// session endpoint. +func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTransport + +// ServeHTTP handles POST requests to the transport endpoint. +func (*SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) + +// Connect sends the 'endpoint' event to the client. +// See [SSEServerTransport] for more details on the [Stream] implementation. +func (*SSEServerTransport) Connect(context.Context) (Stream, error) +``` - +The SSE client transport is simpler, and hopefully self-explanatory. + +```go +type SSEClientTransport struct { /* ... */ } + +// NewSSEClientTransport returns a new client transport that connects to the +// SSE server at the provided URL. +// +// NewSSEClientTransport panics if the given URL is invalid. +func NewSSEClientTransport(url string) *SSEClientTransport { + +// Connect connects through the client endpoint. +func (*SSEClientTransport) Connect(ctx context.Context) (Stream, error) +``` + +The Streamable HTTP transports are similar to the SSE transport, albeit with a +more complicated implementation. For brevity, we summarize only the differences +from the equivalent SSE types: + +```go +// The StreamableHandler interface is symmetrical to the SSEHandler. +type StreamableHandler struct { /* unexported fields */ } +func NewStreamableHandler(getServer func(request *http.Request) *Server) *StreamableHandler +func (*StreamableHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) +func (*StreamableHandler) Close() error + +// Unlike the SSE transport, the streamable transport constructor accepts a +// session ID, not an endpoint, along with the http response for the request +// that created the session. It is the caller's responsibility to delegate +// requests to this session. +type StreamableServerTransport struct { /* ... */ } +func NewStreamableServerTransport(sessionID string, w http.ResponseWriter) *StreamableServerTransport +func (*StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) +func (*StreamableServerTransport) Connect(context.Context) (Stream, error) + +// The streamable client handles reconnection transparently to the user. +type StreamableClientTransport struct { /* ... */ } +func NewStreamableClientTransport(url string) *StreamableClientTransport { +func (*StreamableClientTransport) Connect(context.Context) (Stream, error) +``` ### Protocol types diff --git a/internal/mcp/sse.go b/internal/mcp/sse.go index 511d4bca70f..c2acdd5c201 100644 --- a/internal/mcp/sse.go +++ b/internal/mcp/sse.go @@ -67,27 +67,38 @@ type SSEHandler struct { onConnection func(*ServerConnection) // for testing; must not block mu sync.Mutex - sessions map[string]*sseSession + sessions map[string]*SSEServerTransport } -// NewSSEHandler returns a new [SSEHandler] that is ready to serve HTTP. +// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP +// sessions created via incoming HTTP requests. // -// The getServer function is used to bind created servers for new sessions. It -// is OK for getServer to return the same server multiple times. +// Sessions are created when the client issues a GET request to the server, +// which must accept text/event-stream responses (server-sent events). +// For each such request, a new [SSEServerTransport] is created with a distinct +// messages endpoint, and connected to the server returned by getServer. It is +// up to the user whether getServer returns a distinct [Server] for each new +// request, or reuses an existing server. +// +// The SSEHandler also handles requests to the message endpoints, by +// delegating them to the relevant server transport. func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { return &SSEHandler{ getServer: getServer, - sessions: make(map[string]*sseSession), + sessions: make(map[string]*SSEServerTransport), } } -// A sseSession is a logical jsonrpc2 stream implementing the server side of -// MCP SSE transport, initiated through the hanging GET -// - Writes are SSE 'message' events to the GET response body. -// - Reads are received from POSTs to the session endpoing, mediated through a -// buffered channel. +// A SSEServerTransport is a logical SSE session created through a hanging GET +// request. +// +// When connected, it returns the following [Stream] implementation: +// - Writes are SSE 'message' events to the GET response. +// - Reads are received from POSTs to the session endpoint, via +// [SSEServerTransport.ServeHTTP]. // - Close terminates the hanging GET. -type sseSession struct { +type SSEServerTransport struct { + endpoint string incoming chan jsonrpc2.Message // queue of incoming messages; never closed // We must guard both pushes to the incoming queue and writes to the response @@ -100,9 +111,63 @@ type sseSession struct { done chan struct{} // closed when the stream is closed } -// Connect returns the receiver, as an sseSession is a logical stream. -func (s *sseSession) Connect(context.Context) (Stream, error) { - return s, nil +// NewSSEServerTransport creates a new SSE transport for the given messages +// endpoint, and hanging GET response. +// +// Use [SSEServerTransport.Connect] to initiate the flow of messages. +// +// The transport is itself an [http.Handler]. It is the caller's responsibility +// to ensure that the resulting transport serves HTTP requests on the given +// session endpoint. +// +// Most callers should instead use an [SSEHandler], which transparently handles +// the delegation to SSEServerTransports. +func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTransport { + return &SSEServerTransport{ + endpoint: endpoint, + w: w, + incoming: make(chan jsonrpc2.Message, 100), + done: make(chan struct{}), + } +} + +// ServeHTTP handles POST requests to the transport endpoint. +func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Read and parse the message. + data, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + // Optionally, we could just push the data onto a channel, and let the + // message fail to parse when it is read. This failure seems a bit more + // useful + msg, err := jsonrpc2.DecodeMessage(data) + if err != nil { + http.Error(w, "failed to parse body", http.StatusBadRequest) + return + } + select { + case t.incoming <- msg: + w.WriteHeader(http.StatusAccepted) + case <-t.done: + http.Error(w, "session closed", http.StatusBadRequest) + } +} + +// Connect sends the 'endpoint' event to the client. +// See [SSEServerTransport] for more details on the [Stream] implementation. +func (t *SSEServerTransport) Connect(context.Context) (Stream, error) { + t.mu.Lock() + _, err := writeEvent(t.w, event{ + name: "endpoint", + data: []byte(t.endpoint), + }) + t.mu.Unlock() + if err != nil { + return nil, err + } + return sseServerStream{t}, nil } func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -125,26 +190,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - // Read and parse the message. - data, err := io.ReadAll(req.Body) - if err != nil { - http.Error(w, "failed to read body", http.StatusBadRequest) - return - } - // Optionally, we could just push the data onto a channel, and let the - // message fail to parse when it is read. This failure seems a bit more - // useful - msg, err := jsonrpc2.DecodeMessage(data) - if err != nil { - http.Error(w, "failed to parse body", http.StatusBadRequest) - return - } - select { - case session.incoming <- msg: - w.WriteHeader(http.StatusAccepted) - case <-session.done: - http.Error(w, "session closed", http.StatusBadRequest) - } + session.ServeHTTP(w, req) return } @@ -163,28 +209,17 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { w.Header().Set("Connection", "keep-alive") sessionID = randText() - session := &sseSession{ - w: w, - incoming: make(chan jsonrpc2.Message, 100), - done: make(chan struct{}), - } - defer session.Close() - - // TODO(hxjiang): getServer returns nil will panic. - server := h.getServer(req) - cc, err := server.Connect(req.Context(), session, nil) + endpoint, err := req.URL.Parse("?sessionid=" + sessionID) if err != nil { - http.Error(w, "connection failed", http.StatusInternalServerError) + http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError) return } - if h.onConnection != nil { - h.onConnection(cc) - } - defer cc.Close() // close the transport when the GET exits + + transport := NewSSEServerTransport(endpoint.RequestURI(), w) // The session is terminated when the request exits. h.mu.Lock() - h.sessions[sessionID] = session + h.sessions[sessionID] = transport h.mu.Unlock() defer func() { h.mu.Lock() @@ -192,42 +227,44 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { h.mu.Unlock() }() - endpoint, err := req.URL.Parse("?sessionid=" + sessionID) + // TODO(hxjiang): getServer returns nil will panic. + server := h.getServer(req) + cc, err := server.Connect(req.Context(), transport, nil) if err != nil { - http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError) + http.Error(w, "connection failed", http.StatusInternalServerError) return } - - session.mu.Lock() - _, err = writeEvent(w, event{ - name: "endpoint", - data: []byte(endpoint.RequestURI()), - }) - session.mu.Unlock() - if err != nil { - return // too late to write the status header + if h.onConnection != nil { + h.onConnection(cc) } + defer cc.Close() // close the transport when the GET exits select { case <-req.Context().Done(): - case <-session.done: + case <-transport.done: } } +// sseServerStream implements the Stream interface for a single [SSEServerTransport]. +// It hides the Stream interface from the SSEServerTransport API. +type sseServerStream struct { + t *SSEServerTransport +} + // Read implements jsonrpc2.Reader. -func (s *sseSession) Read(ctx context.Context) (jsonrpc2.Message, int64, error) { +func (s sseServerStream) Read(ctx context.Context) (jsonrpc2.Message, int64, error) { select { case <-ctx.Done(): return nil, 0, ctx.Err() - case msg := <-s.incoming: + case msg := <-s.t.incoming: return msg, 0, nil - case <-s.done: + case <-s.t.done: return nil, 0, io.EOF } } // Write implements jsonrpc2.Writer. -func (s *sseSession) Write(ctx context.Context, msg jsonrpc2.Message) (int64, error) { +func (s sseServerStream) Write(ctx context.Context, msg jsonrpc2.Message) (int64, error) { if ctx.Err() != nil { return 0, ctx.Err() } @@ -237,17 +274,17 @@ func (s *sseSession) Write(ctx context.Context, msg jsonrpc2.Message) (int64, er return 0, err } - s.mu.Lock() - defer s.mu.Unlock() + s.t.mu.Lock() + defer s.t.mu.Unlock() // Note that it is invalid to write to a ResponseWriter after ServeHTTP has // exited, and so we must lock around this write and check isDone, which is // set before the hanging GET exits. - if s.closed { + if s.t.closed { return 0, io.EOF } - n, err := writeEvent(s.w, event{name: "message", data: data}) + n, err := writeEvent(s.t.w, event{name: "message", data: data}) return int64(n), err } @@ -256,12 +293,12 @@ func (s *sseSession) Write(ctx context.Context, msg jsonrpc2.Message) (int64, er // It must be safe to call Close more than once, as the close may // asynchronously be initiated by either the server closing its connection, or // by the hanging GET exiting. -func (s *sseSession) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - if !s.closed { - s.closed = true - close(s.done) +func (s sseServerStream) Close() error { + s.t.mu.Lock() + defer s.t.mu.Unlock() + if !s.t.closed { + s.t.closed = true + close(s.t.done) } return nil } From 2587caa966c93bd847b7950d4c7cc5fdae4b08c8 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 8 May 2025 15:41:20 +0000 Subject: [PATCH 023/196] internal/mcp/design: discuss generated protocol types Discuss generated protocol types, and expand on the section on package layout to make it clearer how they fit in to the SDK. Change-Id: I3beaa86c2396e0cee41dcc8eac06d93c54cb96df Reviewed-on: https://go-review.googlesource.com/c/tools/+/671016 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/mcp/design/design.md | 71 +++++++++++++++++++++++++++++++- internal/mcp/protocol/content.go | 10 ++++- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 6b76e5d0b01..3c8e48847d3 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -56,6 +56,23 @@ SDKs, but is consistent with Go packages like `net/http`, `net/rpc`, or Functionality that is not directly related to MCP (like jsonschema or jsonrpc2) belongs in a separate package. +Therefore, this is the package layout. `module.path` is a placeholder for the +final module path of the mcp module + +- `module.path/mcp`: the bulk of the user facing API +- `module.path/mcp/protocol`: generated types for the MCP spec. +- `module.path/jsonschema`: a jsonschema implementation, with validation +- `module.path/internal/jsonrpc2`: a fork of x/tools/internal/jsonrpc2_v2 + +For now, this layout assumes we want to separate the 'protocol' types from the +'mcp' package, since they won't be needed by most users. It is unclear whether +this is worthwhile. + +The JSON-RPC implementation is hidden, to avoid tight coupling. As described in +the next section, the only aspects of JSON-RPC that need to be exposed in the +SDK are the message types, for the purposes of defining custom transports. We +can expose these types from the `mcp` package via aliases or wrappers. + ### jsonrpc2 and Transports The MCP is defined in terms of client-server communication over bidirectional @@ -251,7 +268,59 @@ func (*StreamableClientTransport) Connect(context.Context) (Stream, error) ### Protocol types - +As described in the section on package layout above, the `protocol` package +will contain definitions of types referenced by the MCP spec that are needed +for the SDK. JSON-RPC message types are elided, since they are handled by the +`jsonrpc2` package and should not be observed by the user. The user interacts +only with the params/result types relevant to MCP operations. + +For user-provided data, use `json.RawMessage`, so that +marshalling/unmarshalling can be delegated to the business logic of the client +or server. + +For union types, which can't be represented in Go (specifically `Content` and +`Resource`), we prefer distinguished unions: struct types with fields +corresponding to the union of all properties for union elements. + +These types will be auto-generated from the [JSON schema of the MCP +spec](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json). +For brevity, only a few examples are shown here: + +```go +type CallToolParams struct { + Arguments map[string]json.RawMessage `json:"arguments,omitempty"` + Name string `json:"name"` +} + +type CallToolResult struct { + Meta map[string]json.RawMessage `json:"_meta,omitempty"` + Content []Content `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// Content is the wire format for content. +// +// The Type field distinguishes the type of the content. +// At most one of Text, MIMEType, Data, and Resource is non-zero. +type Content struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Data string `json:"data,omitempty"` + Resource *Resource `json:"resource,omitempty"` +} + +// Resource is the wire format for embedded resources. +// +// The URI field describes the resource location. At most one of Text and Blob +// is non-zero. +type Resource struct { + URI string `json:"uri,"` + MIMEType string `json:"mimeType,omitempty"` + Text string `json:"text"` + Blob *string `json:"blob"` +} +``` ### Clients and Servers diff --git a/internal/mcp/protocol/content.go b/internal/mcp/protocol/content.go index 5374b62488b..d5fc16f894d 100644 --- a/internal/mcp/protocol/content.go +++ b/internal/mcp/protocol/content.go @@ -9,7 +9,10 @@ import ( "fmt" ) -// Content is the wire format for content, including all fields. +// Content is the wire format for content. +// +// The Type field distinguishes the type of the content. +// At most one of Text, MIMEType, Data, and Resource is non-zero. type Content struct { Type string `json:"type"` Text string `json:"text,omitempty"` @@ -18,7 +21,10 @@ type Content struct { Resource *Resource `json:"resource,omitempty"` } -// Resource is the wire format for embedded resources, including all fields. +// Resource is the wire format for embedded resources. +// +// The URI field describes the resource location. At most one of Text and Blob +// is non-zero. type Resource struct { URI string `json:"uri,"` MIMEType string `json:"mimeType,omitempty"` From 1ead56f2bffa3c26f71b09e673381f7522a0198b Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 8 May 2025 08:36:10 -0400 Subject: [PATCH 024/196] internal/mcp/design: describe mcp-go delta Describe the differences with mcp-go, and justify our choices. Change-Id: Ib9b7ed87b8038f0b7cbc0eeffc5d7a7075eeae30 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671017 Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 145 +++++++++++++++++++++++++++++----- 1 file changed, 125 insertions(+), 20 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 3c8e48847d3..f5d63342391 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -428,31 +428,136 @@ follows: -## Differences with mcp-go +## Differences with mark3labs/mcp-go + +The most popular MCP module for Go is [mark3labs/mcp-go](https://pkg.go.dev/github.com/ +mark3labs/mcp-go). +As of this writing, it is imported by over 400 packages that span over 200 modules. + +We admire mcp-go, and seriously considered simply adopting it as a starting +point for this SDK. However, as we looked at doing so, we realized that a +significant amount of its API would probably need to change. In some cases, +mcp-go has older APIs that predated newer variations--an obvious opportunity +for cleanup. In others, it took a batteries-included approach that is probably +not viable for an official SDK. In yet others, we simply think there is room for +API refinement, and we should take this opportunity to reconsider. Therefore, +we wrote this SDK design from the perspective of a new implementation. +Nevertheless, much of the API discussed here originated from or was inspired +by mcp-go and other unofficial SDKs, and if the consensus of this discussion is +close enough to mcp-go or any other unofficial SDK, we can start from a fork. + +Although our API is not compatible with mcp-go's, translating between them should be +straightforward in most cases. +(Later, we will provide a detailed translation guide.) + +### Packages -The most popular MCP package for Go is [mcp-go](https://pkg.go.dev/github.com/ -mark3labs/mcp-go). While we admire the thoughfulness of its design and the high -quality of its implementation, we made different choices. Although the APIs are -not compatible, translating between them is straightforward. (Later, we will -provide a detailed translation guide.) +As we mentioned above, we decided to put most of the API into a single package. +Our `mcp` package includes all the functionality of mcp-go's `mcp`, `client`, +`server` and `transport` packages, but is smaller than the `mcp` package alone. -## Packages +### Typed tool inputs -As we mentioned above, we decided to put most of the API into a single package. -The exceptions are the JSON-RPC layer, the JSON Schema implementation, and the -parts of the MCP protocol that users don't need. The resulting `mcp` includes -all the functionality of mcp-go's `mcp`, `client`, `server` and `transport` -packages, but is smaller than the `mcp` package alone. +We provide a way to supply a struct as the input type of a Tool. +For example, a tool input with a required "name" parameter and an optional "size" parameter +could be be described by: +``` +type input struct { + Name string `json:"name"` + Size int `json:"size,omitempty"` +} +``` + +The tool handler receives a value of this struct instead of a `map[string]any`, +so it doesn't need to parse its input parameters. Also, we infer the input schema +from the struct, avoiding the need to specify the name, type and required status of +parameters. + +### Schema validation + +We provide a full JSON Schema implementation for validating tool input schemas against +incoming arguments. The `jsonschema.Schema` type provides exported features for all +keywords in the JSON Schema draft2020-12 spec. Tool definers can use it to construct +any schema they want, so there is no need to provide options for all of them. +When combined with schema inference from input structs, +we found that we needed only three options to cover the common cases, +instead of mcp-go's 23. For example, we provide `Enum`, which occurs 125 times in open source +code, but not MinItems, MinLength or MinProperties, which each occur only once (and in an SDK +that wraps mcp-go). + +Moreover, our options can be used to build nested schemas, while +mcp-go's work only at top level. That limitation is visible in +[this code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), +which must resort to untyped maps to express a nested schema: +``` +mcp.WithArray("items", + mcp.Description("Checklist items of the task"), + mcp.Items(map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{ + "type": "string", + "description": "Unique identifier of the checklist item", + }, + "status": map[string]any{ + "type": "number", + "description": "Status of the checklist item (0: normal, 1: completed)", + "enum": []float64{0, 1}, + }, + ... +``` + +### JSON-RPC implementation -## Hooks +The Go team has a battle-tested JSON-RPC implementation that we use for gopls, our +Go LSP server. We are using the new version of this library as part of our MCP SDK. +It handles all JSON-RPC 2.0 features, including cancellation. + +### Hooks Version 0.26.0 of mcp-go defines 24 server hooks. Each hook consists of a field in the `Hooks` struct, a `Hooks.Add` method, and a type for the hook function. -As described above, these can be replaced by middleware. We -don't define any middleware types at present, but will do so if there is demand. -(We're minimalists, not savages.) +These are rarely used. The most common is `OnError`, which occurs fewer than ten +times in open-source code. + +All of the hooks run before or after the server processes a message, +so instead we provide a single way to intercept this message handling, using +two exported names instead of 72: +``` +// A Handler handles an MCP message call. +type Handler func(ctx context.Context, c *ServerConnection, method string, params any) (response any, err error) + + +// AddMiddleware calls each middleware function from right to left on the previous result, beginning +// with the server's current handler, and installs the result as the new handler. +func (*Server) AddMiddleware(middleware ...func(Handler) Handler)) +``` + +As an example, this code adds server-side logging: +``` + +func withLogging(h mcp.Handler) mcp.Handler { + return func(ctx context.Context, c *mcp.ServerConnection, method string, params any) (res any, err error) { + log.Printf("request: %s %v", method, params) + defer func() { log.Printf("response: %v, %v", res, err) }() + return h(ctx, c , method, params) + } +} + +server.AddMiddleware(withLogging) +``` + +### Options + +In Go, the two most common ways to provide options to a function are option structs (for example, +https://pkg.go.dev/net/http#PushOptions) and +variadic option functions. mcp-go uses option functions exclusively. For example, +the `server.NewMCPServer` function has ten associated functions to provide options. +Our API uses both, depending on the context. We use function options for +constructing tools, where they are most convenient. In most other places, we +prefer structs because they have a smaller API footprint and are less verbose. -## Servers +### Servers In mcp-go, server authors create an `MCPServer`, populate it with tools, resources and so on, and then wrap it in an `SSEServer` or `StdioServer`. These @@ -461,6 +566,6 @@ with `RegisterSession` and `UnregisterSession`. We find the similarity in names among the three server types to be confusing, and we could not discover any uses of the session methods in the open-source -ecosystem. In our design is similar, server authors create a `Server`, and then -connect it to a `Transport` or SSE handler. We manage multiple web clients for a -single server using session IDs internally, but do not expose them. +ecosystem. In our design, server authors create a `Server`, and then +connect it to a `Transport`. An `SSEHandler` manages sessions for +incoming SSE connections, but does not expose them. From c89ad19f28595e890358cb22e171f0c964a66b9a Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 8 May 2025 16:12:37 -0400 Subject: [PATCH 025/196] internal/mcp: JSON Schema design Describe how we use JSON Schema. Change-Id: I488a09854333ba7280401e3bed5f50f83f481cd2 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671215 Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 58 +++++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index f5d63342391..b8e27280eff 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -376,12 +376,46 @@ Server.RemoveTools #### JSON Schema - +A tool's input schema, expressed as a [JSON Schema](https://json-schema.org), +provides a way to validate the tool's input. + +We chose a hybrid a approach to specifying the schema, combining reflection +and variadic options. We found that this makes the common cases easy (sometimes +free!) to express and keeps the API small. The most recent JSON Schema +spec defines over 40 keywords. Providing them all as options would bloat +the API despite the fact that most would be very rarely used. Our approach +also guarantees that the input schema is compatible with tool parameters, by +construction. + +`NewTool` determines the input schema for a Tool from the struct used +in the handler. Each struct field that would be marshaled by `encoding/json.Marshal` +becomes a property of the schema. The property is required unless +the field's `json` tag specifies "omitempty" or "omitzero" (new in Go 1.24). +For example, given this struct: +``` +struct { + Name string `json:"name"` + Count int `json:"count,omitempty"` + Choices []string + Password []byte `json:"-"` +} +``` +"name" and "Choices" are required, while "count" is optional. + +The struct provides the names, types and required status of the properties. +Other JSON Schema keywords can be specified by passing options to `NewTool`: +``` +NewTool(name, description, handler, + Property("count", Description("size of the inventory"))) +``` + +For less common keywords, use the `Schema` option: +``` +NewTool(name, description, handler, + Property("Choices", Schema(&jsonschema.Schema{UniqueItems: true}))) +``` + +Schemas are validated on the server before the tool handler is called. ### Prompts @@ -458,16 +492,8 @@ Our `mcp` package includes all the functionality of mcp-go's `mcp`, `client`, ### Typed tool inputs -We provide a way to supply a struct as the input type of a Tool. -For example, a tool input with a required "name" parameter and an optional "size" parameter -could be be described by: -``` -type input struct { - Name string `json:"name"` - Size int `json:"size,omitempty"` -} -``` - +We provide a way to supply a struct as the input type of a Tool, as described +in [JSON Schema](#JSON_Schema), above. The tool handler receives a value of this struct instead of a `map[string]any`, so it doesn't need to parse its input parameters. Also, we infer the input schema from the struct, avoiding the need to specify the name, type and required status of From d60d930447d9a00d4cc43a5882dede21939690e7 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 8 May 2025 17:25:21 -0400 Subject: [PATCH 026/196] internal/mcp: tool section of design Describe our tool design. This is short because most of the interesting part is in the JSON Schema section. Change-Id: I017fca3055aef881285280a379e9f235bf95e75d Reviewed-on: https://go-review.googlesource.com/c/tools/+/671355 Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index b8e27280eff..d7a205deead 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -366,15 +366,29 @@ Client.RemoveRoots ### Tools - +Add tools to a server with `AddTools`: +``` +server.AddTools( + mcp.NewTool("add", "add numbers", addHandler), + mcp.NewTools("subtract, subtract numbers", subHandler)) +``` +Remove them with `RemoveTools`: +``` + server.RemoveTools("add", "subtract") +``` + +We provide a convenient and type-safe way to construct a Tool: + +``` +// NewTool is a creates a Tool using reflection on the given handler. +func NewTool[TReq any](name, description string, handler func(context.Context, TReq) ([]Content, error), opts …ToolOption) *Tool +``` -#### JSON Schema +The `TReq` type is typically a struct, and we use reflection on the struct to +determine the names and types of the tool's input. `ToolOption`s allow further +customization of a Tool's input schema. +Since all the fields of the Tool struct are exported, a Tool can also be created +directly with assignment or a struct literal. A tool's input schema, expressed as a [JSON Schema](https://json-schema.org), provides a way to validate the tool's input. @@ -406,13 +420,13 @@ The struct provides the names, types and required status of the properties. Other JSON Schema keywords can be specified by passing options to `NewTool`: ``` NewTool(name, description, handler, - Property("count", Description("size of the inventory"))) + Input(Property("count", Description("size of the inventory")))) ``` For less common keywords, use the `Schema` option: ``` NewTool(name, description, handler, - Property("Choices", Schema(&jsonschema.Schema{UniqueItems: true}))) + Input(Property("Choices", Schema(&jsonschema.Schema{UniqueItems: true})))) ``` Schemas are validated on the server before the tool handler is called. From 4160b77507a8173cbd2dd5a8172d0a2502292e75 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 9 May 2025 08:15:27 -0400 Subject: [PATCH 027/196] internal/mcp/design.md: client features Change-Id: Iaed3461bcb35ca65c3118cafa0a5bbddfbd41ecc Reviewed-on: https://go-review.googlesource.com/c/tools/+/671375 Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 46 ++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index d7a205deead..cd4323441d5 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -352,15 +352,49 @@ parameterizing automatic keepalive. ### Roots - +Clients support the MCP Roots feature out of the box, including roots-changed notifications. +Roots can be added and removed from a `Client` with `AddRoots` and `RemoveRoots`: + +``` +// AddRoots adds the roots to the client's list of roots. +// If the list changes, the client notifies the server. +// If a root does not begin with a valid URI schema such as "https://" or "file://", +// it is intepreted as a directory path on the local filesystem. +func (*Client) AddRoots(roots ...string) + +// RemoveRoots removes the given roots from the client's list, and notifies +// the server if the list has changed. +// It is not an error to remove a nonexistent root. +func (*Client) RemoveRoots(roots ...string) +``` + +Servers can call `ListRoots` to get the roots. +If a server installs a `RootsChangedHandler`, it will be called when the client sends a +roots-changed notification, which happens whenever the list of roots changes after a +connection has been established. +``` +func (*Server) ListRoots(context.Context, *ListRootsParams) (*ListRootsResult, error) + +type ServerOptions { + ... + // If non-nil, called when a client sends a roots-changed notification. + RootsChangedHandler func(context.Context, *ServerConnection, *RootsChangedParams) +} +``` ### Sampling - +Clients that support sampling are created with a `CreateMessageHandler` option for handling server +calls. +To perform sampling, a server calls `CreateMessage`. +``` +type ClientOptions struct { + ... + CreateMessageHandler func(context.Context, *CreateMessageParams) (*CreateMessageResult, error) +} + +func (*Server) CreateMessage(context.Context, *CreateMessageParams) (*CreateMessageResult, error) +``` ## Server Features From 8ee3f58c755bff021add0c244ae8d62c7bf6356f Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 9 May 2025 08:54:00 -0400 Subject: [PATCH 028/196] internal/mcp/design.md: completions Design for the completions feature. Change-Id: Ia793091a10adb9c51472a9d3d4acdce8d6932536 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671376 Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index cd4323441d5..d73d68aa5d5 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -485,7 +485,10 @@ Server.RemoveResources ### Completion - +Clients call `Complete` to request completions. + +Servers automatically handle these requests based on their collections of +prompts and resources. ### Logging From ef35d724cb82c7da6a71e4dca5c26cbcfce5e0c2 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 8 May 2025 18:52:45 +0000 Subject: [PATCH 029/196] internal/mcp/design: add section on clients and servers After much deliberation, revert back to the symmetrical many-to-one form of the client APIs. Every time I tried to write the asymmetrical form, I couldn't justify it. Furthermore, I think the names 'ClientSession' and 'ServerSession' significantly clarify the meaning. Change-Id: I5c5e1eaf0d5de5ac7c98dab7aad9c0f31c3ff4e1 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671177 Auto-Submit: Robert Findley Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/mcp/design/design.md | 104 ++++++++++++++++++++++++++++++++-- 1 file changed, 98 insertions(+), 6 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index d73d68aa5d5..10e2fb3e41c 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -79,7 +79,7 @@ The MCP is defined in terms of client-server communication over bidirectional JSON-RPC message streams. Specifically, version `2025-03-26` of the spec defines two transports: -- ****: communication with a subprocess over stdin/stdout. +- **stdio**: communication with a subprocess over stdin/stdout. - **streamable http**: communication over a relatively complicated series of text/event-stream GET and HTTP POST requests. @@ -324,10 +324,95 @@ type Resource struct { ### Clients and Servers - +Generally speaking, the SDK is used by creating a `Client` or `Server` +instance, adding features to it, and connecting it to a peer. + +However, the SDK must make a non-obvious choice in these APIs: are clients 1:1 +with their logical connections? What about servers? Both clients and servers +are stateful: users may add or remove roots from clients, and tools, prompts, +and resources from servers. Additionally, handlers for these features may +themselves be stateful, for example if a tool handler caches state from earlier +requests in the session. + +We believe that in the common case, both clients and servers are stateless, and +it is therefore more useful to allow multiple connections from a client, and to +a server. This is similar to the `net/http` packages, in which an `http.Client` +and `http.Server` each may handle multiple unrelated connections. When users +add features to a client or server, all connected peers are notified of the +change in feature-set. + +Following the terminology of the +[spec](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#session-management), +we call the logical connection between a client and server a "session". There +must necessarily be a `ClientSession` and a `ServerSession`, corresponding to +the APIs available from the client and server perspective, respectively. + +``` +Client Server + ⇅ (jsonrpc2) ⇅ +ClientSession ⇄ Client Transport ⇄ Server Transport ⇄ ServerSession +``` + +Sessions are created from either `Client` or `Server` using the `Connect` +method. + +```go +type Client struct { /* ... */ } +func NewClient(name, version string, opts *ClientOptions) *Client +func (c *Client) Connect(context.Context, Transport) (*ClientSession, error) +// Methods for adding/removing client features are described below. + +type ClientSession struct { /* ... */ } +func (*ClientSession) Close() error +func (*ClientSession) Wait() error +// Methods for calling through the ClientSession are described below. + +type Server struct { /* ... */ } +func NewServer(name, version string, opts *ServerOptions) *Server +func (s *Server) Connect(context.Context, Transport) (*ServerSession, error) +// Methods for adding/removing server features are described below. + +type ServerSession struct { /* ... */ } +func (*ServerSession) Close() error +func (*ServerSession) Wait() error +// Methods for calling through the ServerSession are described below. +``` + +Here's an example of these API from the client side: + +```go +client := mcp.NewClient("mcp-client", "v1.0.0", nil) +// Connect to a server over stdin/stdout +transport := mcp.NewCommandTransport(exec.Command("myserver")) +session, err := client.Connect(ctx, transport) +if err != nil { + log.Fatal(err) +} +// Call a tool on the server. +content, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}) +... +return session.Close() +``` + +And here's an example from the server side: + +```go +// Create a server with a single tool. +server := mcp.NewServer("greeter", "v1.0.0", nil) +server.AddTool(mcp.NewTool("greet", "say hi", SayHi)) +// Run the server over stdin/stdout, until the client disconnects. +transport := mcp.NewStdIOTransport() +session, err := server.Connect(ctx, transport) +... +return session.Wait() +``` + +For convenience, we provide `Server.Run` to handle the common case of running a +session until the client disconnects: + +```go +func (*Server) Run(context.Context, Transport) +``` ### Errors @@ -429,7 +514,7 @@ provides a way to validate the tool's input. We chose a hybrid a approach to specifying the schema, combining reflection and variadic options. We found that this makes the common cases easy (sometimes -free!) to express and keeps the API small. The most recent JSON Schema +free!) to express and keeps the API small. The most recent JSON Schema spec defines over 40 keywords. Providing them all as options would bloat the API despite the fact that most would be very rarely used. Our approach also guarantees that the input schema is compatible with tool parameters, by @@ -440,6 +525,7 @@ in the handler. Each struct field that would be marshaled by `encoding/json.Mars becomes a property of the schema. The property is required unless the field's `json` tag specifies "omitempty" or "omitzero" (new in Go 1.24). For example, given this struct: + ``` struct { Name string `json:"name"` @@ -448,16 +534,19 @@ struct { Password []byte `json:"-"` } ``` + "name" and "Choices" are required, while "count" is optional. The struct provides the names, types and required status of the properties. Other JSON Schema keywords can be specified by passing options to `NewTool`: + ``` NewTool(name, description, handler, Input(Property("count", Description("size of the inventory")))) ``` For less common keywords, use the `Schema` option: + ``` NewTool(name, description, handler, Input(Property("Choices", Schema(&jsonschema.Schema{UniqueItems: true})))) @@ -566,6 +655,7 @@ Moreover, our options can be used to build nested schemas, while mcp-go's work only at top level. That limitation is visible in [this code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), which must resort to untyped maps to express a nested schema: + ``` mcp.WithArray("items", mcp.Description("Checklist items of the task"), @@ -600,6 +690,7 @@ times in open-source code. All of the hooks run before or after the server processes a message, so instead we provide a single way to intercept this message handling, using two exported names instead of 72: + ``` // A Handler handles an MCP message call. type Handler func(ctx context.Context, c *ServerConnection, method string, params any) (response any, err error) @@ -611,6 +702,7 @@ func (*Server) AddMiddleware(middleware ...func(Handler) Handler)) ``` As an example, this code adds server-side logging: + ``` func withLogging(h mcp.Handler) mcp.Handler { From b61ab3318cce58492e8fa88a00521b2ba4870b03 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 9 May 2025 14:47:48 +0000 Subject: [PATCH 030/196] internal/mcp/design: add a section on cancellation Change-Id: If18a7d8efa37a6f34ae42f4d1896464800ebcdc7 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671435 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/mcp/design/design.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 10e2fb3e41c..c4b663762ef 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -420,7 +420,20 @@ func (*Server) Run(context.Context, Transport) ### Cancellation - +Cancellation is implemented transparently using context cancellation. The user +can cancel an operation by cancelling the associated context: + +```go +ctx, cancel := context.WithCancel(ctx) +go session.CallTool(ctx, "slow", map[string]any{}) +cancel() +``` + +When this client call is cancelled, a `"notifications/cancelled"` notification +is sent to the server. However, the client call returns immediately with +`ctx.Err()`: it does not wait for the result from the server. + +The server observes a client cancellation as cancelled context. ### Progress handling From b4891595167031bf6f9eef231c2fac08c280993f Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 9 May 2025 10:28:43 -0400 Subject: [PATCH 031/196] internal/mcp: progress notifications Change-Id: I3dad9326a88c14d08d62395f603862f71f887a75 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671475 TryBot-Bypass: Jonathan Amsterdam Reviewed-by: Robert Findley --- internal/mcp/design/design.md | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index c4b663762ef..f579b238646 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -437,7 +437,24 @@ The server observes a client cancellation as cancelled context. ### Progress handling - +A caller can request progress notifications by setting the `ProgressToken` field on any request. +```go +type ProgressToken any + +type XXXParams struct { // where XXX is each type of call + ... + ProgressToken ProgressToken +} +``` +Handlers can notify their peer about progress by calling the `NotifyProgress` +method. The notification is only sent if the peer requested it. +```go +func (*ClientSession) NotifyProgress(context.Context, *ProgressNotification) +func (*ServerSession) NotifyProgress(context.Context, *ProgressNotification) +``` +We don't support progress notifications for `Client.ListRoots`, because we expect +that operation to be instantaneous relative to network latency. + ### Ping / Keepalive @@ -488,7 +505,7 @@ To perform sampling, a server calls `CreateMessage`. ``` type ClientOptions struct { ... - CreateMessageHandler func(context.Context, *CreateMessageParams) (*CreateMessageResult, error) + CreateMessageHandler func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) } func (*Server) CreateMessage(context.Context, *CreateMessageParams) (*CreateMessageResult, error) From 721ad8de9116f26a9bf82488b5317ec675166770 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 9 May 2025 11:07:08 -0400 Subject: [PATCH 032/196] internal/mcp/design.md: prompts and resources Change-Id: I5663e14846e3a37bea6947a12fbe57117b07174b Reviewed-on: https://go-review.googlesource.com/c/tools/+/671497 TryBot-Bypass: Jonathan Amsterdam Reviewed-by: Robert Findley --- internal/mcp/design/design.md | 77 ++++++++++++++++++++++++++++------- 1 file changed, 63 insertions(+), 14 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index f579b238646..e37604a7a31 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -521,7 +521,7 @@ server.AddTools( mcp.NewTool("add", "add numbers", addHandler), mcp.NewTools("subtract, subtract numbers", subHandler)) ``` -Remove them with `RemoveTools`: +Remove them by name with `RemoveTools`: ``` server.RemoveTools("add", "subtract") ``` @@ -586,21 +586,70 @@ Schemas are validated on the server before the tool handler is called. ### Prompts - +Use `NewPrompt` to create a prompt. +As with tools, prompt argument schemas can be inferred from a struct, or obtained +from options. +```go +func NewPrompt[TReq any](name, description string, + handler func(context.Context, *ServerSession, TReq) (*GetPromptResult, error), + opts ...PromptOption) *ServerPrompt +``` +Use `AddPrompts` to add prompts to the server, and `RemovePrompts` +to remove them by name. -### Resources +```go +type codeReviewArgs struct { + Code string `json:"code"` +} - +func codeReviewHandler(context.Context, *ServerSession, codeReviewArgs) {...} + +server.AddPrompts( + NewPrompt("code_review", "review code", codeReviewHandler, + Argument("code", Description("the code to review")))) + +server.RemovePrompts("code_review") +``` + +Clients can call ListPrompts to list the available prompts and GetPrompt to get one. +```go +func (*ClientSession) ListPrompts(context.Context, *ListPromptParams) (*ListPromptsResult, error) +func (*ClientSession) GetPrompt(context.Context, *GetPromptParams) (*GetPromptResult, error) +``` + +### Resources and resource templates + +Servers have Add and Remove methods for resources and resource templates: +```go +func (*Server) AddResources(resources ...*Resource) +func (*Server) RemoveResources(names ...string) +func (*Server) AddResourceTemplates(templates...*ResourceTemplate) +func (*Server) RemoveResourceTemplates(names ...string) +``` +Clients call ListResources to list the available resources, ReadResource to read +one of them, and ListResourceTemplates to list the templates: +```go +func (*ClientSession) ListResources(context.Context, *ListResourcesParams) (*ListResourcesResult, error) +func (*ClientSession) ReadResource(context.Context, *ReadResourceParams) (*ReadResourceResult, error) +func (*ClientSession) ListResourceTemplates(context.Context, *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) +``` + + + +### ListChanged notifications + +When a list of tools, prompts or resources changes as the result of an AddXXX +or RemoveXXX call, the server informs all its connected clients by sending the +corresponding type of notification. +A client will receive these notifications if it was created with the corresponding option: +```go +type ClientOptions struct { + ... + ToolListChangedHandler func(context.Context, *ClientConnection, *ToolListChangedParams) + PromptListChangedHandler func(context.Context, *ClientConnection, *PromptListChangedParams) + ResourceListChangedHandler func(context.Context, *ClientConnection, *ResourceListChangedParams) +} +``` ### Completion From d11c94a2777096479cbf038cac064226bc2ea9f6 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 9 May 2025 15:24:55 +0000 Subject: [PATCH 033/196] internal/design: minor cleanup; add 'errors' and 'ping' sections Perform some minor cleanup: - Resolve semantic conflicts updating ServerConnection->ServerSession - Use 'go' code blocks throughout Also, add brief sections discussing error handling and 'ping/keepalive'. Change-Id: I9ee888a7f0506ca380c2d4d4f9965a4c399aac19 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671495 Reviewed-by: Jonathan Amsterdam Commit-Queue: Robert Findley TryBot-Bypass: Robert Findley --- internal/mcp/design/design.md | 102 ++++++++++++++++++++++++++-------- 1 file changed, 79 insertions(+), 23 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index e37604a7a31..3f1e5f732f6 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -362,7 +362,10 @@ func NewClient(name, version string, opts *ClientOptions) *Client func (c *Client) Connect(context.Context, Transport) (*ClientSession, error) // Methods for adding/removing client features are described below. +type ClientOptions struct { /* ... */ } // described below + type ClientSession struct { /* ... */ } +func (*ClientSession) Client() *Client func (*ClientSession) Close() error func (*ClientSession) Wait() error // Methods for calling through the ClientSession are described below. @@ -372,13 +375,16 @@ func NewServer(name, version string, opts *ServerOptions) *Server func (s *Server) Connect(context.Context, Transport) (*ServerSession, error) // Methods for adding/removing server features are described below. +type ServerOptions struct { /* ... */ } // described below + type ServerSession struct { /* ... */ } +func (*ServerSession) Server() *Server func (*ServerSession) Close() error func (*ServerSession) Wait() error // Methods for calling through the ServerSession are described below. ``` -Here's an example of these API from the client side: +Here's an example of these APIs from the client side: ```go client := mcp.NewClient("mcp-client", "v1.0.0", nil) @@ -416,7 +422,24 @@ func (*Server) Run(context.Context, Transport) ### Errors - +With the exception of tool handler errors, protocol errors are handled +transparently as Go errors: errors in server-side feature handlers are +propagated as errors from calls from the `ClientSession`, and vice-versa. + +Protocol errors wrap a `JSONRPC2Error` type which exposes its underlying error +code. + +```go +type JSONRPC2Error struct { + Code int64 `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data,omitempty"` +} +``` + +As described by the +[spec](https://modelcontextprotocol.io/specification/2025-03-26/server/tools#error-handling), +tool execution errors are reported in tool results. ### Cancellation @@ -438,6 +461,7 @@ The server observes a client cancellation as cancelled context. ### Progress handling A caller can request progress notifications by setting the `ProgressToken` field on any request. + ```go type ProgressToken any @@ -446,22 +470,44 @@ type XXXParams struct { // where XXX is each type of call ProgressToken ProgressToken } ``` + Handlers can notify their peer about progress by calling the `NotifyProgress` method. The notification is only sent if the peer requested it. + ```go func (*ClientSession) NotifyProgress(context.Context, *ProgressNotification) func (*ServerSession) NotifyProgress(context.Context, *ProgressNotification) ``` + We don't support progress notifications for `Client.ListRoots`, because we expect that operation to be instantaneous relative to network latency. +### Ping / KeepAlive -### Ping / Keepalive +Both `ClientSession` and `ServerSession` expose a `Ping` method to call "ping" +on their peer. - +```go +func (c *ClientSession) Ping(ctx context.Context) error +func (c *ServerSession) Ping(ctx context.Context) error +``` + +Additionally, client and server sessions can be configured with automatic +keepalive behavior. If set to a non-zero value, this duration defines an +interval for regular "ping" requests. If the peer fails to respond to pings +originating from the keepalive check, the session is automatically closed. + +```go +type ClientOptions struct { + ... + KeepAlive time.Duration +} + +type ServerOptions struct { + ... + KeepAlive time.Duration +} +``` ## Client Features @@ -470,7 +516,7 @@ parameterizing automatic keepalive. Clients support the MCP Roots feature out of the box, including roots-changed notifications. Roots can be added and removed from a `Client` with `AddRoots` and `RemoveRoots`: -``` +```go // AddRoots adds the roots to the client's list of roots. // If the list changes, the client notifies the server. // If a root does not begin with a valid URI schema such as "https://" or "file://", @@ -487,13 +533,14 @@ Servers can call `ListRoots` to get the roots. If a server installs a `RootsChangedHandler`, it will be called when the client sends a roots-changed notification, which happens whenever the list of roots changes after a connection has been established. -``` + +```go func (*Server) ListRoots(context.Context, *ListRootsParams) (*ListRootsResult, error) type ServerOptions { ... // If non-nil, called when a client sends a roots-changed notification. - RootsChangedHandler func(context.Context, *ServerConnection, *RootsChangedParams) + RootsChangedHandler func(context.Context, *ServerSession, *RootsChangedParams) } ``` @@ -502,7 +549,8 @@ type ServerOptions { Clients that support sampling are created with a `CreateMessageHandler` option for handling server calls. To perform sampling, a server calls `CreateMessage`. -``` + +```go type ClientOptions struct { ... CreateMessageHandler func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) @@ -516,19 +564,22 @@ func (*Server) CreateMessage(context.Context, *CreateMessageParams) (*CreateMess ### Tools Add tools to a server with `AddTools`: -``` + +```go server.AddTools( mcp.NewTool("add", "add numbers", addHandler), mcp.NewTools("subtract, subtract numbers", subHandler)) ``` + Remove them by name with `RemoveTools`: + ``` server.RemoveTools("add", "subtract") ``` We provide a convenient and type-safe way to construct a Tool: -``` +```go // NewTool is a creates a Tool using reflection on the given handler. func NewTool[TReq any](name, description string, handler func(context.Context, TReq) ([]Content, error), opts …ToolOption) *Tool ``` @@ -556,7 +607,7 @@ becomes a property of the schema. The property is required unless the field's `json` tag specifies "omitempty" or "omitzero" (new in Go 1.24). For example, given this struct: -``` +```go struct { Name string `json:"name"` Count int `json:"count,omitempty"` @@ -570,7 +621,7 @@ struct { The struct provides the names, types and required status of the properties. Other JSON Schema keywords can be specified by passing options to `NewTool`: -``` +```go NewTool(name, description, handler, Input(Property("count", Description("size of the inventory")))) ``` @@ -589,11 +640,13 @@ Schemas are validated on the server before the tool handler is called. Use `NewPrompt` to create a prompt. As with tools, prompt argument schemas can be inferred from a struct, or obtained from options. + ```go func NewPrompt[TReq any](name, description string, handler func(context.Context, *ServerSession, TReq) (*GetPromptResult, error), opts ...PromptOption) *ServerPrompt ``` + Use `AddPrompts` to add prompts to the server, and `RemovePrompts` to remove them by name. @@ -612,6 +665,7 @@ server.RemovePrompts("code_review") ``` Clients can call ListPrompts to list the available prompts and GetPrompt to get one. + ```go func (*ClientSession) ListPrompts(context.Context, *ListPromptParams) (*ListPromptsResult, error) func (*ClientSession) GetPrompt(context.Context, *GetPromptParams) (*GetPromptResult, error) @@ -620,14 +674,17 @@ func (*ClientSession) GetPrompt(context.Context, *GetPromptParams) (*GetPromptRe ### Resources and resource templates Servers have Add and Remove methods for resources and resource templates: + ```go func (*Server) AddResources(resources ...*Resource) func (*Server) RemoveResources(names ...string) func (*Server) AddResourceTemplates(templates...*ResourceTemplate) func (*Server) RemoveResourceTemplates(names ...string) ``` + Clients call ListResources to list the available resources, ReadResource to read one of them, and ListResourceTemplates to list the templates: + ```go func (*ClientSession) ListResources(context.Context, *ListResourcesParams) (*ListResourcesResult, error) func (*ClientSession) ReadResource(context.Context, *ReadResourceParams) (*ReadResourceResult, error) @@ -642,6 +699,7 @@ When a list of tools, prompts or resources changes as the result of an AddXXX or RemoveXXX call, the server informs all its connected clients by sending the corresponding type of notification. A client will receive these notifications if it was created with the corresponding option: + ```go type ClientOptions struct { ... @@ -735,7 +793,7 @@ mcp-go's work only at top level. That limitation is visible in [this code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), which must resort to untyped maps to express a nested schema: -``` +```go mcp.WithArray("items", mcp.Description("Checklist items of the task"), mcp.Items(map[string]any{ @@ -770,10 +828,9 @@ All of the hooks run before or after the server processes a message, so instead we provide a single way to intercept this message handling, using two exported names instead of 72: -``` +```go // A Handler handles an MCP message call. -type Handler func(ctx context.Context, c *ServerConnection, method string, params any) (response any, err error) - +type Handler func(ctx context.Context, s *ServerSession, method string, params any) (response any, err error) // AddMiddleware calls each middleware function from right to left on the previous result, beginning // with the server's current handler, and installs the result as the new handler. @@ -782,13 +839,12 @@ func (*Server) AddMiddleware(middleware ...func(Handler) Handler)) As an example, this code adds server-side logging: -``` - +```go func withLogging(h mcp.Handler) mcp.Handler { - return func(ctx context.Context, c *mcp.ServerConnection, method string, params any) (res any, err error) { + return func(ctx context.Context, s *mcp.ServerSession, method string, params any) (res any, err error) { log.Printf("request: %s %v", method, params) defer func() { log.Printf("response: %v, %v", res, err) }() - return h(ctx, c , method, params) + return h(ctx, s , method, params) } } From b303c1f206acc362456abd1f3b34e43b18b8a9b4 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 9 May 2025 16:08:58 +0000 Subject: [PATCH 034/196] internal/mcp/design: remove the protocol package; update tools Per discussion, remove the separate protocol package in favor of a single mcp package. Update the tools section accordingly, and add more discussion about the decision of the NewTool API. Change-Id: Ifb259555ca9d493ce63bc4f71e5f2fa295bdf09c Reviewed-on: https://go-review.googlesource.com/c/tools/+/671496 TryBot-Bypass: Robert Findley Commit-Queue: Robert Findley Reviewed-by: Jonathan Amsterdam --- internal/mcp/design/design.md | 130 ++++++++++++++++++++++------------ 1 file changed, 85 insertions(+), 45 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 3f1e5f732f6..1bb65e533c3 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -60,14 +60,9 @@ Therefore, this is the package layout. `module.path` is a placeholder for the final module path of the mcp module - `module.path/mcp`: the bulk of the user facing API -- `module.path/mcp/protocol`: generated types for the MCP spec. - `module.path/jsonschema`: a jsonschema implementation, with validation - `module.path/internal/jsonrpc2`: a fork of x/tools/internal/jsonrpc2_v2 -For now, this layout assumes we want to separate the 'protocol' types from the -'mcp' package, since they won't be needed by most users. It is unclear whether -this is worthwhile. - The JSON-RPC implementation is hidden, to avoid tight coupling. As described in the next section, the only aspects of JSON-RPC that need to be exposed in the SDK are the message types, for the purposes of defining custom transports. We @@ -118,7 +113,7 @@ interface. Other SDKs define higher-level transports, with, for example, methods to send a notification or make a call. Those are jsonrpc2 operations on top of the logical stream, and the lower-level interface is easier to implement in most -cases, which means it is easier to implement custom transports or middleware. +cases, which means it is easier to implement custom transports. For our prototype, we've used an internal `jsonrpc2` package based on the Go language server `gopls`, which we propose to fork for the MCP SDK. It already @@ -268,13 +263,15 @@ func (*StreamableClientTransport) Connect(context.Context) (Stream, error) ### Protocol types -As described in the section on package layout above, the `protocol` package -will contain definitions of types referenced by the MCP spec that are needed -for the SDK. JSON-RPC message types are elided, since they are handled by the -`jsonrpc2` package and should not be observed by the user. The user interacts -only with the params/result types relevant to MCP operations. +Types needed for the protocol are generated from the +[JSON schema of the MCP spec](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json). + +These types will be included in the `mcp` package, but will be unexported +unless they are needed for the user-facing API. Notably, JSON-RPC message types +are elided, since they are handled by the `jsonrpc2` package and should not be +observed by the user. -For user-provided data, use `json.RawMessage`, so that +For user-provided data, we use `json.RawMessage`, so that marshalling/unmarshalling can be delegated to the business logic of the client or server. @@ -282,8 +279,6 @@ For union types, which can't be represented in Go (specifically `Content` and `Resource`), we prefer distinguished unions: struct types with fields corresponding to the union of all properties for union elements. -These types will be auto-generated from the [JSON schema of the MCP -spec](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json). For brevity, only a few examples are shown here: ```go @@ -529,9 +524,9 @@ func (*Client) AddRoots(roots ...string) func (*Client) RemoveRoots(roots ...string) ``` -Servers can call `ListRoots` to get the roots. -If a server installs a `RootsChangedHandler`, it will be called when the client sends a -roots-changed notification, which happens whenever the list of roots changes after a +Servers can call `ListRoots` to get the roots. If a server installs a +`RootsChangedHandler`, it will be called when the client sends a roots-changed +notification, which happens whenever the list of roots changes after a connection has been established. ```go @@ -546,9 +541,8 @@ type ServerOptions { ### Sampling -Clients that support sampling are created with a `CreateMessageHandler` option for handling server -calls. -To perform sampling, a server calls `CreateMessage`. +Clients that support sampling are created with a `CreateMessageHandler` option +for handling server calls. To perform sampling, a server calls `CreateMessage`. ```go type ClientOptions struct { @@ -563,6 +557,25 @@ func (*Server) CreateMessage(context.Context, *CreateMessageParams) (*CreateMess ### Tools +A `Tool` is a logical MCP tool, generated from the MCP spec, and a `ServerTool` +is a tool bound to a tool handler. + +```go +type Tool struct { + Annotations *ToolAnnotations `json:"annotations,omitempty"` + Description string `json:"description,omitempty"` + InputSchema *jsonschema.Schema `json:"inputSchema"` + Name string `json:"name"` +} + +type ToolHandler func(context.Context, *ServerSession, map[string]json.RawMessage) (*CallToolResult, error) + +type ServerTool struct { + Tool Tool + Handler ToolHandler +} +``` + Add tools to a server with `AddTools`: ```go @@ -573,33 +586,43 @@ server.AddTools( Remove them by name with `RemoveTools`: -``` - server.RemoveTools("add", "subtract") +```go +server.RemoveTools("add", "subtract") ``` -We provide a convenient and type-safe way to construct a Tool: +A tool's input schema, expressed as a [JSON Schema](https://json-schema.org), +provides a way to validate the tool's input. One of the challenges in defining +tools is the need to associate them with a Go function, yet support the +arbitrary complexity of JSON Schema. To achieve this, we have seen two primary +approaches: + +1. Use reflection to generate the tool's input schema from a Go type (ala + `metoro-io/mcp-golang`) +2. Explicitly build the input schema (ala `mark3labs/mcp-go`). + +Both of these have their advantages and disadvantages. Reflection is nice, +because it allows you to bind directly to a Go API, and means that the JSON +schema of your API is compatible with your Go types by construction. It also +means that concerns like parsing and validation can be handled automatically. +However, it can become cumbersome to express the full breadth of JSON schema +using Go types or struct tags, and sometimes you want to express things that +aren’t naturally modeled by Go types, like unions. Explicit schemas are simple +and readable, and gives the caller full control over their tool definition, but +involve significant boilerplate. + +We believe that a hybrid model works well, where the _initial_ schema is +derived using reflection, but any customization on top of that schema is +applied using variadic options. We achieve this using a `NewTool` helper, which +generates the schema from the input type, and wraps the handler to provide +parsing and validation. The schema (and potentially other features) can be +customized using ToolOptions. ```go // NewTool is a creates a Tool using reflection on the given handler. -func NewTool[TReq any](name, description string, handler func(context.Context, TReq) ([]Content, error), opts …ToolOption) *Tool -``` - -The `TReq` type is typically a struct, and we use reflection on the struct to -determine the names and types of the tool's input. `ToolOption`s allow further -customization of a Tool's input schema. -Since all the fields of the Tool struct are exported, a Tool can also be created -directly with assignment or a struct literal. - -A tool's input schema, expressed as a [JSON Schema](https://json-schema.org), -provides a way to validate the tool's input. +func NewTool[TInput any](name, description string, handler func(context.Context, *ServerSession, TInput) ([]Content, error), opts …ToolOption) *ServerTool -We chose a hybrid a approach to specifying the schema, combining reflection -and variadic options. We found that this makes the common cases easy (sometimes -free!) to express and keeps the API small. The most recent JSON Schema -spec defines over 40 keywords. Providing them all as options would bloat -the API despite the fact that most would be very rarely used. Our approach -also guarantees that the input schema is compatible with tool parameters, by -construction. +type ToolOption interface { /* ... */ } +``` `NewTool` determines the input schema for a Tool from the struct used in the handler. Each struct field that would be marshaled by `encoding/json.Marshal` @@ -618,23 +641,40 @@ struct { "name" and "Choices" are required, while "count" is optional. -The struct provides the names, types and required status of the properties. -Other JSON Schema keywords can be specified by passing options to `NewTool`: +As of writing, the only `ToolOption` is `Input`, which allows customizing the +input schema of the tool using schema options. These schema options are +recursive, in the sense that they may also be applied to properties. + +```go +func Input(...SchemaOption) ToolOption + +type Property(name string, opts ...SchemaOption) SchemaOption +type Description(desc string) SchemaOption +// etc. +``` + +For example: ```go NewTool(name, description, handler, Input(Property("count", Description("size of the inventory")))) ``` -For less common keywords, use the `Schema` option: +The most recent JSON Schema spec defines over 40 keywords. Providing them all +as options would bloat the API despite the fact that most would be very rarely +used. For less common keywords, use the `Schema` option to set the schema +explicitly: -``` +```go NewTool(name, description, handler, Input(Property("Choices", Schema(&jsonschema.Schema{UniqueItems: true})))) ``` Schemas are validated on the server before the tool handler is called. +Since all the fields of the Tool struct are exported, a Tool can also be created +directly with assignment or a struct literal. + ### Prompts Use `NewPrompt` to create a prompt. From 2cf2b2aed5c0bfa9984ea08ef9e0aa9222429fcb Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 9 May 2025 16:46:54 +0000 Subject: [PATCH 035/196] internal/mcp/design: weave discussion of differences with mcp-go Put discussion of mcp-go up front, and address differences in the individual sections. I think this leads to easier reading for the user who is concerned about our divergence in design. Change-Id: Ifd21f6b9eb570ad4a325b8a8800b7868dac12205 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671556 TryBot-Bypass: Robert Findley Reviewed-by: Jonathan Amsterdam Commit-Queue: Robert Findley --- internal/mcp/design/design.md | 344 ++++++++++++++++++---------------- 1 file changed, 180 insertions(+), 164 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 1bb65e533c3..de3a07bde75 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -2,11 +2,37 @@ This file discusses the design of a Go SDK for the [model context protocol](https://modelcontextprotocol.io/specification/2025-03-26). It is -intended to seed a GitHub discussion about the official Go MCP SDK, and so -approaches each design aspect from first principles. Of course, there is -significant prior art in various unofficial SDKs, along with other practical -constraints. Nevertheless, if we can first agree on the most natural way to -model the MCP spec, we can then discuss the shortest path to get there. +intended to seed a GitHub discussion about the official Go MCP SDK. + +The golang.org/x/tools/internal/mcp package contains a prototype that we built +to explore the MCP design space. Many of the ideas there are present in this +document. However, we have diverged and expanded on the APIs of that prototype, +and this document should be considered canonical. + +## Similarities and differences with mark3labs/mcp-go + +The most popular unofficial MCP SDK for Go is +[mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go). As of this +writing, it is imported by over 400 packages that span over 200 modules. + +We admire mcp-go, and seriously considered simply adopting it as a starting +point for this SDK. However, as we looked at doing so, we realized that a +significant amount of its API would probably need to change. In some cases, +mcp-go has older APIs that predated newer variations--an obvious opportunity +for cleanup. In others, it took a batteries-included approach that is probably +not viable for an official SDK. In yet others, we simply think there is room for +API refinement, and we should take this opportunity to consider our options. +Therefore, we wrote this document as though it were proposing a new +implementation. Nevertheless, much of the API discussed here originated from or +was inspired by mcp-go and other unofficial SDKs, and if the consensus of this +discussion is close enough to mcp-go or any other unofficial SDK, we can start +from a fork. + +Since mcp-go is so influential and popular, we have noted significant +differences from its API in the sections below. Although the API here is not +compatible with mcp-go, translating between them should be straightforward in +most cases. +(Later, we will provide a detailed translation guide.) # Requirements @@ -35,15 +61,6 @@ the order they are presented by the [official spec](https://modelcontextprotocol For each, we discuss considerations for the Go implementation. In many cases an API is suggested, though in some there many be open questions. - - ## Foundations ### Package layout @@ -68,6 +85,9 @@ the next section, the only aspects of JSON-RPC that need to be exposed in the SDK are the message types, for the purposes of defining custom transports. We can expose these types from the `mcp` package via aliases or wrappers. +**Difference from mcp-go**: Our `mcp` package includes all the functionality of +mcp-go's `mcp`, `client`, `server` and `transport` packages. + ### jsonrpc2 and Transports The MCP is defined in terms of client-server communication over bidirectional @@ -165,21 +185,21 @@ Importantly, since they serve many connections, the HTTP handlers must accept a callback to get an MCP server for each new session. ```go -// SSEHandler is an http.Handler that serves SSE-based MCP sessions as defined by +// SSEHTTPHandler is an http.Handler that serves SSE-based MCP sessions as defined by // the 2024-11-05 version of the MCP protocol. -type SSEHandler struct { /* unexported fields */ } +type SSEHTTPHandler struct { /* unexported fields */ } -// NewSSEHandler returns a new [SSEHandler] that is ready to serve HTTP. +// NewSSEHTTPHandler returns a new [SSEHTTPHandler] that is ready to serve HTTP. // // The getServer function is used to bind created servers for new sessions. It // is OK for getServer to return the same server multiple times. -func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler +func NewSSEHTTPHandler(getServer func(request *http.Request) *Server) *SSEHTTPHandler -func (*SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) +func (*SSEHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) -// Close prevents the SSEHandler from accepting new sessions, closes active +// Close prevents the SSEHTTPHandler from accepting new sessions, closes active // sessions, and awaits their graceful termination. -func (*SSEHandler) Close() error +func (*SSEHTTPHandler) Close() error ``` Notably absent are options to hook into the request handling for the purposes @@ -240,11 +260,11 @@ more complicated implementation. For brevity, we summarize only the differences from the equivalent SSE types: ```go -// The StreamableHandler interface is symmetrical to the SSEHandler. -type StreamableHandler struct { /* unexported fields */ } -func NewStreamableHandler(getServer func(request *http.Request) *Server) *StreamableHandler -func (*StreamableHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) -func (*StreamableHandler) Close() error +// The StreamableHTTPHandler interface is symmetrical to the SSEHTTPHandler. +type StreamableHTTPHandler struct { /* unexported fields */ } +func NewStreamableHTTPHandler(getServer func(request *http.Request) *Server) *StreamableHTTPHandler +func (*StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) +func (*StreamableHTTPHandler) Close() error // Unlike the SSE transport, the streamable transport constructor accepts a // session ID, not an endpoint, along with the http response for the request @@ -261,6 +281,32 @@ func NewStreamableClientTransport(url string) *StreamableClientTransport { func (*StreamableClientTransport) Connect(context.Context) (Stream, error) ``` +**Differences from mcp-go**: The Go team has a battle-tested JSON-RPC +implementation that we use for gopls, our Go LSP server. We are using the new +version of this library as part of our MCP SDK. It handles all JSON-RPC 2.0 +features, including cancellation. + +The `Transport` interface here is lower-level than that of mcp-go, but serves a +similar purpose. We believe the lower-level interface is easier to implement. + +In mcp-go, server authors create an `MCPServer`, populate it with tools, +resources and so on, and then wrap it in an `SSEServer` or `StdioServer`. These +also use session IDs, which are exposed. Users can manage their own sessions +with `RegisterSession` and `UnregisterSession`. + +We find the similarity in names among the three server types to be confusing, +and we could not discover any uses of the session methods in the open-source +ecosystem. In our design, server authors create a `Server`, and then +connect it to a `Transport`. An `SSEHTTPHandler` manages sessions for +incoming SSE connections, but does not expose them. HTTP handlers accept a +server constructor, rather than Server, to allow for stateful or "per-session" +servers. + +Individual handlers and transports here have a minimal smaller API, and do not +expose internal details. Customization of things like handlers or session +management is intended to be implemented with middleware and/or compositional +patterns. + ### Protocol types Types needed for the protocol are generated from the @@ -317,6 +363,9 @@ type Resource struct { } ``` +**Differences from mcp-go**: these types are largely similar, but our type +generation flattens types rather than using struct embedding. + ### Clients and Servers Generally speaking, the SDK is used by creating a `Client` or `Server` @@ -415,17 +464,71 @@ session until the client disconnects: func (*Server) Run(context.Context, Transport) ``` +**Differences from mcp-go**: the Server APIs are very similar to mcp-go, +though the association between servers and transports is different. In +mcp-go, a single server is bound to what we would call an `SSEHTTPHandler`, +and reused for all client sessions. As discussed above, the transport +abstraction here is differentiated from HTTP serving, and the `Server.Connect` +method provides a consistent API for binding to an arbitrary transport. Servers +here do not have methods for sending notifications or calls, because they are +logically distinct from the `ServerSession`. In mcp-go, servers are `n:1`, +but there is no abstraction of a server session: sessions are addressed in +Server APIs through their `sessionID`: `SendNotificationToAllClients`, +`SendNotificationToClient`, `SendNotificationToSpecificClient`. + +The client API here is different, since clients and client sessions are +conceptually distinct. The `ClientSession` is closer to mcp-go's notion of +Client. + +For both clients and servers, mcp-go uses variadic options to customize +behavior, whereas an options struct is used here. We felt that in this case, an +options struct would be more readable, and result in cleaner package +documentation. + +### Middleware + +We provide a mechanism to add MCP-level middleware, which runs after the +request has been parsed, but before any normal handling. + +```go +// A Handler handles an MCP message call. +type Handler func(ctx context.Context, s *ServerSession, method string, params any) (response any, err error) + +// AddMiddleware calls each middleware function from right to left on the previous result, beginning +// with the server's current handler, and installs the result as the new handler. +func (*Server) AddMiddleware(middleware ...func(Handler) Handler)) +``` + +As an example, this code adds server-side logging: + +```go +func withLogging(h mcp.Handler) mcp.Handler { + return func(ctx context.Context, s *mcp.ServerSession, method string, params any) (res any, err error) { + log.Printf("request: %s %v", method, params) + defer func() { log.Printf("response: %v, %v", res, err) }() + return h(ctx, s , method, params) + } +} + +server.AddMiddleware(withLogging) +``` + +**Differences from mcp-go**: Version 0.26.0 of mcp-go defines 24 server hooks. +Each hook consists of a field in the `Hooks` struct, a `Hooks.Add` method, and +a type for the hook function. These are rarely used. The most common is +`OnError`, which occurs fewer than ten times in open-source code. + ### Errors With the exception of tool handler errors, protocol errors are handled transparently as Go errors: errors in server-side feature handlers are propagated as errors from calls from the `ClientSession`, and vice-versa. -Protocol errors wrap a `JSONRPC2Error` type which exposes its underlying error +Protocol errors wrap a `JSONRPCError` type which exposes its underlying error code. ```go -type JSONRPC2Error struct { +type JSONRPCError struct { Code int64 `json:"code"` Message string `json:"message"` Data json.RawMessage `json:"data,omitempty"` @@ -436,6 +539,10 @@ As described by the [spec](https://modelcontextprotocol.io/specification/2025-03-26/server/tools#error-handling), tool execution errors are reported in tool results. +**Differences from mcp-go**: the `JSONRPCError` type here does not expose +details that are irrelevant or can be inferred from the caller (ID and Method). +Otherwise, this behavior is similar. + ### Cancellation Cancellation is implemented transparently using context cancellation. The user @@ -504,6 +611,10 @@ type ServerOptions struct { } ``` +**Differences from mcp-go**: in mcp-go the `Ping` method is only provided for +client, not server, and the keepalive option is only provided for SSE servers +(as a variadic option). + ## Client Features ### Roots @@ -550,7 +661,7 @@ type ClientOptions struct { CreateMessageHandler func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) } -func (*Server) CreateMessage(context.Context, *CreateMessageParams) (*CreateMessageResult, error) +func (*ServerSession) CreateMessage(context.Context, *CreateMessageParams) (*CreateMessageResult, error) ``` ## Server Features @@ -675,6 +786,31 @@ Schemas are validated on the server before the tool handler is called. Since all the fields of the Tool struct are exported, a Tool can also be created directly with assignment or a struct literal. +**Differences from mcp-go**: using variadic options to configure tools was +signficantly inspired by mcp-go. However, the distinction between `ToolOption` +and `SchemaOption` allows for recursive application of schema options. +For example, that limitation is visible in [this +code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), +which must resort to untyped maps to express a nested schema. + +Additionally, the `NewTool` helper provides a means for building a tool from a +Go function using reflection, that automatically handles parsing and validation +of inputs. + +We provide a full JSON Schema implementation for validating tool input schemas +against incoming arguments. The `jsonschema.Schema` type provides exported +features for all keywords in the JSON Schema draft2020-12 spec. Tool definers +can use it to construct any schema they want, so there is no need to provide +options for all of them. When combined with schema inference from input +structs, we found that we needed only three options to cover the common cases, +instead of mcp-go's 23. For example, we will provide `Enum`, which occurs 125 +times in open source code, but not MinItems, MinLength or MinProperties, which +each occur only once (and in an SDK that wraps mcp-go). + +For registering tools, we provide only `AddTools`; mcp-go's `SetTools`, +`AddTool`, `AddSessionTool`, and `AddSessionTools` are deemed unnecessary. +(similarly for Delete/Remove). + ### Prompts Use `NewPrompt` to create a prompt. @@ -711,6 +847,10 @@ func (*ClientSession) ListPrompts(context.Context, *ListPromptParams) (*ListProm func (*ClientSession) GetPrompt(context.Context, *GetPromptParams) (*GetPromptResult, error) ``` +**Differences from mcp-go**: We provide a `NewPrompt` helper to bind a prompt +handler to a Go function using reflection to derive its arguments. We provide +`RemovePrompts` to remove prompts from the server. + ### Resources and resource templates Servers have Add and Remove methods for resources and resource templates: @@ -749,13 +889,24 @@ type ClientOptions struct { } ``` +**Differences from mcp-go**: mcp-go instead provides a general `OnNotification` +handler. For type-safety, and to hide JSON RPC details, we provide +feature-specific handlers here. + ### Completion Clients call `Complete` to request completions. +```go +func (*ClientSession) Complete(context.Context, *CompleteParams) (*CompleteResult, error) +``` + Servers automatically handle these requests based on their collections of prompts and resources. +**Differences from mcp-go**: the client API is similar. mcp-go has not yet +defined its server-side behavior. + ### Logging Servers have access to a `slog.Logger` that writes to the client. A call to @@ -778,138 +929,3 @@ follows: ### Pagination - -## Differences with mark3labs/mcp-go - -The most popular MCP module for Go is [mark3labs/mcp-go](https://pkg.go.dev/github.com/ -mark3labs/mcp-go). -As of this writing, it is imported by over 400 packages that span over 200 modules. - -We admire mcp-go, and seriously considered simply adopting it as a starting -point for this SDK. However, as we looked at doing so, we realized that a -significant amount of its API would probably need to change. In some cases, -mcp-go has older APIs that predated newer variations--an obvious opportunity -for cleanup. In others, it took a batteries-included approach that is probably -not viable for an official SDK. In yet others, we simply think there is room for -API refinement, and we should take this opportunity to reconsider. Therefore, -we wrote this SDK design from the perspective of a new implementation. -Nevertheless, much of the API discussed here originated from or was inspired -by mcp-go and other unofficial SDKs, and if the consensus of this discussion is -close enough to mcp-go or any other unofficial SDK, we can start from a fork. - -Although our API is not compatible with mcp-go's, translating between them should be -straightforward in most cases. -(Later, we will provide a detailed translation guide.) - -### Packages - -As we mentioned above, we decided to put most of the API into a single package. -Our `mcp` package includes all the functionality of mcp-go's `mcp`, `client`, -`server` and `transport` packages, but is smaller than the `mcp` package alone. - -### Typed tool inputs - -We provide a way to supply a struct as the input type of a Tool, as described -in [JSON Schema](#JSON_Schema), above. -The tool handler receives a value of this struct instead of a `map[string]any`, -so it doesn't need to parse its input parameters. Also, we infer the input schema -from the struct, avoiding the need to specify the name, type and required status of -parameters. - -### Schema validation - -We provide a full JSON Schema implementation for validating tool input schemas against -incoming arguments. The `jsonschema.Schema` type provides exported features for all -keywords in the JSON Schema draft2020-12 spec. Tool definers can use it to construct -any schema they want, so there is no need to provide options for all of them. -When combined with schema inference from input structs, -we found that we needed only three options to cover the common cases, -instead of mcp-go's 23. For example, we provide `Enum`, which occurs 125 times in open source -code, but not MinItems, MinLength or MinProperties, which each occur only once (and in an SDK -that wraps mcp-go). - -Moreover, our options can be used to build nested schemas, while -mcp-go's work only at top level. That limitation is visible in -[this code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), -which must resort to untyped maps to express a nested schema: - -```go -mcp.WithArray("items", - mcp.Description("Checklist items of the task"), - mcp.Items(map[string]any{ - "type": "object", - "properties": map[string]any{ - "id": map[string]any{ - "type": "string", - "description": "Unique identifier of the checklist item", - }, - "status": map[string]any{ - "type": "number", - "description": "Status of the checklist item (0: normal, 1: completed)", - "enum": []float64{0, 1}, - }, - ... -``` - -### JSON-RPC implementation - -The Go team has a battle-tested JSON-RPC implementation that we use for gopls, our -Go LSP server. We are using the new version of this library as part of our MCP SDK. -It handles all JSON-RPC 2.0 features, including cancellation. - -### Hooks - -Version 0.26.0 of mcp-go defines 24 server hooks. Each hook consists of a field -in the `Hooks` struct, a `Hooks.Add` method, and a type for the hook function. -These are rarely used. The most common is `OnError`, which occurs fewer than ten -times in open-source code. - -All of the hooks run before or after the server processes a message, -so instead we provide a single way to intercept this message handling, using -two exported names instead of 72: - -```go -// A Handler handles an MCP message call. -type Handler func(ctx context.Context, s *ServerSession, method string, params any) (response any, err error) - -// AddMiddleware calls each middleware function from right to left on the previous result, beginning -// with the server's current handler, and installs the result as the new handler. -func (*Server) AddMiddleware(middleware ...func(Handler) Handler)) -``` - -As an example, this code adds server-side logging: - -```go -func withLogging(h mcp.Handler) mcp.Handler { - return func(ctx context.Context, s *mcp.ServerSession, method string, params any) (res any, err error) { - log.Printf("request: %s %v", method, params) - defer func() { log.Printf("response: %v, %v", res, err) }() - return h(ctx, s , method, params) - } -} - -server.AddMiddleware(withLogging) -``` - -### Options - -In Go, the two most common ways to provide options to a function are option structs (for example, -https://pkg.go.dev/net/http#PushOptions) and -variadic option functions. mcp-go uses option functions exclusively. For example, -the `server.NewMCPServer` function has ten associated functions to provide options. -Our API uses both, depending on the context. We use function options for -constructing tools, where they are most convenient. In most other places, we -prefer structs because they have a smaller API footprint and are less verbose. - -### Servers - -In mcp-go, server authors create an `MCPServer`, populate it with tools, -resources and so on, and then wrap it in an `SSEServer` or `StdioServer`. These -also use session IDs, which are exposed. Users can manage their own sessions -with `RegisterSession` and `UnregisterSession`. - -We find the similarity in names among the three server types to be confusing, -and we could not discover any uses of the session methods in the open-source -ecosystem. In our design, server authors create a `Server`, and then -connect it to a `Transport`. An `SSEHandler` manages sessions for -incoming SSE connections, but does not expose them. From ce6fe291a58c56822625ba8e7805940c57484603 Mon Sep 17 00:00:00 2001 From: Peter Weinberger Date: Fri, 9 May 2025 08:48:39 -0400 Subject: [PATCH 036/196] gopls/internal/completion: apply modernizers This CL replaces a lot of for loops with range loops. Change-Id: I70e4742afe3dd223d953f01d4ca832f27d709541 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671395 LUCI-TryBot-Result: Go LUCI Reviewed-by: Madeline Kalil --- gopls/internal/golang/completion/completion.go | 16 ++++++++-------- gopls/internal/golang/completion/format.go | 16 ++++++++-------- gopls/internal/golang/completion/fuzz.go | 4 ++-- gopls/internal/golang/completion/literal.go | 12 ++++++------ .../golang/completion/postfix_snippets.go | 4 ++-- gopls/internal/golang/completion/statements.go | 6 +++--- gopls/internal/golang/completion/util.go | 6 +++--- 7 files changed, 32 insertions(+), 32 deletions(-) diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index 47fcbf463eb..ddaaac15ece 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -1117,7 +1117,7 @@ func (c *completer) populateCommentCompletions(comment *ast.CommentGroup) { _, named := typesinternal.ReceiverNamed(recv) if named != nil { if recvStruct, ok := named.Underlying().(*types.Struct); ok { - for i := 0; i < recvStruct.NumFields(); i++ { + for i := range recvStruct.NumFields() { field := recvStruct.Field(i) c.deepState.enqueue(candidate{obj: field, score: lowScore}) } @@ -1614,7 +1614,7 @@ func (c *completer) methodsAndFields(typ types.Type, addressable bool, imp *impo c.methodSetCache[methodSetKey{typ, addressable}] = mset } - for i := 0; i < mset.Len(); i++ { + for i := range mset.Len() { obj := mset.At(i).Obj() // to the other side of the cb() queue? if c.tooNew(obj) { @@ -2001,7 +2001,7 @@ func (c *completer) structLiteralFieldName(ctx context.Context) error { // Add struct fields. if t, ok := types.Unalias(clInfo.clType).(*types.Struct); ok { const deltaScore = 0.0001 - for i := 0; i < t.NumFields(); i++ { + for i := range t.NumFields() { field := t.Field(i) if !addedFields[field] { c.deepState.enqueue(candidate{ @@ -2170,7 +2170,7 @@ func expectedCompositeLiteralType(clInfo *compLitInfo, pos token.Pos) types.Type // value side. The expected type of the value will be determined from the key. if clInfo.kv != nil { if key, ok := clInfo.kv.Key.(*ast.Ident); ok { - for i := 0; i < t.NumFields(); i++ { + for i := range t.NumFields() { if field := t.Field(i); field.Name() == key.Name { return field.Type() } @@ -2755,7 +2755,7 @@ func reverseInferTypeArgs(sig *types.Signature, typeArgs []types.Type, expectedR } substs := make([]types.Type, sig.TypeParams().Len()) - for i := 0; i < sig.TypeParams().Len(); i++ { + for i := range sig.TypeParams().Len() { if sub := u.handles[sig.TypeParams().At(i)]; sub != nil && *sub != nil { // Ensure the inferred subst is assignable to the type parameter's constraint. if !assignableTo(*sub, sig.TypeParams().At(i).Constraint()) { @@ -2835,7 +2835,7 @@ func (c *completer) expectedCallParamType(inf candidateInference, node *ast.Call // call. Record the assignees so we can favor function // calls that return matching values. if len(node.Args) <= 1 && exprIdx == 0 { - for i := 0; i < sig.Params().Len(); i++ { + for i := range sig.Params().Len() { inf.assignees = append(inf.assignees, sig.Params().At(i).Type()) } @@ -2916,7 +2916,7 @@ func objChain(info *types.Info, e ast.Expr) []types.Object { } // Reverse order so the layout matches the syntactic order. - for i := 0; i < len(objs)/2; i++ { + for i := range len(objs) / 2 { objs[i], objs[len(objs)-1-i] = objs[len(objs)-1-i], objs[i] } @@ -3502,7 +3502,7 @@ func (ci *candidateInference) assigneesMatch(cand *candidate, sig *types.Signatu // assignees match the corresponding sig result value, the signature // is a match. allMatch := false - for i := 0; i < sig.Results().Len(); i++ { + for i := range sig.Results().Len() { var assignee types.Type // If we are completing into variadic parameters, deslice the diff --git a/gopls/internal/golang/completion/format.go b/gopls/internal/golang/completion/format.go index 69339bffe84..5c9d81cff39 100644 --- a/gopls/internal/golang/completion/format.go +++ b/gopls/internal/golang/completion/format.go @@ -398,32 +398,32 @@ func inferableTypeParams(sig *types.Signature) map[*types.TypeParam]bool { case *types.Slice: visit(t.Elem()) case *types.Interface: - for i := 0; i < t.NumExplicitMethods(); i++ { + for i := range t.NumExplicitMethods() { visit(t.ExplicitMethod(i).Type()) } - for i := 0; i < t.NumEmbeddeds(); i++ { + for i := range t.NumEmbeddeds() { visit(t.EmbeddedType(i)) } case *types.Union: - for i := 0; i < t.Len(); i++ { + for i := range t.Len() { visit(t.Term(i).Type()) } case *types.Signature: if tp := t.TypeParams(); tp != nil { // Generic signatures only appear as the type of generic // function declarations, so this isn't really reachable. - for i := 0; i < tp.Len(); i++ { + for i := range tp.Len() { visit(tp.At(i).Constraint()) } } visit(t.Params()) visit(t.Results()) case *types.Tuple: - for i := 0; i < t.Len(); i++ { + for i := range t.Len() { visit(t.At(i).Type()) } case *types.Struct: - for i := 0; i < t.NumFields(); i++ { + for i := range t.NumFields() { visit(t.Field(i).Type()) } case *types.TypeParam: @@ -432,7 +432,7 @@ func inferableTypeParams(sig *types.Signature) map[*types.TypeParam]bool { visit(types.Unalias(t)) case *types.Named: targs := t.TypeArgs() - for i := 0; i < targs.Len(); i++ { + for i := range targs.Len() { visit(targs.At(i)) } case *types.Basic: @@ -446,7 +446,7 @@ func inferableTypeParams(sig *types.Signature) map[*types.TypeParam]bool { // Perform induction through constraints. restart: - for i := 0; i < sig.TypeParams().Len(); i++ { + for i := range sig.TypeParams().Len() { tp := sig.TypeParams().At(i) if free[tp] { n := len(free) diff --git a/gopls/internal/golang/completion/fuzz.go b/gopls/internal/golang/completion/fuzz.go index 3f5ac99c428..9e3bb7ba1e2 100644 --- a/gopls/internal/golang/completion/fuzz.go +++ b/gopls/internal/golang/completion/fuzz.go @@ -50,7 +50,7 @@ Loop: } } if inside { - for i := 0; i < mset.Len(); i++ { + for i := range mset.Len() { o := mset.At(i).Obj() if o.Name() == "Failed" || o.Name() == "Name" { cb(candidate{ @@ -125,7 +125,7 @@ Loop: isSlice: false, } c.items = append(c.items, xx) - for i := 0; i < mset.Len(); i++ { + for i := range mset.Len() { o := mset.At(i).Obj() if o.Name() != "Fuzz" { cb(candidate{ diff --git a/gopls/internal/golang/completion/literal.go b/gopls/internal/golang/completion/literal.go index ef077ab7e20..20cce04b69f 100644 --- a/gopls/internal/golang/completion/literal.go +++ b/gopls/internal/golang/completion/literal.go @@ -214,7 +214,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m paramNameCount = make(map[string]int) hasTypeParams bool ) - for i := 0; i < sig.Params().Len(); i++ { + for i := range sig.Params().Len() { var ( p = sig.Params().At(i) name = p.Name() @@ -258,7 +258,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m } } - for i := 0; i < sig.Params().Len(); i++ { + for i := range sig.Params().Len() { if hasTypeParams && !c.opts.placeholders { // If there are type params in the args then the user must // choose the concrete types. If placeholders are disabled just @@ -331,7 +331,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m results.Len() == 1 && results.At(0).Name() != "" var resultHasTypeParams bool - for i := 0; i < results.Len(); i++ { + for i := range results.Len() { if tp, ok := types.Unalias(results.At(i).Type()).(*types.TypeParam); ok && !c.typeParamInScope(tp) { resultHasTypeParams = true } @@ -340,7 +340,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m if resultsNeedParens { snip.WriteText("(") } - for i := 0; i < results.Len(); i++ { + for i := range results.Len() { if resultHasTypeParams && !c.opts.placeholders { // Leave an empty tabstop if placeholders are disabled and there // are type args that need specificying. @@ -535,7 +535,7 @@ func (c *completer) typeNameSnippet(literalType types.Type, qual types.Qualifier snip.WriteText(typeName + "[") if c.opts.placeholders { - for i := 0; i < tparams.Len(); i++ { + for i := range tparams.Len() { if i > 0 { snip.WriteText(", ") } @@ -567,7 +567,7 @@ func (c *completer) fullyInstantiated(t typesinternal.NamedOrAlias) bool { return false } - for i := 0; i < targs.Len(); i++ { + for i := range targs.Len() { targ := targs.At(i) // The expansion of an alias can have free type parameters, diff --git a/gopls/internal/golang/completion/postfix_snippets.go b/gopls/internal/golang/completion/postfix_snippets.go index 1d306e3518d..e81fb67a2ed 100644 --- a/gopls/internal/golang/completion/postfix_snippets.go +++ b/gopls/internal/golang/completion/postfix_snippets.go @@ -436,7 +436,7 @@ func (a *postfixTmplArgs) Tuple() []*types.Var { } typs := make([]*types.Var, 0, tuple.Len()) - for i := 0; i < tuple.Len(); i++ { + for i := range tuple.Len() { typs = append(typs, tuple.At(i)) } return typs @@ -564,7 +564,7 @@ func (c *completer) addPostfixSnippetCandidates(ctx context.Context, sel *ast.Se results := c.enclosingFunc.sig.Results() if results != nil { funcResults = make([]*types.Var, results.Len()) - for i := 0; i < results.Len(); i++ { + for i := range results.Len() { funcResults[i] = results.At(i) } } diff --git a/gopls/internal/golang/completion/statements.go b/gopls/internal/golang/completion/statements.go index 3791211d6a6..e8b35a4cfdb 100644 --- a/gopls/internal/golang/completion/statements.go +++ b/gopls/internal/golang/completion/statements.go @@ -294,7 +294,7 @@ func (c *completer) addErrCheck() { label = fmt.Sprintf("%[1]s != nil { %[2]s.Fatal(%[1]s) }", errVar, testVar) } else { snip.WriteText("return ") - for i := 0; i < result.Len()-1; i++ { + for i := range result.Len() - 1 { if zero, isValid := typesinternal.ZeroString(result.At(i).Type(), c.qual); isValid { snip.WriteText(zero) } @@ -351,7 +351,7 @@ func getTestVar(enclosingFunc *funcInfo, pkg *cache.Package) string { } sig := enclosingFunc.sig - for i := 0; i < sig.Params().Len(); i++ { + for i := range sig.Params().Len() { param := sig.Params().At(i) if param.Name() == "_" { continue @@ -401,7 +401,7 @@ func (c *completer) addReturnZeroValues() { snip.WriteText("return ") fmt.Fprintf(&label, "return ") - for i := 0; i < result.Len(); i++ { + for i := range result.Len() { if i > 0 { snip.WriteText(", ") fmt.Fprintf(&label, ", ") diff --git a/gopls/internal/golang/completion/util.go b/gopls/internal/golang/completion/util.go index 306078296c1..fe1b86fdea2 100644 --- a/gopls/internal/golang/completion/util.go +++ b/gopls/internal/golang/completion/util.go @@ -48,7 +48,7 @@ func eachField(T types.Type, fn func(*types.Var)) { return } - for i := 0; i < T.NumFields(); i++ { + for i := range T.NumFields() { f := T.Field(i) fn(f) if f.Anonymous() { @@ -85,7 +85,7 @@ func typeIsValid(typ types.Type) bool { case *types.Signature: return typeIsValid(typ.Params()) && typeIsValid(typ.Results()) case *types.Tuple: - for i := 0; i < typ.Len(); i++ { + for i := range typ.Len() { if !typeIsValid(typ.At(i).Type()) { return false } @@ -242,7 +242,7 @@ func typeConversion(call *ast.CallExpr, info *types.Info) types.Type { // fieldsAccessible returns whether s has at least one field accessible by p. func fieldsAccessible(s *types.Struct, p *types.Package) bool { - for i := 0; i < s.NumFields(); i++ { + for i := range s.NumFields() { f := s.Field(i) if f.Exported() || f.Pkg() == p { return true From a240192bdfd79024f2cea5793a5758c1222cee41 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 9 May 2025 13:47:46 -0400 Subject: [PATCH 037/196] internal/mcp/design.md: minor changes Fix typos and formatting, reword slightly. Also, change names in the middleware section. Change-Id: Ie86e5f633a07675ab188a8aaf70e51a48b0878bb Reviewed-on: https://go-review.googlesource.com/c/tools/+/671555 Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 72 ++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index de3a07bde75..a7a22d644d6 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -59,7 +59,7 @@ SDK. An official SDK should aim to be: In the sections below, we visit each aspect of the MCP spec, in approximately the order they are presented by the [official spec](https://modelcontextprotocol.io/specification/2025-03-26) For each, we discuss considerations for the Go implementation. In many cases an -API is suggested, though in some there many be open questions. +API is suggested, though in some there may be open questions. ## Foundations @@ -123,10 +123,11 @@ type Stream interface { Close() error } ``` +Methods accept a Go `Context` and return an `error`, +as is idiomatic for APIs that do I/O. -Specifically, a `Transport` is something that connects a logical JSON-RPC -stream, and nothing more (methods accept a Go `Context` and return an `error`, -as is idiomatic for APIs that do I/O). Streams must be closeable in order to +A `Transport` is something that connects a logical JSON-RPC +stream, and nothing more. Streams must be closeable in order to implement client and server shutdown, and therefore conform to the `io.Closer` interface. @@ -202,9 +203,10 @@ func (*SSEHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) func (*SSEHTTPHandler) Close() error ``` -Notably absent are options to hook into the request handling for the purposes +Notably absent are options to hook into low-level request handling for the purposes of authentication or context injection. These concerns are better handled using -standard HTTP middleware patterns. +standard HTTP middleware patterns. For middleware at the level of the MCP protocol, +see [Middleware](#Middleware) below. By default, the SSE handler creates messages endpoints with the `?sessionId=...` query parameter. Users that want more control over the @@ -247,9 +249,7 @@ type SSEClientTransport struct { /* ... */ } // NewSSEClientTransport returns a new client transport that connects to the // SSE server at the provided URL. -// -// NewSSEClientTransport panics if the given URL is invalid. -func NewSSEClientTransport(url string) *SSEClientTransport { +func NewSSEClientTransport(url string) (*SSEClientTransport, error) { // Connect connects through the client endpoint. func (*SSEClientTransport) Connect(ctx context.Context) (Stream, error) @@ -378,12 +378,13 @@ and resources from servers. Additionally, handlers for these features may themselves be stateful, for example if a tool handler caches state from earlier requests in the session. -We believe that in the common case, both clients and servers are stateless, and -it is therefore more useful to allow multiple connections from a client, and to +We believe that in the common case, any change to a client or server, +such as adding a tool, is intended for all its peers. +It is therefore more useful to allow multiple connections from a client, and to a server. This is similar to the `net/http` packages, in which an `http.Client` and `http.Server` each may handle multiple unrelated connections. When users add features to a client or server, all connected peers are notified of the -change in feature-set. +change. Following the terminology of the [spec](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#session-management), @@ -435,9 +436,7 @@ client := mcp.NewClient("mcp-client", "v1.0.0", nil) // Connect to a server over stdin/stdout transport := mcp.NewCommandTransport(exec.Command("myserver")) session, err := client.Connect(ctx, transport) -if err != nil { - log.Fatal(err) -} +if err != nil { ... } // Call a tool on the server. content, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}) ... @@ -491,12 +490,14 @@ We provide a mechanism to add MCP-level middleware, which runs after the request has been parsed, but before any normal handling. ```go -// A Handler handles an MCP message call. -type Handler func(ctx context.Context, s *ServerSession, method string, params any) (response any, err error) +// A Dispatcher dispatches an MCP message to the appropriate handler. +// The params argument will be an XXXParams struct pointer, such as *GetPromptParams. +// The response if err is non-nil should be an XXXResult struct pointer. +type Dispatcher func(ctx context.Context, s *ServerSession, method string, params any) (result any, err error) -// AddMiddleware calls each middleware function from right to left on the previous result, beginning -// with the server's current handler, and installs the result as the new handler. -func (*Server) AddMiddleware(middleware ...func(Handler) Handler)) +// AddDispatchers calls each function from right to left on the previous result, beginning +// with the server's current dispatcher, and installs the result as the new handler. +func (*Server) AddDispatchers(middleware ...func(Handler) Handler)) ``` As an example, this code adds server-side logging: @@ -510,7 +511,7 @@ func withLogging(h mcp.Handler) mcp.Handler { } } -server.AddMiddleware(withLogging) +server.AddDispatchers(withLogging) ``` **Differences from mcp-go**: Version 0.26.0 of mcp-go defines 24 server hooks. @@ -558,14 +559,14 @@ When this client call is cancelled, a `"notifications/cancelled"` notification is sent to the server. However, the client call returns immediately with `ctx.Err()`: it does not wait for the result from the server. -The server observes a client cancellation as cancelled context. +The server observes a client cancellation as a cancelled context. ### Progress handling A caller can request progress notifications by setting the `ProgressToken` field on any request. ```go -type ProgressToken any +type ProgressToken any // string or int type XXXParams struct { // where XXX is each type of call ... @@ -595,7 +596,8 @@ func (c *ServerSession) Ping(ctx context.Context) error ``` Additionally, client and server sessions can be configured with automatic -keepalive behavior. If set to a non-zero value, this duration defines an +keepalive behavior. If the `KeepAlive` option is set to a non-zero duration, +it defines an interval for regular "ping" requests. If the peer fails to respond to pings originating from the keepalive check, the session is automatically closed. @@ -707,9 +709,9 @@ tools is the need to associate them with a Go function, yet support the arbitrary complexity of JSON Schema. To achieve this, we have seen two primary approaches: -1. Use reflection to generate the tool's input schema from a Go type (ala +1. Use reflection to generate the tool's input schema from a Go type (à la `metoro-io/mcp-golang`) -2. Explicitly build the input schema (ala `mark3labs/mcp-go`). +2. Explicitly build the input schema (à la `mark3labs/mcp-go`). Both of these have their advantages and disadvantages. Reflection is nice, because it allows you to bind directly to a Go API, and means that the JSON @@ -718,10 +720,10 @@ means that concerns like parsing and validation can be handled automatically. However, it can become cumbersome to express the full breadth of JSON schema using Go types or struct tags, and sometimes you want to express things that aren’t naturally modeled by Go types, like unions. Explicit schemas are simple -and readable, and gives the caller full control over their tool definition, but +and readable, and give the caller full control over their tool definition, but involve significant boilerplate. -We believe that a hybrid model works well, where the _initial_ schema is +We have found that a hybrid model works well, where the _initial_ schema is derived using reflection, but any customization on top of that schema is applied using variadic options. We achieve this using a `NewTool` helper, which generates the schema from the input type, and wraps the handler to provide @@ -729,7 +731,7 @@ parsing and validation. The schema (and potentially other features) can be customized using ToolOptions. ```go -// NewTool is a creates a Tool using reflection on the given handler. +// NewTool creates a Tool using reflection on the given handler. func NewTool[TInput any](name, description string, handler func(context.Context, *ServerSession, TInput) ([]Content, error), opts …ToolOption) *ServerTool type ToolOption interface { /* ... */ } @@ -752,7 +754,7 @@ struct { "name" and "Choices" are required, while "count" is optional. -As of writing, the only `ToolOption` is `Input`, which allows customizing the +As of this writing, the only `ToolOption` is `Input`, which allows customizing the input schema of the tool using schema options. These schema options are recursive, in the sense that they may also be applied to properties. @@ -840,7 +842,7 @@ server.AddPrompts( server.RemovePrompts("code_review") ``` -Clients can call ListPrompts to list the available prompts and GetPrompt to get one. +Clients can call `ListPrompts` to list the available prompts and `GetPrompt` to get one. ```go func (*ClientSession) ListPrompts(context.Context, *ListPromptParams) (*ListPromptsResult, error) @@ -862,8 +864,8 @@ func (*Server) AddResourceTemplates(templates...*ResourceTemplate) func (*Server) RemoveResourceTemplates(names ...string) ``` -Clients call ListResources to list the available resources, ReadResource to read -one of them, and ListResourceTemplates to list the templates: +Clients call `ListResources` to list the available resources, `ReadResource` to read +one of them, and `ListResourceTemplates` to list the templates: ```go func (*ClientSession) ListResources(context.Context, *ListResourcesParams) (*ListResourcesResult, error) @@ -909,7 +911,7 @@ defined its server-side behavior. ### Logging -Servers have access to a `slog.Logger` that writes to the client. A call to +ServerSessions have access to a `slog.Logger` that writes to the client. A call to a log method like `Info`is translated to a `LoggingMessageNotification` as follows: @@ -924,7 +926,7 @@ follows: to integers between the slog levels. For example, "notice" is level 2 because it is between "warning" (slog value 4) and "info" (slog value 0). The `mcp` package defines consts for these levels. To log at the "notice" - level, a server would call `Log(ctx, mcp.LevelNotice, "message")`. + level, a handler would call `session.Log(ctx, mcp.LevelNotice, "message")`. ### Pagination From 7b18363df201ff1ce84e5dfdcb2475403d2b6d2b Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 7 May 2025 23:26:42 -0400 Subject: [PATCH 038/196] go/ast/inspector: publish Cursor This CL moves internal/astutil/cursor to this package (and internal/astutil/edge to go/ast/edge) and makes both public. The API is identical except that At and Root are now methods on Inspector. The old cursor and edge packages are replaced by aliases, and will be removed in a follow-up, using //go:fix inline. Fixes golang/go#70859 Change-Id: Ibada236e9deb467bce40086dfb73d9199e60d0a0 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670835 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Auto-Submit: Alan Donovan Reviewed-by: Jonathan Amsterdam --- go/ast/edge/edge.go | 295 +++++++++++ go/ast/inspector/cursor.go | 493 +++++++++++++++++ .../ast/inspector}/cursor_test.go | 99 +--- go/ast/inspector/inspector.go | 28 +- go/ast/inspector/inspector_test.go | 27 +- go/ast/inspector/iter_test.go | 2 - internal/astutil/cursor/cursor.go | 500 +----------------- internal/astutil/cursor/hooks.go | 40 -- internal/astutil/edge/edge.go | 386 ++++---------- 9 files changed, 952 insertions(+), 918 deletions(-) create mode 100644 go/ast/edge/edge.go create mode 100644 go/ast/inspector/cursor.go rename {internal/astutil/cursor => go/ast/inspector}/cursor_test.go (84%) delete mode 100644 internal/astutil/cursor/hooks.go diff --git a/go/ast/edge/edge.go b/go/ast/edge/edge.go new file mode 100644 index 00000000000..4f6ccfd6e5e --- /dev/null +++ b/go/ast/edge/edge.go @@ -0,0 +1,295 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package edge defines identifiers for each field of an ast.Node +// struct type that refers to another Node. +package edge + +import ( + "fmt" + "go/ast" + "reflect" +) + +// A Kind describes a field of an ast.Node struct. +type Kind uint8 + +// String returns a description of the edge kind. +func (k Kind) String() string { + if k == Invalid { + return "" + } + info := fieldInfos[k] + return fmt.Sprintf("%v.%s", info.nodeType.Elem().Name(), info.name) +} + +// NodeType returns the pointer-to-struct type of the ast.Node implementation. +func (k Kind) NodeType() reflect.Type { return fieldInfos[k].nodeType } + +// FieldName returns the name of the field. +func (k Kind) FieldName() string { return fieldInfos[k].name } + +// FieldType returns the declared type of the field. +func (k Kind) FieldType() reflect.Type { return fieldInfos[k].fieldType } + +// Get returns the direct child of n identified by (k, idx). +// n's type must match k.NodeType(). +// idx must be a valid slice index, or -1 for a non-slice. +func (k Kind) Get(n ast.Node, idx int) ast.Node { + if k.NodeType() != reflect.TypeOf(n) { + panic(fmt.Sprintf("%v.Get(%T): invalid node type", k, n)) + } + v := reflect.ValueOf(n).Elem().Field(fieldInfos[k].index) + if idx != -1 { + v = v.Index(idx) // asserts valid index + } else { + // (The type assertion below asserts that v is not a slice.) + } + return v.Interface().(ast.Node) // may be nil +} + +const ( + Invalid Kind = iota // for nodes at the root of the traversal + + // Kinds are sorted alphabetically. + // Numbering is not stable. + // Each is named Type_Field, where Type is the + // ast.Node struct type and Field is the name of the field + + ArrayType_Elt + ArrayType_Len + AssignStmt_Lhs + AssignStmt_Rhs + BinaryExpr_X + BinaryExpr_Y + BlockStmt_List + BranchStmt_Label + CallExpr_Args + CallExpr_Fun + CaseClause_Body + CaseClause_List + ChanType_Value + CommClause_Body + CommClause_Comm + CommentGroup_List + CompositeLit_Elts + CompositeLit_Type + DeclStmt_Decl + DeferStmt_Call + Ellipsis_Elt + ExprStmt_X + FieldList_List + Field_Comment + Field_Doc + Field_Names + Field_Tag + Field_Type + File_Decls + File_Doc + File_Name + ForStmt_Body + ForStmt_Cond + ForStmt_Init + ForStmt_Post + FuncDecl_Body + FuncDecl_Doc + FuncDecl_Name + FuncDecl_Recv + FuncDecl_Type + FuncLit_Body + FuncLit_Type + FuncType_Params + FuncType_Results + FuncType_TypeParams + GenDecl_Doc + GenDecl_Specs + GoStmt_Call + IfStmt_Body + IfStmt_Cond + IfStmt_Else + IfStmt_Init + ImportSpec_Comment + ImportSpec_Doc + ImportSpec_Name + ImportSpec_Path + IncDecStmt_X + IndexExpr_Index + IndexExpr_X + IndexListExpr_Indices + IndexListExpr_X + InterfaceType_Methods + KeyValueExpr_Key + KeyValueExpr_Value + LabeledStmt_Label + LabeledStmt_Stmt + MapType_Key + MapType_Value + ParenExpr_X + RangeStmt_Body + RangeStmt_Key + RangeStmt_Value + RangeStmt_X + ReturnStmt_Results + SelectStmt_Body + SelectorExpr_Sel + SelectorExpr_X + SendStmt_Chan + SendStmt_Value + SliceExpr_High + SliceExpr_Low + SliceExpr_Max + SliceExpr_X + StarExpr_X + StructType_Fields + SwitchStmt_Body + SwitchStmt_Init + SwitchStmt_Tag + TypeAssertExpr_Type + TypeAssertExpr_X + TypeSpec_Comment + TypeSpec_Doc + TypeSpec_Name + TypeSpec_Type + TypeSpec_TypeParams + TypeSwitchStmt_Assign + TypeSwitchStmt_Body + TypeSwitchStmt_Init + UnaryExpr_X + ValueSpec_Comment + ValueSpec_Doc + ValueSpec_Names + ValueSpec_Type + ValueSpec_Values + + maxKind +) + +// Assert that the encoding fits in 7 bits, +// as the inspector relies on this. +// (We are currently at 104.) +var _ = [1 << 7]struct{}{}[maxKind] + +type fieldInfo struct { + nodeType reflect.Type // pointer-to-struct type of ast.Node implementation + name string + index int + fieldType reflect.Type +} + +func info[N ast.Node](fieldName string) fieldInfo { + nodePtrType := reflect.TypeFor[N]() + f, ok := nodePtrType.Elem().FieldByName(fieldName) + if !ok { + panic(fieldName) + } + return fieldInfo{nodePtrType, fieldName, f.Index[0], f.Type} +} + +var fieldInfos = [...]fieldInfo{ + Invalid: {}, + ArrayType_Elt: info[*ast.ArrayType]("Elt"), + ArrayType_Len: info[*ast.ArrayType]("Len"), + AssignStmt_Lhs: info[*ast.AssignStmt]("Lhs"), + AssignStmt_Rhs: info[*ast.AssignStmt]("Rhs"), + BinaryExpr_X: info[*ast.BinaryExpr]("X"), + BinaryExpr_Y: info[*ast.BinaryExpr]("Y"), + BlockStmt_List: info[*ast.BlockStmt]("List"), + BranchStmt_Label: info[*ast.BranchStmt]("Label"), + CallExpr_Args: info[*ast.CallExpr]("Args"), + CallExpr_Fun: info[*ast.CallExpr]("Fun"), + CaseClause_Body: info[*ast.CaseClause]("Body"), + CaseClause_List: info[*ast.CaseClause]("List"), + ChanType_Value: info[*ast.ChanType]("Value"), + CommClause_Body: info[*ast.CommClause]("Body"), + CommClause_Comm: info[*ast.CommClause]("Comm"), + CommentGroup_List: info[*ast.CommentGroup]("List"), + CompositeLit_Elts: info[*ast.CompositeLit]("Elts"), + CompositeLit_Type: info[*ast.CompositeLit]("Type"), + DeclStmt_Decl: info[*ast.DeclStmt]("Decl"), + DeferStmt_Call: info[*ast.DeferStmt]("Call"), + Ellipsis_Elt: info[*ast.Ellipsis]("Elt"), + ExprStmt_X: info[*ast.ExprStmt]("X"), + FieldList_List: info[*ast.FieldList]("List"), + Field_Comment: info[*ast.Field]("Comment"), + Field_Doc: info[*ast.Field]("Doc"), + Field_Names: info[*ast.Field]("Names"), + Field_Tag: info[*ast.Field]("Tag"), + Field_Type: info[*ast.Field]("Type"), + File_Decls: info[*ast.File]("Decls"), + File_Doc: info[*ast.File]("Doc"), + File_Name: info[*ast.File]("Name"), + ForStmt_Body: info[*ast.ForStmt]("Body"), + ForStmt_Cond: info[*ast.ForStmt]("Cond"), + ForStmt_Init: info[*ast.ForStmt]("Init"), + ForStmt_Post: info[*ast.ForStmt]("Post"), + FuncDecl_Body: info[*ast.FuncDecl]("Body"), + FuncDecl_Doc: info[*ast.FuncDecl]("Doc"), + FuncDecl_Name: info[*ast.FuncDecl]("Name"), + FuncDecl_Recv: info[*ast.FuncDecl]("Recv"), + FuncDecl_Type: info[*ast.FuncDecl]("Type"), + FuncLit_Body: info[*ast.FuncLit]("Body"), + FuncLit_Type: info[*ast.FuncLit]("Type"), + FuncType_Params: info[*ast.FuncType]("Params"), + FuncType_Results: info[*ast.FuncType]("Results"), + FuncType_TypeParams: info[*ast.FuncType]("TypeParams"), + GenDecl_Doc: info[*ast.GenDecl]("Doc"), + GenDecl_Specs: info[*ast.GenDecl]("Specs"), + GoStmt_Call: info[*ast.GoStmt]("Call"), + IfStmt_Body: info[*ast.IfStmt]("Body"), + IfStmt_Cond: info[*ast.IfStmt]("Cond"), + IfStmt_Else: info[*ast.IfStmt]("Else"), + IfStmt_Init: info[*ast.IfStmt]("Init"), + ImportSpec_Comment: info[*ast.ImportSpec]("Comment"), + ImportSpec_Doc: info[*ast.ImportSpec]("Doc"), + ImportSpec_Name: info[*ast.ImportSpec]("Name"), + ImportSpec_Path: info[*ast.ImportSpec]("Path"), + IncDecStmt_X: info[*ast.IncDecStmt]("X"), + IndexExpr_Index: info[*ast.IndexExpr]("Index"), + IndexExpr_X: info[*ast.IndexExpr]("X"), + IndexListExpr_Indices: info[*ast.IndexListExpr]("Indices"), + IndexListExpr_X: info[*ast.IndexListExpr]("X"), + InterfaceType_Methods: info[*ast.InterfaceType]("Methods"), + KeyValueExpr_Key: info[*ast.KeyValueExpr]("Key"), + KeyValueExpr_Value: info[*ast.KeyValueExpr]("Value"), + LabeledStmt_Label: info[*ast.LabeledStmt]("Label"), + LabeledStmt_Stmt: info[*ast.LabeledStmt]("Stmt"), + MapType_Key: info[*ast.MapType]("Key"), + MapType_Value: info[*ast.MapType]("Value"), + ParenExpr_X: info[*ast.ParenExpr]("X"), + RangeStmt_Body: info[*ast.RangeStmt]("Body"), + RangeStmt_Key: info[*ast.RangeStmt]("Key"), + RangeStmt_Value: info[*ast.RangeStmt]("Value"), + RangeStmt_X: info[*ast.RangeStmt]("X"), + ReturnStmt_Results: info[*ast.ReturnStmt]("Results"), + SelectStmt_Body: info[*ast.SelectStmt]("Body"), + SelectorExpr_Sel: info[*ast.SelectorExpr]("Sel"), + SelectorExpr_X: info[*ast.SelectorExpr]("X"), + SendStmt_Chan: info[*ast.SendStmt]("Chan"), + SendStmt_Value: info[*ast.SendStmt]("Value"), + SliceExpr_High: info[*ast.SliceExpr]("High"), + SliceExpr_Low: info[*ast.SliceExpr]("Low"), + SliceExpr_Max: info[*ast.SliceExpr]("Max"), + SliceExpr_X: info[*ast.SliceExpr]("X"), + StarExpr_X: info[*ast.StarExpr]("X"), + StructType_Fields: info[*ast.StructType]("Fields"), + SwitchStmt_Body: info[*ast.SwitchStmt]("Body"), + SwitchStmt_Init: info[*ast.SwitchStmt]("Init"), + SwitchStmt_Tag: info[*ast.SwitchStmt]("Tag"), + TypeAssertExpr_Type: info[*ast.TypeAssertExpr]("Type"), + TypeAssertExpr_X: info[*ast.TypeAssertExpr]("X"), + TypeSpec_Comment: info[*ast.TypeSpec]("Comment"), + TypeSpec_Doc: info[*ast.TypeSpec]("Doc"), + TypeSpec_Name: info[*ast.TypeSpec]("Name"), + TypeSpec_Type: info[*ast.TypeSpec]("Type"), + TypeSpec_TypeParams: info[*ast.TypeSpec]("TypeParams"), + TypeSwitchStmt_Assign: info[*ast.TypeSwitchStmt]("Assign"), + TypeSwitchStmt_Body: info[*ast.TypeSwitchStmt]("Body"), + TypeSwitchStmt_Init: info[*ast.TypeSwitchStmt]("Init"), + UnaryExpr_X: info[*ast.UnaryExpr]("X"), + ValueSpec_Comment: info[*ast.ValueSpec]("Comment"), + ValueSpec_Doc: info[*ast.ValueSpec]("Doc"), + ValueSpec_Names: info[*ast.ValueSpec]("Names"), + ValueSpec_Type: info[*ast.ValueSpec]("Type"), + ValueSpec_Values: info[*ast.ValueSpec]("Values"), +} diff --git a/go/ast/inspector/cursor.go b/go/ast/inspector/cursor.go new file mode 100644 index 00000000000..cd10afa5889 --- /dev/null +++ b/go/ast/inspector/cursor.go @@ -0,0 +1,493 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package inspector + +// TODO(adonovan): +// - review package documentation +// - apply-all //go:fix inline + +import ( + "fmt" + "go/ast" + "go/token" + "iter" + "reflect" + + "golang.org/x/tools/go/ast/edge" +) + +// A Cursor represents an [ast.Node]. It is immutable. +// +// Two Cursors compare equal if they represent the same node. +// +// Call [Inspector.Root] to obtain a valid cursor. +type Cursor struct { + in *Inspector + index int32 // index of push node; -1 for virtual root node +} + +// Root returns a cursor for the virtual root node, +// whose children are the files provided to [New]. +// +// Its [Cursor.Node] and [Cursor.Stack] methods return nil. +func (in *Inspector) Root() Cursor { + return Cursor{in, -1} +} + +// At returns the cursor at the specified index in the traversal, +// which must have been obtained from [Cursor.Index] on a Cursor +// belonging to the same Inspector (see [Cursor.Inspector]). +func (in *Inspector) At(index int32) Cursor { + if index < 0 { + panic("negative index") + } + if int(index) >= len(in.events) { + panic("index out of range for this inspector") + } + if in.events[index].index < index { + panic("invalid index") // (a push, not a pop) + } + return Cursor{in, index} +} + +// Inspector returns the cursor's Inspector. +func (c Cursor) Inspector() *Inspector { return c.in } + +// Index returns the index of this cursor position within the package. +// +// Clients should not assume anything about the numeric Index value +// except that it increases monotonically throughout the traversal. +// It is provided for use with [At]. +// +// Index must not be called on the Root node. +func (c Cursor) Index() int32 { + if c.index < 0 { + panic("Index called on Root node") + } + return c.index +} + +// Node returns the node at the current cursor position, +// or nil for the cursor returned by [Inspector.Root]. +func (c Cursor) Node() ast.Node { + if c.index < 0 { + return nil + } + return c.in.events[c.index].node +} + +// String returns information about the cursor's node, if any. +func (c Cursor) String() string { + if c.in == nil { + return "(invalid)" + } + if c.index < 0 { + return "(root)" + } + return reflect.TypeOf(c.Node()).String() +} + +// indices return the [start, end) half-open interval of event indices. +func (c Cursor) indices() (int32, int32) { + if c.index < 0 { + return 0, int32(len(c.in.events)) // root: all events + } else { + return c.index, c.in.events[c.index].index + 1 // just one subtree + } +} + +// Preorder returns an iterator over the nodes of the subtree +// represented by c in depth-first order. Each node in the sequence is +// represented by a Cursor that allows access to the Node, but may +// also be used to start a new traversal, or to obtain the stack of +// nodes enclosing the cursor. +// +// The traversal sequence is determined by [ast.Inspect]. The types +// argument, if non-empty, enables type-based filtering of events. The +// function f if is called only for nodes whose type matches an +// element of the types slice. +// +// If you need control over descent into subtrees, +// or need both pre- and post-order notifications, use [Cursor.Inspect] +func (c Cursor) Preorder(types ...ast.Node) iter.Seq[Cursor] { + mask := maskOf(types) + + return func(yield func(Cursor) bool) { + events := c.in.events + + for i, limit := c.indices(); i < limit; { + ev := events[i] + if ev.index > i { // push? + if ev.typ&mask != 0 && !yield(Cursor{c.in, i}) { + break + } + pop := ev.index + if events[pop].typ&mask == 0 { + // Subtree does not contain types: skip. + i = pop + 1 + continue + } + } + i++ + } + } +} + +// Inspect visits the nodes of the subtree represented by c in +// depth-first order. It calls f(n) for each node n before it +// visits n's children. If f returns true, Inspect invokes f +// recursively for each of the non-nil children of the node. +// +// Each node is represented by a Cursor that allows access to the +// Node, but may also be used to start a new traversal, or to obtain +// the stack of nodes enclosing the cursor. +// +// The complete traversal sequence is determined by [ast.Inspect]. +// The types argument, if non-empty, enables type-based filtering of +// events. The function f if is called only for nodes whose type +// matches an element of the types slice. +func (c Cursor) Inspect(types []ast.Node, f func(c Cursor) (descend bool)) { + mask := maskOf(types) + events := c.in.events + for i, limit := c.indices(); i < limit; { + ev := events[i] + if ev.index > i { + // push + pop := ev.index + if ev.typ&mask != 0 && !f(Cursor{c.in, i}) || + events[pop].typ&mask == 0 { + // The user opted not to descend, or the + // subtree does not contain types: + // skip past the pop. + i = pop + 1 + continue + } + } + i++ + } +} + +// Enclosing returns an iterator over the nodes enclosing the current +// current node, starting with the Cursor itself. +// +// Enclosing must not be called on the Root node (whose [Cursor.Node] returns nil). +// +// The types argument, if non-empty, enables type-based filtering of +// events: the sequence includes only enclosing nodes whose type +// matches an element of the types slice. +func (c Cursor) Enclosing(types ...ast.Node) iter.Seq[Cursor] { + if c.index < 0 { + panic("Cursor.Enclosing called on Root node") + } + + mask := maskOf(types) + + return func(yield func(Cursor) bool) { + events := c.in.events + for i := c.index; i >= 0; i = events[i].parent { + if events[i].typ&mask != 0 && !yield(Cursor{c.in, i}) { + break + } + } + } +} + +// Parent returns the parent of the current node. +// +// Parent must not be called on the Root node (whose [Cursor.Node] returns nil). +func (c Cursor) Parent() Cursor { + if c.index < 0 { + panic("Cursor.Parent called on Root node") + } + + return Cursor{c.in, c.in.events[c.index].parent} +} + +// ParentEdge returns the identity of the field in the parent node +// that holds this cursor's node, and if it is a list, the index within it. +// +// For example, f(x, y) is a CallExpr whose three children are Idents. +// f has edge kind [edge.CallExpr_Fun] and index -1. +// x and y have kind [edge.CallExpr_Args] and indices 0 and 1, respectively. +// +// If called on a child of the Root node, it returns ([edge.Invalid], -1). +// +// ParentEdge must not be called on the Root node (whose [Cursor.Node] returns nil). +func (c Cursor) ParentEdge() (edge.Kind, int) { + if c.index < 0 { + panic("Cursor.ParentEdge called on Root node") + } + events := c.in.events + pop := events[c.index].index + return unpackEdgeKindAndIndex(events[pop].parent) +} + +// ChildAt returns the cursor for the child of the +// current node identified by its edge and index. +// The index must be -1 if the edge.Kind is not a slice. +// The indicated child node must exist. +// +// ChildAt must not be called on the Root node (whose [Cursor.Node] returns nil). +// +// Invariant: c.Parent().ChildAt(c.ParentEdge()) == c. +func (c Cursor) ChildAt(k edge.Kind, idx int) Cursor { + target := packEdgeKindAndIndex(k, idx) + + // Unfortunately there's no shortcut to looping. + events := c.in.events + i := c.index + 1 + for { + pop := events[i].index + if pop < i { + break + } + if events[pop].parent == target { + return Cursor{c.in, i} + } + i = pop + 1 + } + panic(fmt.Sprintf("ChildAt(%v, %d): no such child of %v", k, idx, c)) +} + +// Child returns the cursor for n, which must be a direct child of c's Node. +// +// Child must not be called on the Root node (whose [Cursor.Node] returns nil). +func (c Cursor) Child(n ast.Node) Cursor { + if c.index < 0 { + panic("Cursor.Child called on Root node") + } + + if false { + // reference implementation + for child := range c.Children() { + if child.Node() == n { + return child + } + } + + } else { + // optimized implementation + events := c.in.events + for i := c.index + 1; events[i].index > i; i = events[i].index + 1 { + if events[i].node == n { + return Cursor{c.in, i} + } + } + } + panic(fmt.Sprintf("Child(%T): not a child of %v", n, c)) +} + +// NextSibling returns the cursor for the next sibling node in the same list +// (for example, of files, decls, specs, statements, fields, or expressions) as +// the current node. It returns (zero, false) if the node is the last node in +// the list, or is not part of a list. +// +// NextSibling must not be called on the Root node. +// +// See note at [Cursor.Children]. +func (c Cursor) NextSibling() (Cursor, bool) { + if c.index < 0 { + panic("Cursor.NextSibling called on Root node") + } + + events := c.in.events + i := events[c.index].index + 1 // after corresponding pop + if i < int32(len(events)) { + if events[i].index > i { // push? + return Cursor{c.in, i}, true + } + } + return Cursor{}, false +} + +// PrevSibling returns the cursor for the previous sibling node in the +// same list (for example, of files, decls, specs, statements, fields, +// or expressions) as the current node. It returns zero if the node is +// the first node in the list, or is not part of a list. +// +// It must not be called on the Root node. +// +// See note at [Cursor.Children]. +func (c Cursor) PrevSibling() (Cursor, bool) { + if c.index < 0 { + panic("Cursor.PrevSibling called on Root node") + } + + events := c.in.events + i := c.index - 1 + if i >= 0 { + if j := events[i].index; j < i { // pop? + return Cursor{c.in, j}, true + } + } + return Cursor{}, false +} + +// FirstChild returns the first direct child of the current node, +// or zero if it has no children. +func (c Cursor) FirstChild() (Cursor, bool) { + events := c.in.events + i := c.index + 1 // i=0 if c is root + if i < int32(len(events)) && events[i].index > i { // push? + return Cursor{c.in, i}, true + } + return Cursor{}, false +} + +// LastChild returns the last direct child of the current node, +// or zero if it has no children. +func (c Cursor) LastChild() (Cursor, bool) { + events := c.in.events + if c.index < 0 { // root? + if len(events) > 0 { + // return push of final event (a pop) + return Cursor{c.in, events[len(events)-1].index}, true + } + } else { + j := events[c.index].index - 1 // before corresponding pop + // Inv: j == c.index if c has no children + // or j is last child's pop. + if j > c.index { // c has children + return Cursor{c.in, events[j].index}, true + } + } + return Cursor{}, false +} + +// Children returns an iterator over the direct children of the +// current node, if any. +// +// When using Children, NextChild, and PrevChild, bear in mind that a +// Node's children may come from different fields, some of which may +// be lists of nodes without a distinguished intervening container +// such as [ast.BlockStmt]. +// +// For example, [ast.CaseClause] has a field List of expressions and a +// field Body of statements, so the children of a CaseClause are a mix +// of expressions and statements. Other nodes that have "uncontained" +// list fields include: +// +// - [ast.ValueSpec] (Names, Values) +// - [ast.CompositeLit] (Type, Elts) +// - [ast.IndexListExpr] (X, Indices) +// - [ast.CallExpr] (Fun, Args) +// - [ast.AssignStmt] (Lhs, Rhs) +// +// So, do not assume that the previous sibling of an ast.Stmt is also +// an ast.Stmt, or if it is, that they are executed sequentially, +// unless you have established that, say, its parent is a BlockStmt +// or its [Cursor.ParentEdge] is [edge.BlockStmt_List]. +// For example, given "for S1; ; S2 {}", the predecessor of S2 is S1, +// even though they are not executed in sequence. +func (c Cursor) Children() iter.Seq[Cursor] { + return func(yield func(Cursor) bool) { + c, ok := c.FirstChild() + for ok && yield(c) { + c, ok = c.NextSibling() + } + } +} + +// Contains reports whether c contains or is equal to c2. +// +// Both Cursors must belong to the same [Inspector]; +// neither may be its Root node. +func (c Cursor) Contains(c2 Cursor) bool { + if c.in != c2.in { + panic("different inspectors") + } + events := c.in.events + return c.index <= c2.index && events[c2.index].index <= events[c.index].index +} + +// FindNode returns the cursor for node n if it belongs to the subtree +// rooted at c. It returns zero if n is not found. +func (c Cursor) FindNode(n ast.Node) (Cursor, bool) { + + // FindNode is equivalent to this code, + // but more convenient and 15-20% faster: + if false { + for candidate := range c.Preorder(n) { + if candidate.Node() == n { + return candidate, true + } + } + return Cursor{}, false + } + + // TODO(adonovan): opt: should we assume Node.Pos is accurate + // and combine type-based filtering with position filtering + // like FindByPos? + + mask := maskOf([]ast.Node{n}) + events := c.in.events + + for i, limit := c.indices(); i < limit; i++ { + ev := events[i] + if ev.index > i { // push? + if ev.typ&mask != 0 && ev.node == n { + return Cursor{c.in, i}, true + } + pop := ev.index + if events[pop].typ&mask == 0 { + // Subtree does not contain type of n: skip. + i = pop + } + } + } + return Cursor{}, false +} + +// FindByPos returns the cursor for the innermost node n in the tree +// rooted at c such that n.Pos() <= start && end <= n.End(). +// (For an *ast.File, it uses the bounds n.FileStart-n.FileEnd.) +// +// It returns zero if none is found. +// Precondition: start <= end. +// +// See also [astutil.PathEnclosingInterval], which +// tolerates adjoining whitespace. +func (c Cursor) FindByPos(start, end token.Pos) (Cursor, bool) { + if end < start { + panic("end < start") + } + events := c.in.events + + // This algorithm could be implemented using c.Inspect, + // but it is about 2.5x slower. + + best := int32(-1) // push index of latest (=innermost) node containing range + for i, limit := c.indices(); i < limit; i++ { + ev := events[i] + if ev.index > i { // push? + n := ev.node + var nodeEnd token.Pos + if file, ok := n.(*ast.File); ok { + nodeEnd = file.FileEnd + // Note: files may be out of Pos order. + if file.FileStart > start { + i = ev.index // disjoint, after; skip to next file + continue + } + } else { + nodeEnd = n.End() + if n.Pos() > start { + break // disjoint, after; stop + } + } + // Inv: node.{Pos,FileStart} <= start + if end <= nodeEnd { + // node fully contains target range + best = i + } else if nodeEnd < start { + i = ev.index // disjoint, before; skip forward + } + } + } + if best >= 0 { + return Cursor{c.in, best}, true + } + return Cursor{}, false +} diff --git a/internal/astutil/cursor/cursor_test.go b/go/ast/inspector/cursor_test.go similarity index 84% rename from internal/astutil/cursor/cursor_test.go rename to go/ast/inspector/cursor_test.go index 0573512fc3b..90067c67060 100644 --- a/internal/astutil/cursor/cursor_test.go +++ b/go/ast/inspector/cursor_test.go @@ -1,88 +1,25 @@ -// Copyright 2024 The Go Authors. All rights reserved. +// Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.23 - -package cursor_test +package inspector_test import ( "fmt" "go/ast" - "go/build" "go/parser" "go/token" "iter" - "log" "math/rand" - "path/filepath" "reflect" "slices" "strings" "testing" "golang.org/x/tools/go/ast/inspector" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/astutil/edge" ) -// net/http package -var ( - netFset = token.NewFileSet() - netFiles []*ast.File - netInspect *inspector.Inspector -) - -func init() { - files, err := parseNetFiles() - if err != nil { - log.Fatal(err) - } - netFiles = files - netInspect = inspector.New(netFiles) -} - -func parseNetFiles() ([]*ast.File, error) { - pkg, err := build.Default.Import("net", "", 0) - if err != nil { - return nil, err - } - var files []*ast.File - for _, filename := range pkg.GoFiles { - filename = filepath.Join(pkg.Dir, filename) - f, err := parser.ParseFile(netFset, filename, nil, 0) - if err != nil { - return nil, err - } - files = append(files, f) - } - return files, nil -} - -// compare calls t.Error if !slices.Equal(nodesA, nodesB). -func compare[N comparable](t *testing.T, nodesA, nodesB []N) { - if len(nodesA) != len(nodesB) { - t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB)) - } else { - for i := range nodesA { - if a, b := nodesA[i], nodesB[i]; a != b { - t.Errorf("node %d is inconsistent: %T, %T", i, a, b) - } - } - } -} - -// firstN(n, seq), returns a slice of up to n elements of seq. -func firstN[T any](n int, seq iter.Seq[T]) (res []T) { - for x := range seq { - res = append(res, x) - if len(res) == n { - break - } - } - return res -} - func TestCursor_Preorder(t *testing.T) { inspect := netInspect @@ -90,7 +27,7 @@ func TestCursor_Preorder(t *testing.T) { // reference implementation var want []ast.Node - for cur := range cursor.Root(inspect).Preorder(nodeFilter...) { + for cur := range inspect.Root().Preorder(nodeFilter...) { want = append(want, cur.Node()) } @@ -100,7 +37,7 @@ func TestCursor_Preorder(t *testing.T) { // Check that break works. got = got[:0] - for _, c := range firstN(10, cursor.Root(inspect).Preorder(nodeFilter...)) { + for _, c := range firstN(10, inspect.Root().Preorder(nodeFilter...)) { got = append(got, c.Node()) } compare(t, got, want[:10]) @@ -127,7 +64,7 @@ func g() { ncalls = 0 ) - for curFunc := range cursor.Root(inspect).Preorder(funcDecls...) { + for curFunc := range inspect.Root().Preorder(funcDecls...) { _ = curFunc.Node().(*ast.FuncDecl) // Check edge and index. @@ -167,7 +104,7 @@ func g() { // nested Inspect traversal inspectCount := 0 - curFunc.Inspect(callExprs, func(curCall cursor.Cursor) (proceed bool) { + curFunc.Inspect(callExprs, func(curCall inspector.Cursor) (proceed bool) { _ = curCall.Node().(*ast.CallExpr) inspectCount++ stack := slices.Collect(curCall.Enclosing()) @@ -198,7 +135,7 @@ func TestCursor_Children(t *testing.T) { // Assert that Cursor.Children agrees with // reference implementation for every node. var want, got []ast.Node - for c := range cursor.Root(inspect).Preorder() { + for c := range inspect.Root().Preorder() { // reference implementation want = want[:0] @@ -265,7 +202,7 @@ func TestCursor_Inspect(t *testing.T) { // Test Cursor.Inspect implementation. var nodesB []ast.Node - cursor.Root(inspect).Inspect(switches, func(c cursor.Cursor) (proceed bool) { + inspect.Root().Inspect(switches, func(c inspector.Cursor) (proceed bool) { n := c.Node() nodesB = append(nodesB, n) return !is[*ast.SwitchStmt](n) // descend only into TypeSwitchStmt @@ -293,7 +230,7 @@ func TestCursor_FindNode(t *testing.T) { // starting at the root. // // (We use BasicLit because they are numerous.) - root := cursor.Root(inspect) + root := inspect.Root() for c := range root.Preorder((*ast.BasicLit)(nil)) { node := c.Node() got, ok := root.FindNode(node) @@ -333,7 +270,7 @@ func TestCursor_FindPos_order(t *testing.T) { target := netFiles[7].Decls[0] // Find the target decl by its position. - cur, ok := cursor.Root(netInspect).FindByPos(target.Pos(), target.End()) + cur, ok := netInspect.Root().FindByPos(target.Pos(), target.End()) if !ok || cur.Node() != target { t.Fatalf("unshuffled: FindPos(%T) = (%v, %t)", target, cur, ok) } @@ -346,14 +283,14 @@ func TestCursor_FindPos_order(t *testing.T) { // Find it again. inspect := inspector.New(files) - cur, ok = cursor.Root(inspect).FindByPos(target.Pos(), target.End()) + cur, ok = inspect.Root().FindByPos(target.Pos(), target.End()) if !ok || cur.Node() != target { t.Fatalf("shuffled: FindPos(%T) = (%v, %t)", target, cur, ok) } } func TestCursor_Edge(t *testing.T) { - root := cursor.Root(netInspect) + root := netInspect.Root() for cur := range root.Preorder() { if cur == root { continue // root node @@ -462,10 +399,8 @@ func sliceTypes[T any](slice []T) string { return buf.String() } -// (partially duplicates benchmark in go/ast/inspector) func BenchmarkInspectCalls(b *testing.B) { inspect := netInspect - b.ResetTimer() // Measure marginal cost of traversal. @@ -497,7 +432,7 @@ func BenchmarkInspectCalls(b *testing.B) { b.Run("Cursor", func(b *testing.B) { var ncalls int for range b.N { - for cur := range cursor.Root(inspect).Preorder(callExprs...) { + for cur := range inspect.Root().Preorder(callExprs...) { _ = cur.Node().(*ast.CallExpr) ncalls++ } @@ -507,7 +442,7 @@ func BenchmarkInspectCalls(b *testing.B) { b.Run("CursorEnclosing", func(b *testing.B) { var ncalls int for range b.N { - for cur := range cursor.Root(inspect).Preorder(callExprs...) { + for cur := range inspect.Root().Preorder(callExprs...) { _ = cur.Node().(*ast.CallExpr) for range cur.Enclosing() { } @@ -519,13 +454,13 @@ func BenchmarkInspectCalls(b *testing.B) { // This benchmark compares methods for finding a known node in a tree. func BenchmarkCursor_FindNode(b *testing.B) { - root := cursor.Root(netInspect) + root := netInspect.Root() callExprs := []ast.Node{(*ast.CallExpr)(nil)} // Choose a needle in the haystack to use as the search target: // a CallExpr not too near the start nor at too shallow a depth. - var needle cursor.Cursor + var needle inspector.Cursor { count := 0 found := false @@ -547,7 +482,7 @@ func BenchmarkCursor_FindNode(b *testing.B) { b.Run("Cursor.Preorder", func(b *testing.B) { needleNode := needle.Node() for range b.N { - var found cursor.Cursor + var found inspector.Cursor for c := range root.Preorder(callExprs...) { if c.Node() == needleNode { found = c diff --git a/go/ast/inspector/inspector.go b/go/ast/inspector/inspector.go index 674490a65b4..b07318ac4c5 100644 --- a/go/ast/inspector/inspector.go +++ b/go/ast/inspector/inspector.go @@ -48,18 +48,12 @@ type Inspector struct { events []event } -//go:linkname events golang.org/x/tools/go/ast/inspector.events -func events(in *Inspector) []event { return in.events } - -//go:linkname packEdgeKindAndIndex golang.org/x/tools/go/ast/inspector.packEdgeKindAndIndex func packEdgeKindAndIndex(ek edge.Kind, index int) int32 { return int32(uint32(index+1)<<7 | uint32(ek)) } // unpackEdgeKindAndIndex unpacks the edge kind and edge index (within // an []ast.Node slice) from the parent field of a pop event. -// -//go:linkname unpackEdgeKindAndIndex golang.org/x/tools/go/ast/inspector.unpackEdgeKindAndIndex func unpackEdgeKindAndIndex(x int32) (edge.Kind, int) { // The "parent" field of a pop node holds the // edge Kind in the lower 7 bits and the index+1 @@ -92,6 +86,11 @@ type event struct { // The types argument, if non-empty, enables type-based filtering of // events. The function f is called only for nodes whose type // matches an element of the types slice. +// +// The [Cursor.Preorder] method provides a richer alternative interface. +// Example: +// +// for c := range in.Root().Preorder(types) { ... } func (in *Inspector) Preorder(types []ast.Node, f func(ast.Node)) { // Because it avoids postorder calls to f, and the pruning // check, Preorder is almost twice as fast as Nodes. The two @@ -135,6 +134,14 @@ func (in *Inspector) Preorder(types []ast.Node, f func(ast.Node)) { // The types argument, if non-empty, enables type-based filtering of // events. The function f if is called only for nodes whose type // matches an element of the types slice. +// +// The [Cursor.Inspect] method provides a richer alternative interface. +// Example: +// +// in.Root().Inspect(types, func(c Cursor) bool { +// ... +// return true +// } func (in *Inspector) Nodes(types []ast.Node, f func(n ast.Node, push bool) (proceed bool)) { mask := maskOf(types) for i := int32(0); i < int32(len(in.events)); { @@ -168,6 +175,15 @@ func (in *Inspector) Nodes(types []ast.Node, f func(n ast.Node, push bool) (proc // supplies each call to f an additional argument, the current // traversal stack. The stack's first element is the outermost node, // an *ast.File; its last is the innermost, n. +// +// The [Cursor.Inspect] method provides a richer alternative interface. +// Example: +// +// in.Root().Inspect(types, func(c Cursor) bool { +// stack := slices.Collect(c.Enclosing()) +// ... +// return true +// }) func (in *Inspector) WithStack(types []ast.Node, f func(n ast.Node, push bool, stack []ast.Node) (proceed bool)) { mask := maskOf(types) var stack []ast.Node diff --git a/go/ast/inspector/inspector_test.go b/go/ast/inspector/inspector_test.go index a19ba653e0a..4c017ce2dc8 100644 --- a/go/ast/inspector/inspector_test.go +++ b/go/ast/inspector/inspector_test.go @@ -19,7 +19,12 @@ import ( "golang.org/x/tools/go/ast/inspector" ) -var netFiles []*ast.File +// net/http package +var ( + netFset = token.NewFileSet() + netFiles []*ast.File + netInspect *inspector.Inspector +) func init() { files, err := parseNetFiles() @@ -27,6 +32,7 @@ func init() { log.Fatal(err) } netFiles = files + netInspect = inspector.New(netFiles) } func parseNetFiles() ([]*ast.File, error) { @@ -34,11 +40,10 @@ func parseNetFiles() ([]*ast.File, error) { if err != nil { return nil, err } - fset := token.NewFileSet() var files []*ast.File for _, filename := range pkg.GoFiles { filename = filepath.Join(pkg.Dir, filename) - f, err := parser.ParseFile(fset, filename, nil, 0) + f, err := parser.ParseFile(netFset, filename, nil, 0) if err != nil { return nil, err } @@ -292,22 +297,6 @@ func BenchmarkInspectFilter(b *testing.B) { } } -func BenchmarkInspectCalls(b *testing.B) { - b.StopTimer() - inspect := inspector.New(netFiles) - b.StartTimer() - - // Measure marginal cost of traversal. - nodeFilter := []ast.Node{(*ast.CallExpr)(nil)} - var ncalls int - for i := 0; i < b.N; i++ { - inspect.Preorder(nodeFilter, func(n ast.Node) { - _ = n.(*ast.CallExpr) - ncalls++ - }) - } -} - func BenchmarkASTInspect(b *testing.B) { var ndecls, nlits int for i := 0; i < b.N; i++ { diff --git a/go/ast/inspector/iter_test.go b/go/ast/inspector/iter_test.go index 2f52998c558..99882c65be8 100644 --- a/go/ast/inspector/iter_test.go +++ b/go/ast/inspector/iter_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.23 - package inspector_test import ( diff --git a/internal/astutil/cursor/cursor.go b/internal/astutil/cursor/cursor.go index 78d874a86fa..e328c484a08 100644 --- a/internal/astutil/cursor/cursor.go +++ b/internal/astutil/cursor/cursor.go @@ -2,500 +2,16 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.23 - -// Package cursor augments [inspector.Inspector] with [Cursor] -// functionality allowing more flexibility and control during -// inspection. -// -// This package is a temporary private extension of inspector until -// proposal #70859 is accepted, and which point it will be moved into -// the inspector package, and [Root] will become a method of -// Inspector. +// Package cursor is deprecated; use [inspector.Cursor]. package cursor -import ( - "fmt" - "go/ast" - "go/token" - "iter" - "reflect" - - "golang.org/x/tools/go/ast/inspector" - "golang.org/x/tools/internal/astutil/edge" -) - -// A Cursor represents an [ast.Node]. It is immutable. -// -// Two Cursors compare equal if they represent the same node. -// -// Call [Root] to obtain a valid cursor. -type Cursor struct { - in *inspector.Inspector - index int32 // index of push node; -1 for virtual root node -} - -// Root returns a cursor for the virtual root node, -// whose children are the files provided to [New]. -// -// Its [Cursor.Node] and [Cursor.Stack] methods return nil. -func Root(in *inspector.Inspector) Cursor { - return Cursor{in, -1} -} - -// At returns the cursor at the specified index in the traversal, -// which must have been obtained from [Cursor.Index] on a Cursor -// belonging to the same Inspector (see [Cursor.Inspector]). -func At(in *inspector.Inspector, index int32) Cursor { - if index < 0 { - panic("negative index") - } - events := events(in) - if int(index) >= len(events) { - panic("index out of range for this inspector") - } - if events[index].index < index { - panic("invalid index") // (a push, not a pop) - } - return Cursor{in, index} -} - -// Inspector returns the cursor's Inspector. -func (c Cursor) Inspector() *inspector.Inspector { return c.in } - -// Index returns the index of this cursor position within the package. -// -// Clients should not assume anything about the numeric Index value -// except that it increases monotonically throughout the traversal. -// It is provided for use with [At]. -// -// Index must not be called on the Root node. -func (c Cursor) Index() int32 { - if c.index < 0 { - panic("Index called on Root node") - } - return c.index -} - -// Node returns the node at the current cursor position, -// or nil for the cursor returned by [Inspector.Root]. -func (c Cursor) Node() ast.Node { - if c.index < 0 { - return nil - } - return c.events()[c.index].node -} - -// String returns information about the cursor's node, if any. -func (c Cursor) String() string { - if c.in == nil { - return "(invalid)" - } - if c.index < 0 { - return "(root)" - } - return reflect.TypeOf(c.Node()).String() -} - -// indices return the [start, end) half-open interval of event indices. -func (c Cursor) indices() (int32, int32) { - if c.index < 0 { - return 0, int32(len(c.events())) // root: all events - } else { - return c.index, c.events()[c.index].index + 1 // just one subtree - } -} - -// Preorder returns an iterator over the nodes of the subtree -// represented by c in depth-first order. Each node in the sequence is -// represented by a Cursor that allows access to the Node, but may -// also be used to start a new traversal, or to obtain the stack of -// nodes enclosing the cursor. -// -// The traversal sequence is determined by [ast.Inspect]. The types -// argument, if non-empty, enables type-based filtering of events. The -// function f if is called only for nodes whose type matches an -// element of the types slice. -// -// If you need control over descent into subtrees, -// or need both pre- and post-order notifications, use [Cursor.Inspect] -func (c Cursor) Preorder(types ...ast.Node) iter.Seq[Cursor] { - mask := maskOf(types) - - return func(yield func(Cursor) bool) { - events := c.events() - - for i, limit := c.indices(); i < limit; { - ev := events[i] - if ev.index > i { // push? - if ev.typ&mask != 0 && !yield(Cursor{c.in, i}) { - break - } - pop := ev.index - if events[pop].typ&mask == 0 { - // Subtree does not contain types: skip. - i = pop + 1 - continue - } - } - i++ - } - } -} - -// Inspect visits the nodes of the subtree represented by c in -// depth-first order. It calls f(n) for each node n before it -// visits n's children. If f returns true, Inspect invokes f -// recursively for each of the non-nil children of the node. -// -// Each node is represented by a Cursor that allows access to the -// Node, but may also be used to start a new traversal, or to obtain -// the stack of nodes enclosing the cursor. -// -// The complete traversal sequence is determined by [ast.Inspect]. -// The types argument, if non-empty, enables type-based filtering of -// events. The function f if is called only for nodes whose type -// matches an element of the types slice. -func (c Cursor) Inspect(types []ast.Node, f func(c Cursor) (descend bool)) { - mask := maskOf(types) - events := c.events() - for i, limit := c.indices(); i < limit; { - ev := events[i] - if ev.index > i { - // push - pop := ev.index - if ev.typ&mask != 0 && !f(Cursor{c.in, i}) || - events[pop].typ&mask == 0 { - // The user opted not to descend, or the - // subtree does not contain types: - // skip past the pop. - i = pop + 1 - continue - } - } - i++ - } -} - -// Enclosing returns an iterator over the nodes enclosing the current -// current node, starting with the Cursor itself. -// -// Enclosing must not be called on the Root node (whose [Cursor.Node] returns nil). -// -// The types argument, if non-empty, enables type-based filtering of -// events: the sequence includes only enclosing nodes whose type -// matches an element of the types slice. -func (c Cursor) Enclosing(types ...ast.Node) iter.Seq[Cursor] { - if c.index < 0 { - panic("Cursor.Enclosing called on Root node") - } - - mask := maskOf(types) - - return func(yield func(Cursor) bool) { - events := c.events() - for i := c.index; i >= 0; i = events[i].parent { - if events[i].typ&mask != 0 && !yield(Cursor{c.in, i}) { - break - } - } - } -} - -// Parent returns the parent of the current node. -// -// Parent must not be called on the Root node (whose [Cursor.Node] returns nil). -func (c Cursor) Parent() Cursor { - if c.index < 0 { - panic("Cursor.Parent called on Root node") - } - - return Cursor{c.in, c.events()[c.index].parent} -} - -// ParentEdge returns the identity of the field in the parent node -// that holds this cursor's node, and if it is a list, the index within it. -// -// For example, f(x, y) is a CallExpr whose three children are Idents. -// f has edge kind [edge.CallExpr_Fun] and index -1. -// x and y have kind [edge.CallExpr_Args] and indices 0 and 1, respectively. -// -// If called on a child of the Root node, it returns ([edge.Invalid], -1). -// -// ParentEdge must not be called on the Root node (whose [Cursor.Node] returns nil). -func (c Cursor) ParentEdge() (edge.Kind, int) { - if c.index < 0 { - panic("Cursor.ParentEdge called on Root node") - } - events := c.events() - pop := events[c.index].index - return unpackEdgeKindAndIndex(events[pop].parent) -} - -// ChildAt returns the cursor for the child of the -// current node identified by its edge and index. -// The index must be -1 if the edge.Kind is not a slice. -// The indicated child node must exist. -// -// ChildAt must not be called on the Root node (whose [Cursor.Node] returns nil). -// -// Invariant: c.Parent().ChildAt(c.ParentEdge()) == c. -func (c Cursor) ChildAt(k edge.Kind, idx int) Cursor { - target := packEdgeKindAndIndex(k, idx) - - // Unfortunately there's no shortcut to looping. - events := c.events() - i := c.index + 1 - for { - pop := events[i].index - if pop < i { - break - } - if events[pop].parent == target { - return Cursor{c.in, i} - } - i = pop + 1 - } - panic(fmt.Sprintf("ChildAt(%v, %d): no such child of %v", k, idx, c)) -} - -// Child returns the cursor for n, which must be a direct child of c's Node. -// -// Child must not be called on the Root node (whose [Cursor.Node] returns nil). -func (c Cursor) Child(n ast.Node) Cursor { - if c.index < 0 { - panic("Cursor.Child called on Root node") - } - - if false { - // reference implementation - for child := range c.Children() { - if child.Node() == n { - return child - } - } - - } else { - // optimized implementation - events := c.events() - for i := c.index + 1; events[i].index > i; i = events[i].index + 1 { - if events[i].node == n { - return Cursor{c.in, i} - } - } - } - panic(fmt.Sprintf("Child(%T): not a child of %v", n, c)) -} - -// NextSibling returns the cursor for the next sibling node in the same list -// (for example, of files, decls, specs, statements, fields, or expressions) as -// the current node. It returns (zero, false) if the node is the last node in -// the list, or is not part of a list. -// -// NextSibling must not be called on the Root node. -// -// See note at [Cursor.Children]. -func (c Cursor) NextSibling() (Cursor, bool) { - if c.index < 0 { - panic("Cursor.NextSibling called on Root node") - } - - events := c.events() - i := events[c.index].index + 1 // after corresponding pop - if i < int32(len(events)) { - if events[i].index > i { // push? - return Cursor{c.in, i}, true - } - } - return Cursor{}, false -} - -// PrevSibling returns the cursor for the previous sibling node in the -// same list (for example, of files, decls, specs, statements, fields, -// or expressions) as the current node. It returns zero if the node is -// the first node in the list, or is not part of a list. -// -// It must not be called on the Root node. -// -// See note at [Cursor.Children]. -func (c Cursor) PrevSibling() (Cursor, bool) { - if c.index < 0 { - panic("Cursor.PrevSibling called on Root node") - } - - events := c.events() - i := c.index - 1 - if i >= 0 { - if j := events[i].index; j < i { // pop? - return Cursor{c.in, j}, true - } - } - return Cursor{}, false -} - -// FirstChild returns the first direct child of the current node, -// or zero if it has no children. -func (c Cursor) FirstChild() (Cursor, bool) { - events := c.events() - i := c.index + 1 // i=0 if c is root - if i < int32(len(events)) && events[i].index > i { // push? - return Cursor{c.in, i}, true - } - return Cursor{}, false -} - -// LastChild returns the last direct child of the current node, -// or zero if it has no children. -func (c Cursor) LastChild() (Cursor, bool) { - events := c.events() - if c.index < 0 { // root? - if len(events) > 0 { - // return push of final event (a pop) - return Cursor{c.in, events[len(events)-1].index}, true - } - } else { - j := events[c.index].index - 1 // before corresponding pop - // Inv: j == c.index if c has no children - // or j is last child's pop. - if j > c.index { // c has children - return Cursor{c.in, events[j].index}, true - } - } - return Cursor{}, false -} - -// Children returns an iterator over the direct children of the -// current node, if any. -// -// When using Children, NextChild, and PrevChild, bear in mind that a -// Node's children may come from different fields, some of which may -// be lists of nodes without a distinguished intervening container -// such as [ast.BlockStmt]. -// -// For example, [ast.CaseClause] has a field List of expressions and a -// field Body of statements, so the children of a CaseClause are a mix -// of expressions and statements. Other nodes that have "uncontained" -// list fields include: -// -// - [ast.ValueSpec] (Names, Values) -// - [ast.CompositeLit] (Type, Elts) -// - [ast.IndexListExpr] (X, Indices) -// - [ast.CallExpr] (Fun, Args) -// - [ast.AssignStmt] (Lhs, Rhs) -// -// So, do not assume that the previous sibling of an ast.Stmt is also -// an ast.Stmt, or if it is, that they are executed sequentially, -// unless you have established that, say, its parent is a BlockStmt -// or its [Cursor.ParentEdge] is [edge.BlockStmt_List]. -// For example, given "for S1; ; S2 {}", the predecessor of S2 is S1, -// even though they are not executed in sequence. -func (c Cursor) Children() iter.Seq[Cursor] { - return func(yield func(Cursor) bool) { - c, ok := c.FirstChild() - for ok && yield(c) { - c, ok = c.NextSibling() - } - } -} - -// Contains reports whether c contains or is equal to c2. -// -// Both Cursors must belong to the same [Inspector]; -// neither may be its Root node. -func (c Cursor) Contains(c2 Cursor) bool { - if c.in != c2.in { - panic("different inspectors") - } - events := c.events() - return c.index <= c2.index && events[c2.index].index <= events[c.index].index -} - -// FindNode returns the cursor for node n if it belongs to the subtree -// rooted at c. It returns zero if n is not found. -func (c Cursor) FindNode(n ast.Node) (Cursor, bool) { - - // FindNode is equivalent to this code, - // but more convenient and 15-20% faster: - if false { - for candidate := range c.Preorder(n) { - if candidate.Node() == n { - return candidate, true - } - } - return Cursor{}, false - } - - // TODO(adonovan): opt: should we assume Node.Pos is accurate - // and combine type-based filtering with position filtering - // like FindByPos? - - mask := maskOf([]ast.Node{n}) - events := c.events() - - for i, limit := c.indices(); i < limit; i++ { - ev := events[i] - if ev.index > i { // push? - if ev.typ&mask != 0 && ev.node == n { - return Cursor{c.in, i}, true - } - pop := ev.index - if events[pop].typ&mask == 0 { - // Subtree does not contain type of n: skip. - i = pop - } - } - } - return Cursor{}, false -} +import "golang.org/x/tools/go/ast/inspector" -// FindByPos returns the cursor for the innermost node n in the tree -// rooted at c such that n.Pos() <= start && end <= n.End(). -// (For an *ast.File, it uses the bounds n.FileStart-n.FileEnd.) -// -// It returns zero if none is found. -// Precondition: start <= end. -// -// See also [astutil.PathEnclosingInterval], which -// tolerates adjoining whitespace. -func (c Cursor) FindByPos(start, end token.Pos) (Cursor, bool) { - if end < start { - panic("end < start") - } - events := c.events() +//go:fix inline +type Cursor = inspector.Cursor - // This algorithm could be implemented using c.Inspect, - // but it is about 2.5x slower. +//go:fix inline +func Root(in *inspector.Inspector) inspector.Cursor { return in.Root() } - best := int32(-1) // push index of latest (=innermost) node containing range - for i, limit := c.indices(); i < limit; i++ { - ev := events[i] - if ev.index > i { // push? - n := ev.node - var nodeEnd token.Pos - if file, ok := n.(*ast.File); ok { - nodeEnd = file.FileEnd - // Note: files may be out of Pos order. - if file.FileStart > start { - i = ev.index // disjoint, after; skip to next file - continue - } - } else { - nodeEnd = n.End() - if n.Pos() > start { - break // disjoint, after; stop - } - } - // Inv: node.{Pos,FileStart} <= start - if end <= nodeEnd { - // node fully contains target range - best = i - } else if nodeEnd < start { - i = ev.index // disjoint, before; skip forward - } - } - } - if best >= 0 { - return Cursor{c.in, best}, true - } - return Cursor{}, false -} +//go:fix inline +func At(in *inspector.Inspector, index int32) inspector.Cursor { return in.At(index) } diff --git a/internal/astutil/cursor/hooks.go b/internal/astutil/cursor/hooks.go deleted file mode 100644 index 0257d61d778..00000000000 --- a/internal/astutil/cursor/hooks.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.23 - -package cursor - -import ( - "go/ast" - _ "unsafe" // for go:linkname - - "golang.org/x/tools/go/ast/inspector" - "golang.org/x/tools/internal/astutil/edge" -) - -// This file defines backdoor access to inspector. - -// Copied from inspector.event; must remain in sync. -// (Note that the linkname effects a type coercion too.) -type event struct { - node ast.Node - typ uint64 // typeOf(node) on push event, or union of typ strictly between push and pop events on pop events - index int32 // index of corresponding push or pop event (relative to this event's index, +ve=push, -ve=pop) - parent int32 // index of parent's push node (push nodes only); or edge and index, bit packed (pop nodes only) -} - -//go:linkname maskOf golang.org/x/tools/go/ast/inspector.maskOf -func maskOf(nodes []ast.Node) uint64 - -//go:linkname events golang.org/x/tools/go/ast/inspector.events -func events(in *inspector.Inspector) []event - -//go:linkname packEdgeKindAndIndex golang.org/x/tools/go/ast/inspector.packEdgeKindAndIndex -func packEdgeKindAndIndex(edge.Kind, int) int32 - -//go:linkname unpackEdgeKindAndIndex golang.org/x/tools/go/ast/inspector.unpackEdgeKindAndIndex -func unpackEdgeKindAndIndex(int32) (edge.Kind, int) - -func (c Cursor) events() []event { return events(c.in) } diff --git a/internal/astutil/edge/edge.go b/internal/astutil/edge/edge.go index 4f6ccfd6e5e..c3f7661cbad 100644 --- a/internal/astutil/edge/edge.go +++ b/internal/astutil/edge/edge.go @@ -6,290 +6,122 @@ // struct type that refers to another Node. package edge -import ( - "fmt" - "go/ast" - "reflect" -) - -// A Kind describes a field of an ast.Node struct. -type Kind uint8 - -// String returns a description of the edge kind. -func (k Kind) String() string { - if k == Invalid { - return "" - } - info := fieldInfos[k] - return fmt.Sprintf("%v.%s", info.nodeType.Elem().Name(), info.name) -} - -// NodeType returns the pointer-to-struct type of the ast.Node implementation. -func (k Kind) NodeType() reflect.Type { return fieldInfos[k].nodeType } - -// FieldName returns the name of the field. -func (k Kind) FieldName() string { return fieldInfos[k].name } - -// FieldType returns the declared type of the field. -func (k Kind) FieldType() reflect.Type { return fieldInfos[k].fieldType } +import "golang.org/x/tools/go/ast/edge" -// Get returns the direct child of n identified by (k, idx). -// n's type must match k.NodeType(). -// idx must be a valid slice index, or -1 for a non-slice. -func (k Kind) Get(n ast.Node, idx int) ast.Node { - if k.NodeType() != reflect.TypeOf(n) { - panic(fmt.Sprintf("%v.Get(%T): invalid node type", k, n)) - } - v := reflect.ValueOf(n).Elem().Field(fieldInfos[k].index) - if idx != -1 { - v = v.Index(idx) // asserts valid index - } else { - // (The type assertion below asserts that v is not a slice.) - } - return v.Interface().(ast.Node) // may be nil -} +//go:fix inline +type Kind = edge.Kind +//go:fix inline const ( - Invalid Kind = iota // for nodes at the root of the traversal + Invalid Kind = edge.Invalid // Kinds are sorted alphabetically. // Numbering is not stable. // Each is named Type_Field, where Type is the // ast.Node struct type and Field is the name of the field - ArrayType_Elt - ArrayType_Len - AssignStmt_Lhs - AssignStmt_Rhs - BinaryExpr_X - BinaryExpr_Y - BlockStmt_List - BranchStmt_Label - CallExpr_Args - CallExpr_Fun - CaseClause_Body - CaseClause_List - ChanType_Value - CommClause_Body - CommClause_Comm - CommentGroup_List - CompositeLit_Elts - CompositeLit_Type - DeclStmt_Decl - DeferStmt_Call - Ellipsis_Elt - ExprStmt_X - FieldList_List - Field_Comment - Field_Doc - Field_Names - Field_Tag - Field_Type - File_Decls - File_Doc - File_Name - ForStmt_Body - ForStmt_Cond - ForStmt_Init - ForStmt_Post - FuncDecl_Body - FuncDecl_Doc - FuncDecl_Name - FuncDecl_Recv - FuncDecl_Type - FuncLit_Body - FuncLit_Type - FuncType_Params - FuncType_Results - FuncType_TypeParams - GenDecl_Doc - GenDecl_Specs - GoStmt_Call - IfStmt_Body - IfStmt_Cond - IfStmt_Else - IfStmt_Init - ImportSpec_Comment - ImportSpec_Doc - ImportSpec_Name - ImportSpec_Path - IncDecStmt_X - IndexExpr_Index - IndexExpr_X - IndexListExpr_Indices - IndexListExpr_X - InterfaceType_Methods - KeyValueExpr_Key - KeyValueExpr_Value - LabeledStmt_Label - LabeledStmt_Stmt - MapType_Key - MapType_Value - ParenExpr_X - RangeStmt_Body - RangeStmt_Key - RangeStmt_Value - RangeStmt_X - ReturnStmt_Results - SelectStmt_Body - SelectorExpr_Sel - SelectorExpr_X - SendStmt_Chan - SendStmt_Value - SliceExpr_High - SliceExpr_Low - SliceExpr_Max - SliceExpr_X - StarExpr_X - StructType_Fields - SwitchStmt_Body - SwitchStmt_Init - SwitchStmt_Tag - TypeAssertExpr_Type - TypeAssertExpr_X - TypeSpec_Comment - TypeSpec_Doc - TypeSpec_Name - TypeSpec_Type - TypeSpec_TypeParams - TypeSwitchStmt_Assign - TypeSwitchStmt_Body - TypeSwitchStmt_Init - UnaryExpr_X - ValueSpec_Comment - ValueSpec_Doc - ValueSpec_Names - ValueSpec_Type - ValueSpec_Values - - maxKind + ArrayType_Elt = edge.ArrayType_Elt + ArrayType_Len = edge.ArrayType_Len + AssignStmt_Lhs = edge.AssignStmt_Lhs + AssignStmt_Rhs = edge.AssignStmt_Rhs + BinaryExpr_X = edge.BinaryExpr_X + BinaryExpr_Y = edge.BinaryExpr_Y + BlockStmt_List = edge.BlockStmt_List + BranchStmt_Label = edge.BranchStmt_Label + CallExpr_Args = edge.CallExpr_Args + CallExpr_Fun = edge.CallExpr_Fun + CaseClause_Body = edge.CaseClause_Body + CaseClause_List = edge.CaseClause_List + ChanType_Value = edge.ChanType_Value + CommClause_Body = edge.CommClause_Body + CommClause_Comm = edge.CommClause_Comm + CommentGroup_List = edge.CommentGroup_List + CompositeLit_Elts = edge.CompositeLit_Elts + CompositeLit_Type = edge.CompositeLit_Type + DeclStmt_Decl = edge.DeclStmt_Decl + DeferStmt_Call = edge.DeferStmt_Call + Ellipsis_Elt = edge.Ellipsis_Elt + ExprStmt_X = edge.ExprStmt_X + FieldList_List = edge.FieldList_List + Field_Comment = edge.Field_Comment + Field_Doc = edge.Field_Doc + Field_Names = edge.Field_Names + Field_Tag = edge.Field_Tag + Field_Type = edge.Field_Type + File_Decls = edge.File_Decls + File_Doc = edge.File_Doc + File_Name = edge.File_Name + ForStmt_Body = edge.ForStmt_Body + ForStmt_Cond = edge.ForStmt_Cond + ForStmt_Init = edge.ForStmt_Init + ForStmt_Post = edge.ForStmt_Post + FuncDecl_Body = edge.FuncDecl_Body + FuncDecl_Doc = edge.FuncDecl_Doc + FuncDecl_Name = edge.FuncDecl_Name + FuncDecl_Recv = edge.FuncDecl_Recv + FuncDecl_Type = edge.FuncDecl_Type + FuncLit_Body = edge.FuncLit_Body + FuncLit_Type = edge.FuncLit_Type + FuncType_Params = edge.FuncType_Params + FuncType_Results = edge.FuncType_Results + FuncType_TypeParams = edge.FuncType_TypeParams + GenDecl_Doc = edge.GenDecl_Doc + GenDecl_Specs = edge.GenDecl_Specs + GoStmt_Call = edge.GoStmt_Call + IfStmt_Body = edge.IfStmt_Body + IfStmt_Cond = edge.IfStmt_Cond + IfStmt_Else = edge.IfStmt_Else + IfStmt_Init = edge.IfStmt_Init + ImportSpec_Comment = edge.ImportSpec_Comment + ImportSpec_Doc = edge.ImportSpec_Doc + ImportSpec_Name = edge.ImportSpec_Name + ImportSpec_Path = edge.ImportSpec_Path + IncDecStmt_X = edge.IncDecStmt_X + IndexExpr_Index = edge.IndexExpr_Index + IndexExpr_X = edge.IndexExpr_X + IndexListExpr_Indices = edge.IndexListExpr_Indices + IndexListExpr_X = edge.IndexListExpr_X + InterfaceType_Methods = edge.InterfaceType_Methods + KeyValueExpr_Key = edge.KeyValueExpr_Key + KeyValueExpr_Value = edge.KeyValueExpr_Value + LabeledStmt_Label = edge.LabeledStmt_Label + LabeledStmt_Stmt = edge.LabeledStmt_Stmt + MapType_Key = edge.MapType_Key + MapType_Value = edge.MapType_Value + ParenExpr_X = edge.ParenExpr_X + RangeStmt_Body = edge.RangeStmt_Body + RangeStmt_Key = edge.RangeStmt_Key + RangeStmt_Value = edge.RangeStmt_Value + RangeStmt_X = edge.RangeStmt_X + ReturnStmt_Results = edge.ReturnStmt_Results + SelectStmt_Body = edge.SelectStmt_Body + SelectorExpr_Sel = edge.SelectorExpr_Sel + SelectorExpr_X = edge.SelectorExpr_X + SendStmt_Chan = edge.SendStmt_Chan + SendStmt_Value = edge.SendStmt_Value + SliceExpr_High = edge.SliceExpr_High + SliceExpr_Low = edge.SliceExpr_Low + SliceExpr_Max = edge.SliceExpr_Max + SliceExpr_X = edge.SliceExpr_X + StarExpr_X = edge.StarExpr_X + StructType_Fields = edge.StructType_Fields + SwitchStmt_Body = edge.SwitchStmt_Body + SwitchStmt_Init = edge.SwitchStmt_Init + SwitchStmt_Tag = edge.SwitchStmt_Tag + TypeAssertExpr_Type = edge.TypeAssertExpr_Type + TypeAssertExpr_X = edge.TypeAssertExpr_X + TypeSpec_Comment = edge.TypeSpec_Comment + TypeSpec_Doc = edge.TypeSpec_Doc + TypeSpec_Name = edge.TypeSpec_Name + TypeSpec_Type = edge.TypeSpec_Type + TypeSpec_TypeParams = edge.TypeSpec_TypeParams + TypeSwitchStmt_Assign = edge.TypeSwitchStmt_Assign + TypeSwitchStmt_Body = edge.TypeSwitchStmt_Body + TypeSwitchStmt_Init = edge.TypeSwitchStmt_Init + UnaryExpr_X = edge.UnaryExpr_X + ValueSpec_Comment = edge.ValueSpec_Comment + ValueSpec_Doc = edge.ValueSpec_Doc + ValueSpec_Names = edge.ValueSpec_Names + ValueSpec_Type = edge.ValueSpec_Type + ValueSpec_Values = edge.ValueSpec_Values ) - -// Assert that the encoding fits in 7 bits, -// as the inspector relies on this. -// (We are currently at 104.) -var _ = [1 << 7]struct{}{}[maxKind] - -type fieldInfo struct { - nodeType reflect.Type // pointer-to-struct type of ast.Node implementation - name string - index int - fieldType reflect.Type -} - -func info[N ast.Node](fieldName string) fieldInfo { - nodePtrType := reflect.TypeFor[N]() - f, ok := nodePtrType.Elem().FieldByName(fieldName) - if !ok { - panic(fieldName) - } - return fieldInfo{nodePtrType, fieldName, f.Index[0], f.Type} -} - -var fieldInfos = [...]fieldInfo{ - Invalid: {}, - ArrayType_Elt: info[*ast.ArrayType]("Elt"), - ArrayType_Len: info[*ast.ArrayType]("Len"), - AssignStmt_Lhs: info[*ast.AssignStmt]("Lhs"), - AssignStmt_Rhs: info[*ast.AssignStmt]("Rhs"), - BinaryExpr_X: info[*ast.BinaryExpr]("X"), - BinaryExpr_Y: info[*ast.BinaryExpr]("Y"), - BlockStmt_List: info[*ast.BlockStmt]("List"), - BranchStmt_Label: info[*ast.BranchStmt]("Label"), - CallExpr_Args: info[*ast.CallExpr]("Args"), - CallExpr_Fun: info[*ast.CallExpr]("Fun"), - CaseClause_Body: info[*ast.CaseClause]("Body"), - CaseClause_List: info[*ast.CaseClause]("List"), - ChanType_Value: info[*ast.ChanType]("Value"), - CommClause_Body: info[*ast.CommClause]("Body"), - CommClause_Comm: info[*ast.CommClause]("Comm"), - CommentGroup_List: info[*ast.CommentGroup]("List"), - CompositeLit_Elts: info[*ast.CompositeLit]("Elts"), - CompositeLit_Type: info[*ast.CompositeLit]("Type"), - DeclStmt_Decl: info[*ast.DeclStmt]("Decl"), - DeferStmt_Call: info[*ast.DeferStmt]("Call"), - Ellipsis_Elt: info[*ast.Ellipsis]("Elt"), - ExprStmt_X: info[*ast.ExprStmt]("X"), - FieldList_List: info[*ast.FieldList]("List"), - Field_Comment: info[*ast.Field]("Comment"), - Field_Doc: info[*ast.Field]("Doc"), - Field_Names: info[*ast.Field]("Names"), - Field_Tag: info[*ast.Field]("Tag"), - Field_Type: info[*ast.Field]("Type"), - File_Decls: info[*ast.File]("Decls"), - File_Doc: info[*ast.File]("Doc"), - File_Name: info[*ast.File]("Name"), - ForStmt_Body: info[*ast.ForStmt]("Body"), - ForStmt_Cond: info[*ast.ForStmt]("Cond"), - ForStmt_Init: info[*ast.ForStmt]("Init"), - ForStmt_Post: info[*ast.ForStmt]("Post"), - FuncDecl_Body: info[*ast.FuncDecl]("Body"), - FuncDecl_Doc: info[*ast.FuncDecl]("Doc"), - FuncDecl_Name: info[*ast.FuncDecl]("Name"), - FuncDecl_Recv: info[*ast.FuncDecl]("Recv"), - FuncDecl_Type: info[*ast.FuncDecl]("Type"), - FuncLit_Body: info[*ast.FuncLit]("Body"), - FuncLit_Type: info[*ast.FuncLit]("Type"), - FuncType_Params: info[*ast.FuncType]("Params"), - FuncType_Results: info[*ast.FuncType]("Results"), - FuncType_TypeParams: info[*ast.FuncType]("TypeParams"), - GenDecl_Doc: info[*ast.GenDecl]("Doc"), - GenDecl_Specs: info[*ast.GenDecl]("Specs"), - GoStmt_Call: info[*ast.GoStmt]("Call"), - IfStmt_Body: info[*ast.IfStmt]("Body"), - IfStmt_Cond: info[*ast.IfStmt]("Cond"), - IfStmt_Else: info[*ast.IfStmt]("Else"), - IfStmt_Init: info[*ast.IfStmt]("Init"), - ImportSpec_Comment: info[*ast.ImportSpec]("Comment"), - ImportSpec_Doc: info[*ast.ImportSpec]("Doc"), - ImportSpec_Name: info[*ast.ImportSpec]("Name"), - ImportSpec_Path: info[*ast.ImportSpec]("Path"), - IncDecStmt_X: info[*ast.IncDecStmt]("X"), - IndexExpr_Index: info[*ast.IndexExpr]("Index"), - IndexExpr_X: info[*ast.IndexExpr]("X"), - IndexListExpr_Indices: info[*ast.IndexListExpr]("Indices"), - IndexListExpr_X: info[*ast.IndexListExpr]("X"), - InterfaceType_Methods: info[*ast.InterfaceType]("Methods"), - KeyValueExpr_Key: info[*ast.KeyValueExpr]("Key"), - KeyValueExpr_Value: info[*ast.KeyValueExpr]("Value"), - LabeledStmt_Label: info[*ast.LabeledStmt]("Label"), - LabeledStmt_Stmt: info[*ast.LabeledStmt]("Stmt"), - MapType_Key: info[*ast.MapType]("Key"), - MapType_Value: info[*ast.MapType]("Value"), - ParenExpr_X: info[*ast.ParenExpr]("X"), - RangeStmt_Body: info[*ast.RangeStmt]("Body"), - RangeStmt_Key: info[*ast.RangeStmt]("Key"), - RangeStmt_Value: info[*ast.RangeStmt]("Value"), - RangeStmt_X: info[*ast.RangeStmt]("X"), - ReturnStmt_Results: info[*ast.ReturnStmt]("Results"), - SelectStmt_Body: info[*ast.SelectStmt]("Body"), - SelectorExpr_Sel: info[*ast.SelectorExpr]("Sel"), - SelectorExpr_X: info[*ast.SelectorExpr]("X"), - SendStmt_Chan: info[*ast.SendStmt]("Chan"), - SendStmt_Value: info[*ast.SendStmt]("Value"), - SliceExpr_High: info[*ast.SliceExpr]("High"), - SliceExpr_Low: info[*ast.SliceExpr]("Low"), - SliceExpr_Max: info[*ast.SliceExpr]("Max"), - SliceExpr_X: info[*ast.SliceExpr]("X"), - StarExpr_X: info[*ast.StarExpr]("X"), - StructType_Fields: info[*ast.StructType]("Fields"), - SwitchStmt_Body: info[*ast.SwitchStmt]("Body"), - SwitchStmt_Init: info[*ast.SwitchStmt]("Init"), - SwitchStmt_Tag: info[*ast.SwitchStmt]("Tag"), - TypeAssertExpr_Type: info[*ast.TypeAssertExpr]("Type"), - TypeAssertExpr_X: info[*ast.TypeAssertExpr]("X"), - TypeSpec_Comment: info[*ast.TypeSpec]("Comment"), - TypeSpec_Doc: info[*ast.TypeSpec]("Doc"), - TypeSpec_Name: info[*ast.TypeSpec]("Name"), - TypeSpec_Type: info[*ast.TypeSpec]("Type"), - TypeSpec_TypeParams: info[*ast.TypeSpec]("TypeParams"), - TypeSwitchStmt_Assign: info[*ast.TypeSwitchStmt]("Assign"), - TypeSwitchStmt_Body: info[*ast.TypeSwitchStmt]("Body"), - TypeSwitchStmt_Init: info[*ast.TypeSwitchStmt]("Init"), - UnaryExpr_X: info[*ast.UnaryExpr]("X"), - ValueSpec_Comment: info[*ast.ValueSpec]("Comment"), - ValueSpec_Doc: info[*ast.ValueSpec]("Doc"), - ValueSpec_Names: info[*ast.ValueSpec]("Names"), - ValueSpec_Type: info[*ast.ValueSpec]("Type"), - ValueSpec_Values: info[*ast.ValueSpec]("Values"), -} From 3ce9106b71e3abf687b875ce87f9adaa124bcfde Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 9 May 2025 20:21:15 -0400 Subject: [PATCH 039/196] internal/mcp: fix roots design Minor changes to the roots part of the design. Change-Id: I8125aecc5da297594f74d8d7dc5ef3d94dfa8519 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671356 Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index a7a22d644d6..cb5a6c8d68b 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -575,16 +575,14 @@ type XXXParams struct { // where XXX is each type of call ``` Handlers can notify their peer about progress by calling the `NotifyProgress` -method. The notification is only sent if the peer requested it. +method. The notification is only sent if the peer requested it by providing +a progress token. ```go func (*ClientSession) NotifyProgress(context.Context, *ProgressNotification) func (*ServerSession) NotifyProgress(context.Context, *ProgressNotification) ``` -We don't support progress notifications for `Client.ListRoots`, because we expect -that operation to be instantaneous relative to network latency. - ### Ping / KeepAlive Both `ClientSession` and `ServerSession` expose a `Ping` method to call "ping" @@ -625,16 +623,15 @@ Clients support the MCP Roots feature out of the box, including roots-changed no Roots can be added and removed from a `Client` with `AddRoots` and `RemoveRoots`: ```go -// AddRoots adds the roots to the client's list of roots. -// If the list changes, the client notifies the server. -// If a root does not begin with a valid URI schema such as "https://" or "file://", -// it is intepreted as a directory path on the local filesystem. -func (*Client) AddRoots(roots ...string) +// AddRoots adds the given roots to the client, +// replacing any with the same URIs, +// and notifies any connected servers. +func (*Client) AddRoots(roots ...Root) -// RemoveRoots removes the given roots from the client's list, and notifies -// the server if the list has changed. +// RemoveRoots removes the roots with the given URIs. +// and notifies any connected servers if the list has changed. // It is not an error to remove a nonexistent root. -func (*Client) RemoveRoots(roots ...string) +func (*Client) RemoveRoots(uris ...string) ``` Servers can call `ListRoots` to get the roots. If a server installs a From 0d237c049410f869c6f035f287d92cf3b3b26864 Mon Sep 17 00:00:00 2001 From: Nick Ripley Date: Mon, 12 May 2025 10:41:50 -0400 Subject: [PATCH 040/196] go/analysis/passes/framepointer: only stop on unconditional branches The framepointer check currently stops checking a function at the first control flow statement. This is for simplicity: the check currently reads instructions in the order they appear, without any skipping ahead or backtracking. Following branches would likely complicate the check. However, this is a bit more conservative than necessary. For conditional branches, if they aren't taken then execution continues to the next instruction. It's only for uncondtional branches that we'd need more complex logic. So, for the purposes of this check we can assume conditional branches aren't taken and continue reading through the function. This will catch some bugs that aren't currently caught by the check, e.g. the buggy assembly at the root of golang/go#69629. For golang/go#69838 Change-Id: If9b14421def4a1ac5a3edbe387cfcc0d53c0a3dc Reviewed-on: https://go-review.googlesource.com/c/tools/+/640075 Reviewed-by: Keith Randall LUCI-TryBot-Result: Go LUCI Reviewed-by: Keith Randall Reviewed-by: Dmitri Shuralyov Reviewed-by: Cherry Mui --- .../passes/framepointer/framepointer.go | 49 +++---------------- .../framepointer/testdata/src/a/asm_amd64.s | 9 +++- .../framepointer/testdata/src/a/asm_arm64.s | 13 ++++- 3 files changed, 28 insertions(+), 43 deletions(-) diff --git a/go/analysis/passes/framepointer/framepointer.go b/go/analysis/passes/framepointer/framepointer.go index ba94fd68ea4..ff9c8b4f818 100644 --- a/go/analysis/passes/framepointer/framepointer.go +++ b/go/analysis/passes/framepointer/framepointer.go @@ -28,9 +28,9 @@ var Analyzer = &analysis.Analyzer{ // Per-architecture checks for instructions. // Assume comments, leading and trailing spaces are removed. type arch struct { - isFPWrite func(string) bool - isFPRead func(string) bool - isBranch func(string) bool + isFPWrite func(string) bool + isFPRead func(string) bool + isUnconditionalBranch func(string) bool } var re = regexp.MustCompile @@ -48,8 +48,8 @@ var arches = map[string]arch{ "amd64": { isFPWrite: re(`,\s*BP$`).MatchString, // TODO: can have false positive, e.g. for TESTQ BP,BP. Seems unlikely. isFPRead: re(`\bBP\b`).MatchString, - isBranch: func(s string) bool { - return hasAnyPrefix(s, "J", "RET") + isUnconditionalBranch: func(s string) bool { + return hasAnyPrefix(s, "JMP", "RET") }, }, "arm64": { @@ -70,49 +70,16 @@ var arches = map[string]arch{ return false }, isFPRead: re(`\bR29\b`).MatchString, - isBranch: func(s string) bool { + isUnconditionalBranch: func(s string) bool { // Get just the instruction if i := strings.IndexFunc(s, unicode.IsSpace); i > 0 { s = s[:i] } - return arm64Branch[s] + return s == "B" || s == "JMP" || s == "RET" }, }, } -// arm64 has many control flow instructions. -// ^(B|RET) isn't sufficient or correct (e.g. BIC, BFI aren't control flow.) -// It's easier to explicitly enumerate them in a map than to write a regex. -// Borrowed from Go tree, cmd/asm/internal/arch/arm64.go -var arm64Branch = map[string]bool{ - "B": true, - "BL": true, - "BEQ": true, - "BNE": true, - "BCS": true, - "BHS": true, - "BCC": true, - "BLO": true, - "BMI": true, - "BPL": true, - "BVS": true, - "BVC": true, - "BHI": true, - "BLS": true, - "BGE": true, - "BLT": true, - "BGT": true, - "BLE": true, - "CBZ": true, - "CBZW": true, - "CBNZ": true, - "CBNZW": true, - "JMP": true, - "TBNZ": true, - "TBZ": true, - "RET": true, -} - func run(pass *analysis.Pass) (any, error) { arch, ok := arches[build.Default.GOARCH] if !ok { @@ -164,7 +131,7 @@ func run(pass *analysis.Pass) (any, error) { active = false continue } - if arch.isFPRead(line) || arch.isBranch(line) { + if arch.isFPRead(line) || arch.isUnconditionalBranch(line) { active = false continue } diff --git a/go/analysis/passes/framepointer/testdata/src/a/asm_amd64.s b/go/analysis/passes/framepointer/testdata/src/a/asm_amd64.s index a7d1b1cce7e..29d29548d7a 100644 --- a/go/analysis/passes/framepointer/testdata/src/a/asm_amd64.s +++ b/go/analysis/passes/framepointer/testdata/src/a/asm_amd64.s @@ -11,6 +11,13 @@ TEXT ·bad2(SB), 0, $0 TEXT ·bad3(SB), 0, $0 MOVQ 6(AX), BP // want `frame pointer is clobbered before saving` RET +TEXT ·bad4(SB), 0, $0 + CMPQ AX, BX + JEQ skip + // Assume the above conditional branch is not taken + MOVQ $0, BP // want `frame pointer is clobbered before saving` +skip: + RET TEXT ·good1(SB), 0, $0 PUSHQ BP MOVQ $0, BP // this is ok @@ -23,7 +30,7 @@ TEXT ·good2(SB), 0, $0 RET TEXT ·good3(SB), 0, $0 CMPQ AX, BX - JEQ skip + JMP skip MOVQ $0, BP // this is ok skip: RET diff --git a/go/analysis/passes/framepointer/testdata/src/a/asm_arm64.s b/go/analysis/passes/framepointer/testdata/src/a/asm_arm64.s index f2be7bdb9e9..de0626790c5 100644 --- a/go/analysis/passes/framepointer/testdata/src/a/asm_arm64.s +++ b/go/analysis/passes/framepointer/testdata/src/a/asm_arm64.s @@ -17,6 +17,17 @@ TEXT ·bad4(SB), 0, $0 TEXT ·bad5(SB), 0, $0 AND $0x1, R3, R29 // want `frame pointer is clobbered before saving` RET +TEXT ·bad6(SB), 0, $0 + CMP R1, R2 + BEQ skip + // Assume that the above conditional branch is not taken + MOVD $0, R29 // want `frame pointer is clobbered before saving` +skip: + RET +TEXT ·bad7(SB), 0, $0 + BL ·good4(SB) + AND $0x1, R3, R29 // want `frame pointer is clobbered before saving` + RET TEXT ·good1(SB), 0, $0 STPW (R29, R30), -32(RSP) MOVD $0, R29 // this is ok @@ -29,7 +40,7 @@ TEXT ·good2(SB), 0, $0 RET TEXT ·good3(SB), 0, $0 CMP R1, R2 - BEQ skip + B skip MOVD $0, R29 // this is ok skip: RET From acf038e65b084fb5d3cd11698bd18fcba7710880 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 9 May 2025 20:35:53 -0400 Subject: [PATCH 041/196] internal/mcp/protocol: featureSet[T] What the spec calls "lists" of features (tools, prompts and so on) are actually sets. Define a common data structure we can use to implement all of them, and use it for server tools and prompts. Change-Id: I257d0776b87a746ff23cd49ebcc7b317f4ea65e1 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671359 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/features.go | 73 ++++++++++++++++++++++++++++++++++++++++ internal/mcp/mcp_test.go | 35 ++++++++++--------- internal/mcp/server.go | 66 ++++++++++++++++++++++-------------- 3 files changed, 132 insertions(+), 42 deletions(-) create mode 100644 internal/mcp/features.go diff --git a/internal/mcp/features.go b/internal/mcp/features.go new file mode 100644 index 00000000000..42e74c86aaf --- /dev/null +++ b/internal/mcp/features.go @@ -0,0 +1,73 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "iter" + + "golang.org/x/tools/internal/mcp/internal/util" +) + +// This file contains implementations that are common to all features. +// A feature is an item provided to a peer. In the 2025-03-26 spec, +// the features are prompt, tool, resource and root. + +// A featureSet is a collection of features of type T. +// Every feature has a unique ID, and the spec never mentions +// an ordering for the List calls, so what it calls a "list" is actually a set. +type featureSet[T any] struct { + uniqueID func(T) string + features map[string]T +} + +// newFeatureSet creates a new featureSet for features of type T. +// The argument function should return the unique ID for a single feature. +func newFeatureSet[T any](uniqueIDFunc func(T) string) *featureSet[T] { + return &featureSet[T]{ + uniqueID: uniqueIDFunc, + features: make(map[string]T), + } +} + +// add adds each feature to the set if it is not present, +// or replaces an existing feature. +func (s *featureSet[T]) add(fs ...T) { + for _, f := range fs { + s.features[s.uniqueID(f)] = f + } +} + +// remove removes all features with the given uids from the set if present, +// and returns whether any were removed. +// It is not an error to remove a nonexistent feature. +func (s *featureSet[T]) remove(uids ...string) bool { + changed := false + for _, uid := range uids { + if _, ok := s.features[uid]; ok { + changed = true + delete(s.features, uid) + } + } + return changed +} + +// get returns the feature with the given uid. +// If there is none, it returns zero, false. +func (s *featureSet[T]) get(uid string) (T, bool) { + t, ok := s.features[uid] + return t, ok +} + +// all returns an iterator over of all the features in the set +// sorted by unique ID. +func (s *featureSet[T]) all() iter.Seq[T] { + return func(yield func(T) bool) { + for _, f := range util.Sorted(s.features) { + if !yield(f) { + return + } + } + } +} diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 2797519c28c..be26ee79b8f 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -131,25 +131,28 @@ func TestEndToEnd(t *testing.T) { if err != nil { t.Errorf("tools/list failed: %v", err) } - wantTools := []protocol.Tool{{ - Name: "greet", - Description: "say hi", - InputSchema: &jsonschema.Schema{ - Type: "object", - Required: []string{"Name"}, - Properties: map[string]*jsonschema.Schema{ - "Name": {Type: "string"}, + wantTools := []protocol.Tool{ + { + Name: "fail", + Description: "just fail", + InputSchema: &jsonschema.Schema{ + Type: "object", + AdditionalProperties: falseSchema, }, - AdditionalProperties: falseSchema, }, - }, { - Name: "fail", - Description: "just fail", - InputSchema: &jsonschema.Schema{ - Type: "object", - AdditionalProperties: falseSchema, + { + Name: "greet", + Description: "say hi", + InputSchema: &jsonschema.Schema{ + Type: "object", + Required: []string{"Name"}, + Properties: map[string]*jsonschema.Schema{ + "Name": {Type: "string"}, + }, + AdditionalProperties: falseSchema, + }, }, - }} + } if diff := cmp.Diff(wantTools, gotTools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 4f9e196e03f..12b7b4bee0f 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -21,13 +21,14 @@ import ( // Servers expose server-side MCP features, which can serve one or more MCP // sessions by using [Server.Start] or [Server.Run]. type Server struct { + // fixed at creation name string version string opts ServerOptions mu sync.Mutex - prompts []*Prompt - tools []*Tool + prompts *featureSet[*Prompt] + tools *featureSet[*Tool] conns []*ServerConnection } @@ -51,16 +52,29 @@ func NewServer(name, version string, opts *ServerOptions) *Server { name: name, version: version, opts: *opts, + prompts: newFeatureSet(func(p *Prompt) string { return p.Definition.Name }), + tools: newFeatureSet(func(t *Tool) string { return t.Definition.Name }), } } // AddPrompts adds the given prompts to the server. -// -// TODO(rfindley): notify connected clients of any changes. func (s *Server) AddPrompts(prompts ...*Prompt) { s.mu.Lock() defer s.mu.Unlock() - s.prompts = append(s.prompts, prompts...) + s.prompts.add(prompts...) + // Assume there was a change, since add replaces existing prompts. + // (It's possible a prompt was replaced with an identical one, but not worth checking.) + // TODO(rfindley): notify connected clients +} + +// RemovePrompts removes if the prompts with the given names. +// It is not an error to remove a nonexistent prompt. +func (s *Server) RemovePrompts(names ...string) { + s.mu.Lock() + defer s.mu.Unlock() + if s.prompts.remove(names...) { + // TODO: notify + } } // AddTools adds the given tools to the server. @@ -69,7 +83,20 @@ func (s *Server) AddPrompts(prompts ...*Prompt) { func (s *Server) AddTools(tools ...*Tool) { s.mu.Lock() defer s.mu.Unlock() - s.tools = append(s.tools, tools...) + s.tools.add(tools...) + // Assume there was a change, since add replaces existing tools. + // (It's possible a tool was replaced with an identical one, but not worth checking.) + // TODO(rfindley): notify connected clients +} + +// RemoveTools removes if the tools with the given names. +// It is not an error to remove a nonexistent tool. +func (s *Server) RemoveTools(names ...string) { + s.mu.Lock() + defer s.mu.Unlock() + if s.tools.remove(names...) { + // TODO: notify + } } // Clients returns an iterator that yields the current set of client @@ -84,9 +111,8 @@ func (s *Server) Clients() iter.Seq[*ServerConnection] { func (s *Server) listPrompts(_ context.Context, _ *ServerConnection, params *protocol.ListPromptsParams) (*protocol.ListPromptsResult, error) { s.mu.Lock() defer s.mu.Unlock() - res := new(protocol.ListPromptsResult) - for _, p := range s.prompts { + for p := range s.prompts.all() { res.Prompts = append(res.Prompts, p.Definition) } return res, nil @@ -94,15 +120,10 @@ func (s *Server) listPrompts(_ context.Context, _ *ServerConnection, params *pro func (s *Server) getPrompt(ctx context.Context, cc *ServerConnection, params *protocol.GetPromptParams) (*protocol.GetPromptResult, error) { s.mu.Lock() - var prompt *Prompt - if i := slices.IndexFunc(s.prompts, func(t *Prompt) bool { - return t.Definition.Name == params.Name - }); i >= 0 { - prompt = s.prompts[i] - } + prompt, ok := s.prompts.get(params.Name) s.mu.Unlock() - - if prompt == nil { + if !ok { + // TODO: surface the error code over the wire, instead of flattening it into the string. return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, params.Name) } return prompt.Handler(ctx, cc, params.Arguments) @@ -111,9 +132,8 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerConnection, params *pr func (s *Server) listTools(_ context.Context, _ *ServerConnection, params *protocol.ListToolsParams) (*protocol.ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() - res := new(protocol.ListToolsResult) - for _, t := range s.tools { + for t := range s.tools.all() { res.Tools = append(res.Tools, t.Definition) } return res, nil @@ -121,15 +141,9 @@ func (s *Server) listTools(_ context.Context, _ *ServerConnection, params *proto func (s *Server) callTool(ctx context.Context, cc *ServerConnection, params *protocol.CallToolParams) (*protocol.CallToolResult, error) { s.mu.Lock() - var tool *Tool - if i := slices.IndexFunc(s.tools, func(t *Tool) bool { - return t.Definition.Name == params.Name - }); i >= 0 { - tool = s.tools[i] - } + tool, ok := s.tools.get(params.Name) s.mu.Unlock() - - if tool == nil { + if !ok { return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, params.Name) } return tool.Handler(ctx, cc, params.Arguments) From 7b959ffb83ad0c9c6c15c669fd708b8be6b95a5d Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 12 May 2025 13:37:00 -0400 Subject: [PATCH 042/196] go/ast/inspector: improve doc comments Updates golang/go#70859 Change-Id: I0b294c3529717d0224d3f33e1b854ba98b53a571 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672015 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- go/ast/astutil/rewrite.go | 4 ++++ go/ast/inspector/cursor.go | 26 +++++++++++++++++++------- go/ast/inspector/inspector.go | 17 +++++++++++++---- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/go/ast/astutil/rewrite.go b/go/ast/astutil/rewrite.go index 5c8dbbb7a35..4ad0549304c 100644 --- a/go/ast/astutil/rewrite.go +++ b/go/ast/astutil/rewrite.go @@ -67,6 +67,10 @@ var abort = new(int) // singleton, to signal termination of Apply // // The methods Replace, Delete, InsertBefore, and InsertAfter // can be used to change the AST without disrupting Apply. +// +// This type is not to be confused with [inspector.Cursor] from +// package [golang.org/x/tools/go/ast/inspector], which provides +// stateless navigation of immutable syntax trees. type Cursor struct { parent ast.Node name string diff --git a/go/ast/inspector/cursor.go b/go/ast/inspector/cursor.go index cd10afa5889..bec9e4decac 100644 --- a/go/ast/inspector/cursor.go +++ b/go/ast/inspector/cursor.go @@ -5,7 +5,6 @@ package inspector // TODO(adonovan): -// - review package documentation // - apply-all //go:fix inline import ( @@ -22,7 +21,20 @@ import ( // // Two Cursors compare equal if they represent the same node. // -// Call [Inspector.Root] to obtain a valid cursor. +// Call [Inspector.Root] to obtain a valid cursor for the virtual root +// node of the traversal. +// +// Use the following methods to navigate efficiently around the tree: +// - for ancestors, use [Cursor.Parent] and [Cursor.Enclosing]; +// - for children, use [Cursor.Child], [Cursor.Children], +// [Cursor.FirstChild], and [Cursor.LastChild]; +// - for siblings, use [Cursor.PrevSibling] and [Cursor.NextSibling]; +// - for descendants, use [Cursor.FindByPos], [Cursor.FindNode], +// [Cursor.Inspect], and [Cursor.Preorder]. +// +// Use the [Cursor.ChildAt] and [Cursor.ParentEdge] methods for +// information about the edges in a tree: which field (and slice +// element) of the parent node holds the child. type Cursor struct { in *Inspector index int32 // index of push node; -1 for virtual root node @@ -369,11 +381,11 @@ func (c Cursor) LastChild() (Cursor, bool) { // of expressions and statements. Other nodes that have "uncontained" // list fields include: // -// - [ast.ValueSpec] (Names, Values) -// - [ast.CompositeLit] (Type, Elts) -// - [ast.IndexListExpr] (X, Indices) -// - [ast.CallExpr] (Fun, Args) -// - [ast.AssignStmt] (Lhs, Rhs) +// - [ast.ValueSpec] (Names, Values) +// - [ast.CompositeLit] (Type, Elts) +// - [ast.IndexListExpr] (X, Indices) +// - [ast.CallExpr] (Fun, Args) +// - [ast.AssignStmt] (Lhs, Rhs) // // So, do not assume that the previous sibling of an ast.Stmt is also // an ast.Stmt, or if it is, that they are executed sequentially, diff --git a/go/ast/inspector/inspector.go b/go/ast/inspector/inspector.go index b07318ac4c5..656302e2494 100644 --- a/go/ast/inspector/inspector.go +++ b/go/ast/inspector/inspector.go @@ -13,10 +13,19 @@ // This representation is sometimes called a "balanced parenthesis tree." // // Experiments suggest the inspector's traversals are about 2.5x faster -// than ast.Inspect, but it may take around 5 traversals for this +// than [ast.Inspect], but it may take around 5 traversals for this // benefit to amortize the inspector's construction cost. // If efficiency is the primary concern, do not use Inspector for // one-off traversals. +// +// The [Cursor] type provides a more flexible API for efficient +// navigation of syntax trees in all four "cardinal directions". For +// example, traversals may be nested, so you can find each node of +// type A and then search within it for nodes of type B. Or you can +// traverse from a node to its immediate neighbors: its parent, its +// previous and next sibling, or its first and last child. We +// recommend using methods of Cursor in preference to Inspector where +// possible. package inspector // There are four orthogonal features in a traversal: @@ -82,7 +91,7 @@ type event struct { // depth-first order. It calls f(n) for each node n before it visits // n's children. // -// The complete traversal sequence is determined by ast.Inspect. +// The complete traversal sequence is determined by [ast.Inspect]. // The types argument, if non-empty, enables type-based filtering of // events. The function f is called only for nodes whose type // matches an element of the types slice. @@ -130,7 +139,7 @@ func (in *Inspector) Preorder(types []ast.Node, f func(ast.Node)) { // of the non-nil children of the node, followed by a call of // f(n, false). // -// The complete traversal sequence is determined by ast.Inspect. +// The complete traversal sequence is determined by [ast.Inspect]. // The types argument, if non-empty, enables type-based filtering of // events. The function f if is called only for nodes whose type // matches an element of the types slice. @@ -249,7 +258,7 @@ type visitor struct { type item struct { index int32 // index of current node's push event parentIndex int32 // index of parent node's push event - typAccum uint64 // accumulated type bits of current node's descendents + typAccum uint64 // accumulated type bits of current node's descendants edgeKindAndIndex int32 // edge.Kind and index, bit packed } From 2835a17831c928f418855ad7466a0a5636363a20 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sun, 11 May 2025 11:59:19 -0400 Subject: [PATCH 043/196] internal/mcp: design.md: adjust resource design Add handlers for reading resources. Change-Id: Ibe0553a02c6fa2b5159cc406c7c4481984b1a460 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671357 Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 44 +++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index cb5a6c8d68b..832b79c037e 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -852,24 +852,58 @@ handler to a Go function using reflection to derive its arguments. We provide ### Resources and resource templates -Servers have Add and Remove methods for resources and resource templates: +To add a resource or resource template to a server, users call the `AddResource` and +`AddResourceTemplate` methods, passing the resource or template and a function for reading it: +```go +type ReadResourceHandler func(context.Context, *ServerSession, *Resource, *ReadResourceParams) (*ReadResourceResult, error) + +func (*Server) AddResource(*Resource, ReadResourceHandler) +func (*Server) AddResourceTemplate(*ResourceTemplate, ReadResourceHandler) +``` +The `Resource` is passed to the reader function even though it is redundant (the function could have closed over it) +so a single handler can support multiple resources. +If the incoming resource matches a template, a `Resource` argument is constructed +from the fields in the `ResourceTemplate`. +The `ServerSession` argument is there so the reader can observe the client's roots. +To read files from the local filesystem, we recommend using `FileReadResourceHandler` to construct a handler: ```go -func (*Server) AddResources(resources ...*Resource) -func (*Server) RemoveResources(names ...string) -func (*Server) AddResourceTemplates(templates...*ResourceTemplate) +// FileReadResourceHandler returns a ReadResourceHandler that reads paths using dir as a root directory. +// It protects against path traversal attacks. +// It will not read any file that is not in the root set of the client requesting the resource. +func (*Server) FileReadResourceHandler(dir string) ReadResourceHandler +``` +It guards against [path traversal attacks](https://go.dev/blog/osroot) +and observes the client's roots. +Here is an example: +```go +// Safely read "/public/puppies.txt". +s.AddResource( + &mcp.Resource{URI: "file:///puppies.txt"}, + s.FileReadResourceHandler("/public")) +``` + +There are also server methods to remove resources and resource templates. +```go +func (*Server) RemoveResources(uris ...string) func (*Server) RemoveResourceTemplates(names ...string) ``` +Resource templates don't have unique identifiers, so removing a name will remove all +resource templates with that name. Clients call `ListResources` to list the available resources, `ReadResource` to read one of them, and `ListResourceTemplates` to list the templates: ```go func (*ClientSession) ListResources(context.Context, *ListResourcesParams) (*ListResourcesResult, error) -func (*ClientSession) ReadResource(context.Context, *ReadResourceParams) (*ReadResourceResult, error) func (*ClientSession) ListResourceTemplates(context.Context, *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) +func (*ClientSession) ReadResource(context.Context, *ReadResourceParams) (*ReadResourceResult, error) ``` +`ReadResource` checks the incoming URI against the server's list of +resources and resource templates to make sure it matches one of them, +then returns the result of calling the associated reader function. + ### ListChanged notifications From bbef3e4296a1595c5c09b90ee8c80937dfcd3201 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 9 May 2025 22:04:50 -0400 Subject: [PATCH 044/196] internal/mcp: implement roots Add support for roots. Change-Id: Ia50abc88f0047238272d698f30ce615b1a8fd486 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671360 LUCI-TryBot-Result: Go LUCI Reviewed-by: Sam Thanawalla Reviewed-by: Robert Findley --- internal/mcp/client.go | 62 ++++++++++++++++++++++++------- internal/mcp/mcp_test.go | 12 +++++- internal/mcp/protocol/generate.go | 13 ++++++- internal/mcp/protocol/protocol.go | 53 +++++++++++++++++++++++--- internal/mcp/root.go | 5 +++ internal/mcp/server.go | 16 +++++--- 6 files changed, 133 insertions(+), 28 deletions(-) create mode 100644 internal/mcp/root.go diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 8af1c7ea8f6..6592cc8eb6d 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -9,6 +9,7 @@ import ( "context" "encoding/json" "fmt" + "slices" "sync" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" @@ -24,6 +25,7 @@ type Client struct { opts ClientOptions mu sync.Mutex conn *jsonrpc2.Connection + roots *featureSet[protocol.Root] initializeResult *protocol.InitializeResult } @@ -37,6 +39,7 @@ func NewClient(name, version string, t Transport, opts *ClientOptions) *Client { name: name, version: version, transport: t, + roots: newFeatureSet(func(r protocol.Root) string { return r.URI }), } if opts != nil { c.opts = *opts @@ -106,13 +109,47 @@ func (c *Client) Wait() error { return c.conn.Wait() } -func (*Client) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { +// AddRoots adds the given roots to the client, +// replacing any with the same URIs, +// and notifies any connected servers. +// TODO: notification +func (c *Client) AddRoots(roots ...protocol.Root) { + c.mu.Lock() + defer c.mu.Unlock() + c.roots.add(roots...) +} + +// RemoveRoots removes the roots with the given URIs, +// and notifies any connected servers if the list has changed. +// It is not an error to remove a nonexistent root. +// TODO: notification +func (c *Client) RemoveRoots(uris ...string) { + c.mu.Lock() + defer c.mu.Unlock() + c.roots.remove(uris...) +} + +func (c *Client) listRoots(_ context.Context, _ *protocol.ListRootsParams) (*protocol.ListRootsResult, error) { + c.mu.Lock() + defer c.mu.Unlock() + return &protocol.ListRootsResult{ + Roots: slices.Collect(c.roots.all()), + }, nil +} + +func (c *Client) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { + // TODO: when we switch to ClientSessions, use a copy of the server's dispatch function, or + // maybe just add another type parameter. + // // No need to check that the connection is initialized, since we initialize // it in Connect. switch req.Method { case "ping": // The spec says that 'ping' expects an empty object result. return struct{}{}, nil + case "roots/list": + // ListRootsParams happens to be unused. + return c.listRoots(ctx, nil) } return nil, jsonrpc2.ErrNotHandled } @@ -162,10 +199,6 @@ func (c *Client) ListTools(ctx context.Context) ([]protocol.Tool, error) { } // CallTool calls the tool with the given name and arguments. -// -// TODO(jba): make the following true: -// If the provided arguments do not conform to the schema for the given tool, -// the call fails. func (c *Client) CallTool(ctx context.Context, name string, args map[string]any) (_ *protocol.CallToolResult, err error) { defer func() { if err != nil { @@ -180,14 +213,17 @@ func (c *Client) CallTool(ctx context.Context, name string, args map[string]any) } argsJSON[name] = argJSON } - var ( - params = &protocol.CallToolParams{ - Name: name, - Arguments: argsJSON, - } - result protocol.CallToolResult - ) - if err := call(ctx, c.conn, "tools/call", params, &result); err != nil { + + params := &protocol.CallToolParams{ + Name: name, + Arguments: argsJSON, + } + return standardCall[protocol.CallToolResult](ctx, c.conn, "tools/call", params) +} + +func standardCall[TRes, TParams any](ctx context.Context, conn *jsonrpc2.Connection, method string, params TParams) (*TRes, error) { + var result TRes + if err := call(ctx, conn, method, params, &result); err != nil { return nil, err } return &result, nil diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index be26ee79b8f..aea6081473e 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -63,7 +63,7 @@ func TestEndToEnd(t *testing.T) { ) // Connect the server. - cc, err := s.Connect(ctx, st, nil) + sc, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -75,13 +75,14 @@ func TestEndToEnd(t *testing.T) { var clientWG sync.WaitGroup clientWG.Add(1) go func() { - if err := cc.Wait(); err != nil { + if err := sc.Wait(); err != nil { t.Errorf("server failed: %v", err) } clientWG.Done() }() c := NewClient("testClient", "v1.0.0", ct, nil) + c.AddRoots(protocol.Root{URI: "file:///root"}) // Connect the client. if err := c.Start(ctx); err != nil { @@ -182,6 +183,13 @@ func TestEndToEnd(t *testing.T) { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) } + rootRes, err := sc.ListRoots(ctx, &protocol.ListRootsParams{}) + gotRoots := rootRes.Roots + wantRoots := slices.Collect(c.roots.all()) + if diff := cmp.Diff(wantRoots, gotRoots); diff != "" { + t.Errorf("roots/list mismatch (-want +got):\n%s", diff) + } + // Disconnect. c.Close() clientWG.Wait() diff --git a/internal/mcp/protocol/generate.go b/internal/mcp/protocol/generate.go index 30bc81d978b..273500881e5 100644 --- a/internal/mcp/protocol/generate.go +++ b/internal/mcp/protocol/generate.go @@ -75,6 +75,10 @@ var declarations = config{ Fields: config{"Params": {Name: "ListPromptsParams"}}, }, "ListPromptsResult": {Name: "ListPromptsResult"}, + "ListRootsRequest": { + Fields: config{"Params": {Name: "ListRootsParams"}}, + }, + "ListRootsResult": {Name: "ListRootsResult"}, "ListToolsRequest": { Fields: config{"Params": {Name: "ListToolsParams"}}, }, @@ -82,8 +86,11 @@ var declarations = config{ "Prompt": {Name: "Prompt"}, "PromptMessage": {Name: "PromptMessage"}, "PromptArgument": {Name: "PromptArgument"}, + "ProgressToken": {Substitute: "any"}, // null|number|string "RequestId": {Substitute: "any"}, // null|number|string "Role": {Name: "Role"}, + "Root": {Name: "Root"}, + "ServerCapabilities": { Name: "ServerCapabilities", Fields: config{ @@ -243,7 +250,11 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma // For types that explicitly allow additional properties, we can either // unmarshal them into a map[string]any, or delay unmarshalling with // json.RawMessage. For now, use json.RawMessage as it defers the choice. - if def.Type == "object" && canHaveAdditionalProperties(def) { + // + // TODO(jba): further refine this classification of object schemas. + // For example, the typescript "object" type, which should map to a Go "any", + // is represented in schema.json by `{type: object, properties: {}, additionalProperties: true}`. + if def.Type == "object" && canHaveAdditionalProperties(def) && def.Properties == nil { w.Write([]byte("map[string]")) return writeType(w, nil, def.AdditionalProperties, named) } diff --git a/internal/mcp/protocol/protocol.go b/internal/mcp/protocol/protocol.go index 399408238a4..fbc7b256fae 100644 --- a/internal/mcp/protocol/protocol.go +++ b/internal/mcp/protocol/protocol.go @@ -70,14 +70,16 @@ type CancelledParams struct { // additional capabilities. type ClientCapabilities struct { // Experimental, non-standard capabilities that the client supports. - Experimental map[string]map[string]json.RawMessage `json:"experimental,omitempty"` + Experimental map[string]struct { + } `json:"experimental,omitempty"` // Present if the client supports listing roots. Roots *struct { // Whether the client supports notifications for changes to the roots list. ListChanged bool `json:"listChanged,omitempty"` } `json:"roots,omitempty"` // Present if the client supports sampling from an LLM. - Sampling map[string]json.RawMessage `json:"sampling,omitempty"` + Sampling struct { + } `json:"sampling,omitempty"` } type GetPromptParams struct { @@ -131,7 +133,11 @@ type InitializeResult struct { ServerInfo Implementation `json:"serverInfo"` } -type InitializedParams map[string]json.RawMessage +type InitializedParams struct { + // This parameter name is reserved by MCP to allow clients and servers to attach + // additional metadata to their notifications. + Meta map[string]json.RawMessage `json:"_meta,omitempty"` +} type ListPromptsParams struct { // An opaque token representing the current pagination position. If provided, @@ -150,6 +156,26 @@ type ListPromptsResult struct { Prompts []Prompt `json:"prompts"` } +type ListRootsParams struct { + Meta *struct { + // If specified, the caller is requesting out-of-band progress notifications for + // this request (as represented by notifications/progress). The value of this + // parameter is an opaque token that will be attached to any subsequent + // notifications. The receiver is not obligated to provide these notifications. + ProgressToken *any `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` +} + +// The client's response to a roots/list request from the server. This result +// contains an array of Root objects, each representing a root directory or file +// that the server can operate on. +type ListRootsResult struct { + // This result property is reserved by the protocol to allow clients and servers + // to attach additional metadata to their responses. + Meta map[string]json.RawMessage `json:"_meta,omitempty"` + Roots []Root `json:"roots"` +} + type ListToolsParams struct { // An opaque token representing the current pagination position. If provided, // the server should return results starting after this cursor. @@ -213,16 +239,31 @@ type ResourceCapabilities struct { // The sender or recipient of messages and data in a conversation. type Role string +// Represents a root directory or file that the server can operate on. +type Root struct { + // An optional name for the root. This can be used to provide a human-readable + // identifier for the root, which may be useful for display purposes or for + // referencing the root in other parts of the application. + Name string `json:"name,omitempty"` + // The URI identifying the root. This *must* start with file:// for now. This + // restriction may be relaxed in future versions of the protocol to allow other + // URI schemes. + URI string `json:"uri"` +} + // Capabilities that a server may support. Known capabilities are defined here, // in this schema, but this is not a closed set: any server can define its own, // additional capabilities. type ServerCapabilities struct { // Present if the server supports argument autocompletion suggestions. - Completions map[string]json.RawMessage `json:"completions,omitempty"` + Completions struct { + } `json:"completions,omitempty"` // Experimental, non-standard capabilities that the server supports. - Experimental map[string]map[string]json.RawMessage `json:"experimental,omitempty"` + Experimental map[string]struct { + } `json:"experimental,omitempty"` // Present if the server supports sending log messages to the client. - Logging map[string]json.RawMessage `json:"logging,omitempty"` + Logging struct { + } `json:"logging,omitempty"` // Present if the server offers any prompt templates. Prompts *PromptCapabilities `json:"prompts,omitempty"` // Present if the server offers any resources to read. diff --git a/internal/mcp/root.go b/internal/mcp/root.go new file mode 100644 index 00000000000..2eccf60ad41 --- /dev/null +++ b/internal/mcp/root.go @@ -0,0 +1,5 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 12b7b4bee0f..0a75411959e 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -57,7 +57,8 @@ func NewServer(name, version string, opts *ServerOptions) *Server { } } -// AddPrompts adds the given prompts to the server. +// AddPrompts adds the given prompts to the server, +// replacing any with the same names. func (s *Server) AddPrompts(prompts ...*Prompt) { s.mu.Lock() defer s.mu.Unlock() @@ -67,7 +68,7 @@ func (s *Server) AddPrompts(prompts ...*Prompt) { // TODO(rfindley): notify connected clients } -// RemovePrompts removes if the prompts with the given names. +// RemovePrompts removes the prompts with the given names. // It is not an error to remove a nonexistent prompt. func (s *Server) RemovePrompts(names ...string) { s.mu.Lock() @@ -77,9 +78,8 @@ func (s *Server) RemovePrompts(names ...string) { } } -// AddTools adds the given tools to the server. -// -// TODO(rfindley): notify connected clients of any changes. +// AddTools adds the given tools to the server, +// replacing any with the same names. func (s *Server) AddTools(tools ...*Tool) { s.mu.Lock() defer s.mu.Unlock() @@ -89,7 +89,7 @@ func (s *Server) AddTools(tools ...*Tool) { // TODO(rfindley): notify connected clients } -// RemoveTools removes if the tools with the given names. +// RemoveTools removes the tools with the given names. // It is not an error to remove a nonexistent tool. func (s *Server) RemoveTools(names ...string) { s.mu.Lock() @@ -210,6 +210,10 @@ func (cc *ServerConnection) Ping(ctx context.Context) error { return call(ctx, cc.conn, "ping", nil, nil) } +func (cc *ServerConnection) ListRoots(ctx context.Context, params *protocol.ListRootsParams) (*protocol.ListRootsResult, error) { + return standardCall[protocol.ListRootsResult](ctx, cc.conn, "roots/list", params) +} + func (cc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { cc.mu.Lock() initialized := cc.initialized From 3818858976f9a520aa78f69a89a04455d01b42a1 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 10 May 2025 08:15:31 -0400 Subject: [PATCH 045/196] internal/mcp/protocol: make type name more convenient You can now write "*" for a name in a config to use the original name. Also, fix a bug in dealing with initialisms. Change-Id: I4d5d0acb1200fff6500644457ec34c5026d63f9f Reviewed-on: https://go-review.googlesource.com/c/tools/+/671361 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Reviewed-by: Sam Thanawalla --- internal/mcp/protocol/generate.go | 92 +++++++++++++++++++------------ 1 file changed, 56 insertions(+), 36 deletions(-) diff --git a/internal/mcp/protocol/generate.go b/internal/mcp/protocol/generate.go index 273500881e5..ff5f2465627 100644 --- a/internal/mcp/protocol/generate.go +++ b/internal/mcp/protocol/generate.go @@ -46,53 +46,61 @@ type config map[string]*typeConfig // declarations configures the set of declarations to write. // -// Top level declarations are only created if they are configured with a -// non-empty Name. Otherwise, they are discarded, though their fields may be +// Top level declarations are created unless configured with Name=="-", +// in which case they are discarded, though their fields may be // extracted to types if they have a nested field configuration. +// If Name == "", the map key is used as the type name. var declarations = config{ - "Annotations": {Name: "Annotations"}, + "Annotations": {}, "CallToolRequest": { + Name: "-", Fields: config{"Params": {Name: "CallToolParams"}}, }, - "CallToolResult": {Name: "CallToolResult"}, + "CallToolResult": {}, "CancelledNotification": { + Name: "-", Fields: config{"Params": {Name: "CancelledParams"}}, }, - "ClientCapabilities": {Name: "ClientCapabilities"}, + "ClientCapabilities": {}, "GetPromptRequest": { + Name: "-", Fields: config{"Params": {Name: "GetPromptParams"}}, }, - "GetPromptResult": {Name: "GetPromptResult"}, - "Implementation": {Name: "Implementation"}, + "GetPromptResult": {}, + "Implementation": {}, "InitializeRequest": { + Name: "-", Fields: config{"Params": {Name: "InitializeParams"}}, }, - "InitializeResult": {Name: "InitializeResult"}, + "InitializeResult": {}, "InitializedNotification": { + Name: "-", Fields: config{"Params": {Name: "InitializedParams"}}, }, "ListPromptsRequest": { + Name: "-", Fields: config{"Params": {Name: "ListPromptsParams"}}, }, - "ListPromptsResult": {Name: "ListPromptsResult"}, + "ListPromptsResult": {}, "ListRootsRequest": { + Name: "-", Fields: config{"Params": {Name: "ListRootsParams"}}, }, - "ListRootsResult": {Name: "ListRootsResult"}, + "ListRootsResult": {}, "ListToolsRequest": { + Name: "-", Fields: config{"Params": {Name: "ListToolsParams"}}, }, - "ListToolsResult": {Name: "ListToolsResult"}, - "Prompt": {Name: "Prompt"}, - "PromptMessage": {Name: "PromptMessage"}, - "PromptArgument": {Name: "PromptArgument"}, - "ProgressToken": {Substitute: "any"}, // null|number|string - "RequestId": {Substitute: "any"}, // null|number|string - "Role": {Name: "Role"}, - "Root": {Name: "Root"}, + "ListToolsResult": {}, + "Prompt": {}, + "PromptMessage": {}, + "PromptArgument": {}, + "ProgressToken": {Name: "-", Substitute: "any"}, // null|number|string + "RequestId": {Name: "-", Substitute: "any"}, // null|number|string + "Role": {}, + "Root": {}, "ServerCapabilities": { - Name: "ServerCapabilities", Fields: config{ "Prompts": {Name: "PromptCapabilities"}, "Resources": {Name: "ResourceCapabilities"}, @@ -100,10 +108,9 @@ var declarations = config{ }, }, "Tool": { - Name: "Tool", Fields: config{"InputSchema": {Substitute: "*jsonschema.Schema"}}, }, - "ToolAnnotations": {Name: "ToolAnnotations"}, + "ToolAnnotations": {}, } func main() { @@ -128,7 +135,7 @@ func main() { if config == nil { continue } - if err := writeDecl(*config, def, named); err != nil { + if err := writeDecl(name, *config, def, named); err != nil { log.Fatal(err) } } @@ -199,19 +206,22 @@ func loadSchema(schemaFile string) (data []byte, err error) { return data, nil } -func writeDecl(config typeConfig, def *jsonschema.Schema, named map[string]*bytes.Buffer) error { +func writeDecl(configName string, config typeConfig, def *jsonschema.Schema, named map[string]*bytes.Buffer) error { var w io.Writer = io.Discard - if name := config.Name; name != "" { - if _, ok := named[name]; ok { + if typeName := config.Name; typeName != "-" { + if typeName == "" { + typeName = configName + } + if _, ok := named[typeName]; ok { return nil } buf := new(bytes.Buffer) w = buf - named[name] = buf + named[typeName] = buf if def.Description != "" { fmt.Fprintf(buf, "%s\n", toComment(def.Description)) } - fmt.Fprintf(buf, "type %s ", name) + fmt.Fprintf(buf, "type %s ", typeName) } if err := writeType(w, &config, def, named); err != nil { return err // Better error here? @@ -234,12 +244,12 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma // definition is missing, *but only if w is not io.Discard*. That's not a // great API: see if we can do something more explicit than io.Discard. if cfg, ok := declarations[name]; ok { - if cfg.Name == "" && cfg.Substitute == "" { + if cfg.Name == "-" && cfg.Substitute == "" { panic(fmt.Sprintf("referenced type %q cannot be referred to (no name or substitution)", name)) } if cfg.Substitute != "" { name = cfg.Substitute - } else { + } else if cfg.Name != "" { name = cfg.Name } } @@ -311,14 +321,18 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma if r.Substitute != "" { fmt.Fprintf(w, r.Substitute) } else { - assert(r.Name != "", "missing ExtractTo") - if err := writeDecl(*r, fieldDef, named); err != nil { + assert(r.Name != "-", "missing ExtractTo") + typename := export + if r.Name != "" { + typename = r.Name + } + if err := writeDecl(typename, *r, fieldDef, named); err != nil { return err } if needPointer { fmt.Fprintf(w, "*") } - fmt.Fprintf(w, r.Name) + fmt.Fprintf(w, typename) } } else { if needPointer { @@ -410,7 +424,12 @@ func exportName(s string) string { // at once, because the replacement will change the indices.) for { if loc := re.FindStringIndex(s); loc != nil { - s = s[:loc[0]] + replacement + s[loc[1]:] + // Don't replace the rune after the initialism, if any. + end := loc[1] + if end < len(s) { + end-- + } + s = s[:loc[0]] + replacement + s[end:] } else { break } @@ -421,9 +440,10 @@ func exportName(s string) string { // Map from initialism to the regexp that matches it. var initialisms = map[string]*regexp.Regexp{ - "Id": nil, - "Url": nil, - "Uri": nil, + "Id": nil, + "Url": nil, + "Uri": nil, + "Mime": nil, } func init() { From 6dfeba5cf50918ab108a7bedee3dd93d386f8e68 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 10 May 2025 13:47:22 -0400 Subject: [PATCH 046/196] internal/mcp: adjust content types - Change the names of some content-related types to better match the spec. Most notably, rename Resource to EmbeddedResource, because the former has a different meaning in the spec. - Use []byte intead of string where the spec says "base64-encoded data." - Unexport EmbeddedResource.ToWire. Change-Id: I9f65efaadadda1b3abc0e58c64a743af1a72852c Reviewed-on: https://go-review.googlesource.com/c/tools/+/671362 Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam Commit-Queue: Jonathan Amsterdam --- internal/mcp/content.go | 66 +++++++++++++++++--------------- internal/mcp/content_test.go | 28 ++++++-------- internal/mcp/protocol/content.go | 40 +++++++++++-------- 3 files changed, 72 insertions(+), 62 deletions(-) diff --git a/internal/mcp/content.go b/internal/mcp/content.go index 7a3687dd284..bf318e32ee1 100644 --- a/internal/mcp/content.go +++ b/internal/mcp/content.go @@ -15,6 +15,7 @@ import ( // // ToWire converts content to its jsonrpc2 wire format. type Content interface { + // TODO: unexport this, and move the tests that use it to this package. ToWire() protocol.Content } @@ -29,64 +30,69 @@ func (c TextContent) ToWire() protocol.Content { // ImageContent contains base64-encoded image data. type ImageContent struct { - Data string - MimeType string + Data []byte // base64-encoded + MIMEType string } func (c ImageContent) ToWire() protocol.Content { - return protocol.Content{Type: "image", MIMEType: c.MimeType, Data: c.Data} + return protocol.Content{Type: "image", MIMEType: c.MIMEType, Data: c.Data} } // AudioContent contains base64-encoded audio data. type AudioContent struct { - Data string - MimeType string + Data []byte + MIMEType string } func (c AudioContent) ToWire() protocol.Content { - return protocol.Content{Type: "audio", MIMEType: c.MimeType, Data: c.Data} + return protocol.Content{Type: "audio", MIMEType: c.MIMEType, Data: c.Data} } // ResourceContent contains embedded resources. type ResourceContent struct { - Resource Resource + Resource EmbeddedResource } func (r ResourceContent) ToWire() protocol.Content { - res := r.Resource.ToWire() + res := r.Resource.toWire() return protocol.Content{Type: "resource", Resource: &res} } -type Resource interface { - ToWire() protocol.Resource +type EmbeddedResource interface { + toWire() protocol.ResourceContents } -type TextResource struct { +// The {Text,Blob}ResourceContents types match the protocol definitions, +// but we represent both as a single type on the wire. + +// A TextResourceContents is the contents of a text resource. +type TextResourceContents struct { URI string - MimeType string + MIMEType string Text string } -func (r TextResource) ToWire() protocol.Resource { - return protocol.Resource{ +func (r TextResourceContents) toWire() protocol.ResourceContents { + return protocol.ResourceContents{ URI: r.URI, - MIMEType: r.MimeType, + MIMEType: r.MIMEType, Text: r.Text, + // Blob is nil, indicating this is a TextResourceContents. } } -type BlobResource struct { +// A BlobResourceContents is the contents of a blob resource. +type BlobResourceContents struct { URI string - MimeType string - Blob string + MIMEType string + Blob []byte } -func (r BlobResource) ToWire() protocol.Resource { - blob := r.Blob - return protocol.Resource{ +func (r BlobResourceContents) toWire() protocol.ResourceContents { + return protocol.ResourceContents{ URI: r.URI, - MIMEType: r.MimeType, - Blob: &blob, + MIMEType: r.MIMEType, + Blob: r.Blob, } } @@ -97,22 +103,22 @@ func ContentFromWireContent(c protocol.Content) Content { case "text": return TextContent{Text: c.Text} case "image": - return ImageContent{Data: c.Data, MimeType: c.MIMEType} + return ImageContent{Data: c.Data, MIMEType: c.MIMEType} case "audio": - return AudioContent{Data: c.Data, MimeType: c.MIMEType} + return AudioContent{Data: c.Data, MIMEType: c.MIMEType} case "resource": r := ResourceContent{} if c.Resource != nil { if c.Resource.Blob != nil { - r.Resource = BlobResource{ + r.Resource = BlobResourceContents{ URI: c.Resource.URI, - MimeType: c.Resource.MIMEType, - Blob: *c.Resource.Blob, + MIMEType: c.Resource.MIMEType, + Blob: c.Resource.Blob, } } else { - r.Resource = TextResource{ + r.Resource = TextResourceContents{ URI: c.Resource.URI, - MimeType: c.Resource.MIMEType, + MIMEType: c.Resource.MIMEType, Text: c.Resource.Text, } } diff --git a/internal/mcp/content_test.go b/internal/mcp/content_test.go index 950175cb5ac..1984db36be4 100644 --- a/internal/mcp/content_test.go +++ b/internal/mcp/content_test.go @@ -19,24 +19,24 @@ func TestContent(t *testing.T) { }{ {mcp.TextContent{Text: "hello"}, protocol.Content{Type: "text", Text: "hello"}}, { - mcp.ImageContent{Data: "a1b2c3", MimeType: "image/png"}, - protocol.Content{Type: "image", Data: "a1b2c3", MIMEType: "image/png"}, + mcp.ImageContent{Data: []byte("a1b2c3"), MIMEType: "image/png"}, + protocol.Content{Type: "image", Data: []byte("a1b2c3"), MIMEType: "image/png"}, }, { - mcp.AudioContent{Data: "a1b2c3", MimeType: "audio/wav"}, - protocol.Content{Type: "audio", Data: "a1b2c3", MIMEType: "audio/wav"}, + mcp.AudioContent{Data: []byte("a1b2c3"), MIMEType: "audio/wav"}, + protocol.Content{Type: "audio", Data: []byte("a1b2c3"), MIMEType: "audio/wav"}, }, { mcp.ResourceContent{ - Resource: mcp.TextResource{ + Resource: mcp.TextResourceContents{ URI: "file://foo", - MimeType: "text", + MIMEType: "text", Text: "abc", }, }, protocol.Content{ Type: "resource", - Resource: &protocol.Resource{ + Resource: &protocol.ResourceContents{ URI: "file://foo", MIMEType: "text", Text: "abc", @@ -45,18 +45,18 @@ func TestContent(t *testing.T) { }, { mcp.ResourceContent{ - Resource: mcp.BlobResource{ + Resource: mcp.BlobResourceContents{ URI: "file://foo", - MimeType: "text", - Blob: "a1b2c3", + MIMEType: "text", + Blob: []byte("a1b2c3"), }, }, protocol.Content{ Type: "resource", - Resource: &protocol.Resource{ + Resource: &protocol.ResourceContents{ URI: "file://foo", MIMEType: "text", - Blob: ptr("a1b2c3"), + Blob: []byte("a1b2c3"), }, }, }, @@ -69,7 +69,3 @@ func TestContent(t *testing.T) { } } } - -func ptr[T any](t T) *T { - return &t -} diff --git a/internal/mcp/protocol/content.go b/internal/mcp/protocol/content.go index d5fc16f894d..76f017da6cd 100644 --- a/internal/mcp/protocol/content.go +++ b/internal/mcp/protocol/content.go @@ -9,27 +9,35 @@ import ( "fmt" ) +// The []byte fields below are marked omitzero, not omitempty: +// we want to marshal an empty byte slice. + // Content is the wire format for content. -// -// The Type field distinguishes the type of the content. -// At most one of Text, MIMEType, Data, and Resource is non-zero. +// It represents the protocol types TextContent, ImageContent, AudioContent +// and EmbeddedResource. +// The Type field distinguishes them. In the protocol, each type has a constant +// value for the field. +// At most one of Text, Data, and Resource is non-zero. type Content struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - MIMEType string `json:"mimeType,omitempty"` - Data string `json:"data,omitempty"` - Resource *Resource `json:"resource,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Data []byte `json:"data,omitzero"` + Resource *ResourceContents `json:"resource,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` } -// Resource is the wire format for embedded resources. +// A ResourceContents is either a TextResourceContents or a BlobResourceContents. +// See https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts#L524-L551 +// for the inheritance structure. +// If Blob is nil, this is a TextResourceContents; otherwise it's a BlobResourceContents. // -// The URI field describes the resource location. At most one of Text and Blob -// is non-zero. -type Resource struct { - URI string `json:"uri,"` - MIMEType string `json:"mimeType,omitempty"` - Text string `json:"text"` - Blob *string `json:"blob"` // blob is a pointer to distinguish empty from missing data +// The URI field describes the resource location. +type ResourceContents struct { + URI string `json:"uri,"` + MIMEType string `json:"mimeType,omitempty"` + Text string `json:"text"` + Blob []byte `json:"blob,omitzero"` } func (c *Content) UnmarshalJSON(data []byte) error { From 4d1336a692580c2b8b2fed3b028d93a2b615f450 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 12 May 2025 08:36:32 -0400 Subject: [PATCH 047/196] internal/mcp: add resource subscriptions Handle resource subscriptions by delegating to the user. We don't need to use the word "Resource" in the handlers because the Subscribe and Unsubscribe methods in the spec don't use that name. If the spec adds another kind of subscription, they will pick a new name. The need for two handlers is dictated by backward compatibility. We must pass the RPC params to the subscription handler, which means we need different signatures for subscribe and unsubscribe. Change-Id: I045abaf52dc9c3e96181331728d50cf36cc8ed91 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671895 Reviewed-by: Sam Thanawalla Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 44 ++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 832b79c037e..aeb16b8c704 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -891,20 +891,46 @@ func (*Server) RemoveResourceTemplates(names ...string) Resource templates don't have unique identifiers, so removing a name will remove all resource templates with that name. -Clients call `ListResources` to list the available resources, `ReadResource` to read -one of them, and `ListResourceTemplates` to list the templates: - -```go -func (*ClientSession) ListResources(context.Context, *ListResourcesParams) (*ListResourcesResult, error) -func (*ClientSession) ListResourceTemplates(context.Context, *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) -func (*ClientSession) ReadResource(context.Context, *ReadResourceParams) (*ReadResourceResult, error) -``` +Servers support all of the resource-related spec methods: +- `ListResources` and `ListResourceTemplates` for listings. +- `ReadResource` to get the contents of a resource. +- `Subscribe` and `Unsubscribe` to manage subscriptions on resources. `ReadResource` checks the incoming URI against the server's list of resources and resource templates to make sure it matches one of them, then returns the result of calling the associated reader function. - +#### Subscriptions + +ClientSessions can manage change notifications on particular resources: +```go +func (*ClientSession) Subscribe(context.Context, *SubscribeParams) error +func (*ClientSession) Unsubscribe(context.Context, *UnsubscribeParams) error +``` + +The server does not implement resource subscriptions. It passes along +subscription requests to the user, and supplies a method to notify clients of +changes. It tracks which sessions have subscribed to which resources so the +user doesn't have to. + +If a server author wants to support resource subscriptions, they must provide handlers +to be called when clients subscribe and unsubscribe. It is an error to provide only +one of these handlers. +```go +type ServerOptions struct { + ... + // Function called when a client session subscribes to a resource. + SubscribeHandler func(context.Context, *SubscribeParams) error + // Function called when a client session unsubscribes from a resource. + UnsubscribeHandler func(context.Context, *UnsubscribeParams) error +} +``` + +User code should call `ResourceUpdated` when a subscribed resource changes. +```go +func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotification) error +``` +The server routes these notifications to the server sessions that subscribed to the resource. ### ListChanged notifications From 8ab19ea7c246da25052528241790a29f9a675828 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sun, 11 May 2025 17:35:03 -0400 Subject: [PATCH 048/196] internal/mcp: describe standard rpc signature Mention that all methods from the spec have the same signature, and then don't repeat those method declarations. Modify all method calls in the doc to take a Params struct, even if the current MCP spec says it is empty. Change `string` to `[]byte` in some of the listed struct fields. The current code generator incorrectly generates these as strings. (I have a forthcoming CL to fix that.) Also, elaborate the logging design. Change-Id: I5f323d4d98a4b6045d3cdef770d13b4bfdcb1044 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671358 Reviewed-by: Sam Thanawalla Reviewed-by: Robert Findley TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 91 +++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 31 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index aeb16b8c704..e1d8582dbf6 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -347,7 +347,7 @@ type Content struct { Type string `json:"type"` Text string `json:"text,omitempty"` MIMEType string `json:"mimeType,omitempty"` - Data string `json:"data,omitempty"` + Data []byte `json:"data,omitempty"` Resource *Resource `json:"resource,omitempty"` } @@ -359,7 +359,7 @@ type Resource struct { URI string `json:"uri,"` MIMEType string `json:"mimeType,omitempty"` Text string `json:"text"` - Blob *string `json:"blob"` + Blob []byte `json:"blob"` } ``` @@ -438,12 +438,14 @@ transport := mcp.NewCommandTransport(exec.Command("myserver")) session, err := client.Connect(ctx, transport) if err != nil { ... } // Call a tool on the server. -content, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}) +content, err := session.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"name": "you"} , +}) ... return session.Close() ``` - -And here's an example from the server side: +A server that can handle that client call would look like this: ```go // Create a server with a single tool. @@ -463,6 +465,8 @@ session until the client disconnects: func (*Server) Run(context.Context, Transport) ``` + + **Differences from mcp-go**: the Server APIs are very similar to mcp-go, though the association between servers and transports is different. In mcp-go, a single server is bound to what we would call an `SSEHTTPHandler`, @@ -484,6 +488,24 @@ behavior, whereas an options struct is used here. We felt that in this case, an options struct would be more readable, and result in cleaner package documentation. +### Spec Methods + +As we saw above, the `ClientSession` method for the specification's +`CallTool` RPC takes a context and a params pointer as arguments, and returns a +result pointer and error: +```go +func (*ClientSession) CallTool(context.Context, *CallToolParams) (*CallToolResult, error) +``` +Our SDK has a method for every RPC in the spec, and their signatures all share +this form. To avoid boilerplate, we don't repeat this signature for RPCs +defined in the spec; readers may assume it when we mention a "spec method." + +Why do we use params instead of the full request? JSON-RPC requests consist of a method +name and a set of parameters, and the method is already encoded in the Go method name. +Technically, the MCP spec could add a field to a request while preserving backward +compatibility, which would break the Go SDK's compatibility. But in the unlikely event +that were to happen, we would add that field to the Params struct. + ### Middleware We provide a mechanism to add MCP-level middleware, which runs after the @@ -589,8 +611,8 @@ Both `ClientSession` and `ServerSession` expose a `Ping` method to call "ping" on their peer. ```go -func (c *ClientSession) Ping(ctx context.Context) error -func (c *ServerSession) Ping(ctx context.Context) error +func (c *ClientSession) Ping(ctx context.Context, *PingParams) error +func (c *ServerSession) Ping(ctx context.Context, *PingParams) error ``` Additionally, client and server sessions can be configured with automatic @@ -634,14 +656,12 @@ func (*Client) AddRoots(roots ...Root) func (*Client) RemoveRoots(uris ...string) ``` -Servers can call `ListRoots` to get the roots. If a server installs a +Servers can call the spec method `ListRoots` to get the roots. If a server installs a `RootsChangedHandler`, it will be called when the client sends a roots-changed notification, which happens whenever the list of roots changes after a connection has been established. ```go -func (*Server) ListRoots(context.Context, *ListRootsParams) (*ListRootsResult, error) - type ServerOptions { ... // If non-nil, called when a client sends a roots-changed notification. @@ -652,15 +672,13 @@ type ServerOptions { ### Sampling Clients that support sampling are created with a `CreateMessageHandler` option -for handling server calls. To perform sampling, a server calls `CreateMessage`. +for handling server calls. To perform sampling, a server calls the spec method `CreateMessage`. ```go type ClientOptions struct { ... CreateMessageHandler func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) } - -func (*ServerSession) CreateMessage(context.Context, *CreateMessageParams) (*CreateMessageResult, error) ``` ## Server Features @@ -839,12 +857,7 @@ server.AddPrompts( server.RemovePrompts("code_review") ``` -Clients can call `ListPrompts` to list the available prompts and `GetPrompt` to get one. - -```go -func (*ClientSession) ListPrompts(context.Context, *ListPromptParams) (*ListPromptsResult, error) -func (*ClientSession) GetPrompt(context.Context, *GetPromptParams) (*GetPromptResult, error) -``` +Clients can call the spec method `ListPrompts` to list the available prompts and the spec method `GetPrompt` to get one. **Differences from mcp-go**: We provide a `NewPrompt` helper to bind a prompt handler to a Go function using reflection to derive its arguments. We provide @@ -942,9 +955,9 @@ A client will receive these notifications if it was created with the correspondi ```go type ClientOptions struct { ... - ToolListChangedHandler func(context.Context, *ClientConnection, *ToolListChangedParams) - PromptListChangedHandler func(context.Context, *ClientConnection, *PromptListChangedParams) - ResourceListChangedHandler func(context.Context, *ClientConnection, *ResourceListChangedParams) + ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) + PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) + ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) } ``` @@ -954,12 +967,7 @@ feature-specific handlers here. ### Completion -Clients call `Complete` to request completions. - -```go -func (*ClientSession) Complete(context.Context, *CompleteParams) (*CompleteResult, error) -``` - +Clients call the spec method `Complete` to request completions. Servers automatically handle these requests based on their collections of prompts and resources. @@ -968,16 +976,29 @@ defined its server-side behavior. ### Logging +Server-to-client logging is configured with `ServerOptions`: + +```go +type ServerOptions { + ... + // The value for the "logger" field of the notification. + LoggerName string + // Log notifications to a single ClientSession will not be + // send more frequently than this duration. + LogInterval time.Duration +} +``` + ServerSessions have access to a `slog.Logger` that writes to the client. A call to a log method like `Info`is translated to a `LoggingMessageNotification` as follows: -- An attribute with key "logger" is used to populate the "logger" field of the notification. - -- The remaining attributes and the message populate the "data" field with the +- The attributes and the message populate the "data" property with the output of a `slog.JSONHandler`: The result is always a JSON object, with the key "msg" for the message. +- If the `LoggerName` server option is set, it populates the "logger" property. + - The standard slog levels `Info`, `Debug`, `Warn` and `Error` map to the corresponding levels in the MCP spec. The other spec levels will be mapped to integers between the slog levels. For example, "notice" is level 2 because @@ -985,6 +1006,14 @@ follows: The `mcp` package defines consts for these levels. To log at the "notice" level, a handler would call `session.Log(ctx, mcp.LevelNotice, "message")`. +A client that wishes to receive log messages must provide a handler: +```go +type ClientOptions struct { + ... + LogMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams) +} +``` + ### Pagination From 865cd206823d2f102aa41db943ce0f1a58215c84 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Sun, 27 Apr 2025 11:40:37 -0400 Subject: [PATCH 049/196] x/tools: various cleanups related to planned parser changes This CL extracts the good bits out of a large exploratory patch (CL 668677) to update x/tools in anticipation of parser and AST changes that I was hoping to land in go1.25; however that isn't going to happen. Updates golang/go#73438 Updates golang/go#66790 Updates golang/go#66683 Updates golang/go#67704 Change-Id: Iba4a0a7c4a93d04fc6d46466c9fb9980d52067a3 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672055 LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Robert Findley --- go/ast/astutil/enclosing.go | 21 ++++-- gopls/internal/cache/check.go | 6 +- gopls/internal/cache/parsego/parse.go | 71 +++++++++---------- gopls/internal/cache/parsego/parse_test.go | 4 +- gopls/internal/cache/snapshot.go | 5 +- gopls/internal/fuzzy/matcher.go | 2 + gopls/internal/golang/addtest.go | 12 ++-- .../internal/golang/completion/completion.go | 40 ++++++----- gopls/internal/golang/modify_tags.go | 10 --- gopls/internal/golang/rename.go | 2 +- gopls/internal/golang/signature_help.go | 19 +++-- .../integration/completion/completion_test.go | 4 ++ gopls/internal/test/marker/marker_test.go | 2 +- .../codeaction/extract-variadic-63287.txt | 3 +- 14 files changed, 102 insertions(+), 99 deletions(-) diff --git a/go/ast/astutil/enclosing.go b/go/ast/astutil/enclosing.go index 6e34df46130..89f5097be00 100644 --- a/go/ast/astutil/enclosing.go +++ b/go/ast/astutil/enclosing.go @@ -207,6 +207,9 @@ func childrenOf(n ast.Node) []ast.Node { return false // no recursion }) + // TODO(adonovan): be more careful about missing (!Pos.Valid) + // tokens in trees produced from invalid input. + // Then add fake Nodes for bare tokens. switch n := n.(type) { case *ast.ArrayType: @@ -226,9 +229,12 @@ func childrenOf(n ast.Node) []ast.Node { children = append(children, tok(n.OpPos, len(n.Op.String()))) case *ast.BlockStmt: - children = append(children, - tok(n.Lbrace, len("{")), - tok(n.Rbrace, len("}"))) + if n.Lbrace.IsValid() { + children = append(children, tok(n.Lbrace, len("{"))) + } + if n.Rbrace.IsValid() { + children = append(children, tok(n.Rbrace, len("}"))) + } case *ast.BranchStmt: children = append(children, @@ -304,9 +310,12 @@ func childrenOf(n ast.Node) []ast.Node { // TODO(adonovan): Field.{Doc,Comment,Tag}? case *ast.FieldList: - children = append(children, - tok(n.Opening, len("(")), // or len("[") - tok(n.Closing, len(")"))) // or len("]") + if n.Opening.IsValid() { + children = append(children, tok(n.Opening, len("("))) + } + if n.Closing.IsValid() { + children = append(children, tok(n.Closing, len(")"))) + } case *ast.File: // TODO test: Doc diff --git a/gopls/internal/cache/check.go b/gopls/internal/cache/check.go index 909003288bc..bee0616c8a1 100644 --- a/gopls/internal/cache/check.go +++ b/gopls/internal/cache/check.go @@ -2070,12 +2070,14 @@ func typeErrorsToDiagnostics(pkg *syntaxPackage, inputs *typeCheckInputs, errs [ } } } else { + // TODO(adonovan): check File(start)==File(end). + // debugging golang/go#65960 if _, err := safetoken.Offset(pgf.Tok, end); err != nil { if pkg.hasFixedFiles() { - bug.Reportf("ReadGo116ErrorData returned invalid end: %v (fixed files)", err) + bug.Reportf("ErrorCodeStartEnd returned invalid end: %v (fixed files)", err) } else { - bug.Reportf("ReadGo116ErrorData returned invalid end: %v", err) + bug.Reportf("ErrorCodeStartEnd returned invalid end: %v", err) } } } diff --git a/gopls/internal/cache/parsego/parse.go b/gopls/internal/cache/parsego/parse.go index bc5483fc166..9a6bdf03da3 100644 --- a/gopls/internal/cache/parsego/parse.go +++ b/gopls/internal/cache/parsego/parse.go @@ -210,6 +210,7 @@ func fixAST(n ast.Node, tok *token.File, src []byte) (fixes []FixType) { // walkASTWithParent walks the AST rooted at n. The semantics are // similar to ast.Inspect except it does not call f(nil). +// TODO(adonovan): replace with PreorderStack. func walkASTWithParent(n ast.Node, f func(n ast.Node, parent ast.Node) bool) { var ancestors []ast.Node ast.Inspect(n, func(n ast.Node) (recurse bool) { @@ -422,8 +423,10 @@ func fixEmptySwitch(body *ast.BlockStmt, tok *token.File, src []byte) bool { return true } -// fixDanglingSelector inserts real "_" selector expressions in place -// of phantom "_" selectors. For example: +// fixDanglingSelector inserts a real "_" selector expression in place +// of a phantom parser-inserted "_" selector so that the parser will +// not consume the following non-identifier token. +// For example: // // func _() { // x.<> @@ -453,17 +456,13 @@ func fixDanglingSelector(s *ast.SelectorExpr, tf *token.File, src []byte) []byte return nil } - var buf bytes.Buffer - buf.Grow(len(src) + 1) - buf.Write(src[:insertOffset]) - buf.WriteByte('_') - buf.Write(src[insertOffset:]) - return buf.Bytes() + return slices.Concat(src[:insertOffset], []byte("_"), src[insertOffset:]) } -// fixPhantomSelector tries to fix selector expressions with phantom -// "_" selectors. In particular, we check if the selector is a -// keyword, and if so we swap in an *ast.Ident with the keyword text. For example: +// fixPhantomSelector tries to fix selector expressions whose Sel is a +// phantom (parser-invented) "_". If the text after the '.' is a +// keyword, it updates Sel to a fake ast.Ident of that name. For +// example: // // foo.var // @@ -498,21 +497,18 @@ func fixPhantomSelector(sel *ast.SelectorExpr, tf *token.File, src []byte) bool }) } -// isPhantomUnderscore reports whether the given ident is a phantom -// underscore. The parser sometimes inserts phantom underscores when -// it encounters otherwise unparseable situations. +// isPhantomUnderscore reports whether the given ident from a +// SelectorExpr.Sel was invented by the parser and is not present in +// source text. The parser creates a blank "_" identifier when the +// syntax (e.g. a selector) demands one but none is present. The fixer +// also inserts them. func isPhantomUnderscore(id *ast.Ident, tok *token.File, src []byte) bool { - if id == nil || id.Name != "_" { - return false + switch id.Name { + case "_": // go1.24 parser + offset, err := safetoken.Offset(tok, id.Pos()) + return err == nil && offset < len(src) && src[offset] != '_' } - - // Phantom underscore means the underscore is not actually in the - // program text. - offset, err := safetoken.Offset(tok, id.Pos()) - if err != nil { - return false - } - return len(src) <= offset || src[offset] != '_' + return false // real } // fixInitStmt fixes cases where the parser misinterprets an @@ -821,11 +817,7 @@ FindTo: // positions are valid. func parseStmt(tok *token.File, pos token.Pos, src []byte) (ast.Stmt, error) { // Wrap our expression to make it a valid Go file we can pass to ParseFile. - fileSrc := bytes.Join([][]byte{ - []byte("package fake;func _(){"), - src, - []byte("}"), - }, nil) + fileSrc := slices.Concat([]byte("package fake;func _(){"), src, []byte("}")) // Use ParseFile instead of ParseExpr because ParseFile has // best-effort behavior, whereas ParseExpr fails hard on any error. @@ -873,8 +865,8 @@ var tokenPosType = reflect.TypeOf(token.NoPos) // offsetPositions applies an offset to the positions in an ast.Node. func offsetPositions(tok *token.File, n ast.Node, offset token.Pos) { - fileBase := int64(tok.Base()) - fileEnd := fileBase + int64(tok.Size()) + fileBase := token.Pos(tok.Base()) + fileEnd := fileBase + token.Pos(tok.Size()) ast.Inspect(n, func(n ast.Node) bool { if n == nil { return false @@ -894,20 +886,21 @@ func offsetPositions(tok *token.File, n ast.Node, offset token.Pos) { continue } + pos := token.Pos(f.Int()) + // Don't offset invalid positions: they should stay invalid. - if !token.Pos(f.Int()).IsValid() { + if !pos.IsValid() { continue } // Clamp value to valid range; see #64335. // // TODO(golang/go#64335): this is a hack, because our fixes should not - // produce positions that overflow (but they do: golang/go#64488). - pos := max(f.Int()+int64(offset), fileBase) - if pos > fileEnd { - pos = fileEnd - } - f.SetInt(pos) + // produce positions that overflow (but they do; see golang/go#64488, + // #73438, #66790, #66683, #67704). + pos = min(max(pos+offset, fileBase), fileEnd) + + f.SetInt(int64(pos)) } } @@ -950,7 +943,7 @@ func replaceNode(parent, oldChild, newChild ast.Node) bool { switch f.Kind() { // Check interface and pointer fields. - case reflect.Interface, reflect.Ptr: + case reflect.Interface, reflect.Pointer: if tryReplace(f) { return true } diff --git a/gopls/internal/cache/parsego/parse_test.go b/gopls/internal/cache/parsego/parse_test.go index db78b596042..cbbc32e2723 100644 --- a/gopls/internal/cache/parsego/parse_test.go +++ b/gopls/internal/cache/parsego/parse_test.go @@ -300,14 +300,14 @@ func TestFixPhantomSelector(t *testing.T) { // ensure the selector has been converted to underscore by parser. ensureSource(t, src, func(sel *ast.SelectorExpr) { if sel.Sel.Name != "_" { - t.Errorf("%s: the input doesn't cause a blank selector after parser", tc.source) + t.Errorf("%s: selector name is %q, want _", tc.source, sel.Sel.Name) } }) fset := tokeninternal.FileSetFor(pgf.Tok) inspect(t, pgf, func(sel *ast.SelectorExpr) { // the fix should restore the selector as is. - if got, want := fmt.Sprintf("%s", analysisinternal.Format(fset, sel)), tc.source; got != want { + if got, want := analysisinternal.Format(fset, sel), tc.source; got != want { t.Fatalf("got %v want %v", got, want) } }) diff --git a/gopls/internal/cache/snapshot.go b/gopls/internal/cache/snapshot.go index f936bbfc458..8dda86071de 100644 --- a/gopls/internal/cache/snapshot.go +++ b/gopls/internal/cache/snapshot.go @@ -1463,10 +1463,11 @@ func orphanedFileDiagnosticRange(ctx context.Context, cache *parseCache, fh file return nil, protocol.Range{}, false } pgf := pgfs[0] - if !pgf.File.Name.Pos().IsValid() { + name := pgf.File.Name + if !name.Pos().IsValid() { return nil, protocol.Range{}, false } - rng, err := pgf.PosRange(pgf.File.Name.Pos(), pgf.File.Name.End()) + rng, err := pgf.PosRange(name.Pos(), name.End()) if err != nil { return nil, protocol.Range{}, false } diff --git a/gopls/internal/fuzzy/matcher.go b/gopls/internal/fuzzy/matcher.go index eff86efac34..8ce7e7ff3dd 100644 --- a/gopls/internal/fuzzy/matcher.go +++ b/gopls/internal/fuzzy/matcher.go @@ -61,6 +61,8 @@ type Matcher struct { rolesBuf [MaxInputSize]RuneRole } +func (m *Matcher) String() string { return m.pattern } + func (m *Matcher) bestK(i, j int) int { if m.scores[i][j][0].val() < m.scores[i][j][1].val() { return 1 diff --git a/gopls/internal/golang/addtest.go b/gopls/internal/golang/addtest.go index 3a5b1e03308..89d0be3d1fd 100644 --- a/gopls/internal/golang/addtest.go +++ b/gopls/internal/golang/addtest.go @@ -13,7 +13,6 @@ import ( "fmt" "go/ast" "go/format" - "go/token" "go/types" "os" "path/filepath" @@ -395,25 +394,26 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. NewText: header.String(), }) } else { // existing _test.go file. - if testPGF.File.Name == nil || testPGF.File.Name.NamePos == token.NoPos { + file := testPGF.File + if !file.Name.NamePos.IsValid() { return nil, fmt.Errorf("missing package declaration") } - switch testPGF.File.Name.Name { + switch file.Name.Name { case pgf.File.Name.Name: xtest = false case pgf.File.Name.Name + "_test": xtest = true default: - return nil, fmt.Errorf("invalid package declaration %q in test file %q", testPGF.File.Name, testPGF) + return nil, fmt.Errorf("invalid package declaration %q in test file %q", file.Name, testPGF) } - eofRange, err = testPGF.PosRange(testPGF.File.FileEnd, testPGF.File.FileEnd) + eofRange, err = testPGF.PosRange(file.FileEnd, file.FileEnd) if err != nil { return nil, err } // Collect all the imports from the foo_test.go. - if testImports, err = collectImports(testPGF.File); err != nil { + if testImports, err = collectImports(file); err != nil { return nil, err } } diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index ddaaac15ece..f61fdc6f7ba 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -10,7 +10,6 @@ import ( "context" "fmt" "go/ast" - "go/build" "go/constant" "go/parser" "go/printer" @@ -505,7 +504,9 @@ func Completion(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p startTime := time.Now() pkg, pgf, err := golang.NarrowestPackageForFile(ctx, snapshot, fh.URI()) - if err != nil || pgf.File.Package == token.NoPos { + if err != nil || !pgf.File.Package.IsValid() { + // Invalid package declaration + // // If we can't parse this file or find position for the package // keyword, it may be missing a package declaration. Try offering // suggestions for the package declaration. @@ -586,12 +587,6 @@ func Completion(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p } scopes = append(scopes, pkg.Types().Scope(), types.Universe) - var goversion string // "" => no version check - // Prior go1.22, the behavior of FileVersion is not useful to us. - if slices.Contains(build.Default.ReleaseTags, "go1.22") { - goversion = versions.FileVersion(info, pgf.File) // may be "" - } - opts := snapshot.Options() c := &completer{ pkg: pkg, @@ -605,7 +600,7 @@ func Completion(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p fh: fh, filename: fh.URI().Path(), pgf: pgf, - goversion: goversion, + goversion: versions.FileVersion(info, pgf.File), // may be "" => no version check path: path, pos: pos, seen: make(map[types.Object]bool), @@ -746,24 +741,30 @@ func (c *completer) collectCompletions(ctx context.Context) error { if c.pgf.File.Name == n { return c.packageNameCompletions(ctx, c.fh.URI(), n) } else if sel, ok := c.path[1].(*ast.SelectorExpr); ok && sel.Sel == n { - // Is this the Sel part of a selector? + // We are in the Sel part of a selector (e.g. x.‸sel or x.sel‸). return c.selector(ctx, sel) } return c.lexical(ctx) - // The function name hasn't been typed yet, but the parens are there: - // recv.‸(arg) + case *ast.TypeAssertExpr: + // The function name hasn't been typed yet, but the parens are there: + // recv.‸(arg) // Create a fake selector expression. - // + // The name "_" is the convention used by go/parser to represent phantom // selectors. sel := &ast.Ident{NamePos: n.X.End() + token.Pos(len(".")), Name: "_"} return c.selector(ctx, &ast.SelectorExpr{X: n.X, Sel: sel}) + case *ast.SelectorExpr: + // We are in the X part of a selector (x‸.sel), + // or after the dot with a fixed/phantom Sel (x.‸_). return c.selector(ctx, n) - // At the file scope, only keywords are allowed. + case *ast.BadDecl, *ast.File: + // At the file scope, only keywords are allowed. c.addKeywordCompletions() + default: // fallback to lexical completions return c.lexical(ctx) @@ -823,6 +824,8 @@ func (c *completer) scanToken(contents []byte) (token.Pos, token.Token, string) tok := c.pkg.FileSet().File(c.pos) var s scanner.Scanner + // TODO(adonovan): fix! this mutates the token.File borrowed from c.pkg, + // calling AddLine and AddLineColumnInfo. Not sound! s.Init(tok, contents, nil, 0) for { tknPos, tkn, lit := s.Scan() @@ -1232,6 +1235,9 @@ const ( ) // selector finds completions for the specified selector expression. +// +// The caller should ensure that sel.X has type information, +// even if sel is synthetic. func (c *completer) selector(ctx context.Context, sel *ast.SelectorExpr) error { c.inference.objChain = objChain(c.pkg.TypesInfo(), sel.X) @@ -1283,7 +1289,7 @@ func (c *completer) selector(ctx context.Context, sel *ast.SelectorExpr) error { // -- completion of symbols in unimported packages -- // use new code for unimported completions, if flag allows it - if id, ok := sel.X.(*ast.Ident); ok && c.snapshot.Options().ImportsSource == settings.ImportsSourceGopls { + if c.snapshot.Options().ImportsSource == settings.ImportsSourceGopls { // The user might have typed strings.TLower, so id.Name==strings, sel.Sel.Name == TLower, // but the cursor might be inside TLower, so adjust the prefix prefix := sel.Sel.Name @@ -2916,9 +2922,7 @@ func objChain(info *types.Info, e ast.Expr) []types.Object { } // Reverse order so the layout matches the syntactic order. - for i := range len(objs) / 2 { - objs[i], objs[len(objs)-1-i] = objs[len(objs)-1-i], objs[i] - } + slices.Reverse(objs) return objs } diff --git a/gopls/internal/golang/modify_tags.go b/gopls/internal/golang/modify_tags.go index 46748c841d1..c0a6b832730 100644 --- a/gopls/internal/golang/modify_tags.go +++ b/gopls/internal/golang/modify_tags.go @@ -10,7 +10,6 @@ import ( "fmt" "go/ast" "go/format" - "go/token" "github.com/fatih/gomodifytags/modifytags" "golang.org/x/tools/gopls/internal/cache" @@ -20,19 +19,10 @@ import ( "golang.org/x/tools/gopls/internal/protocol/command" "golang.org/x/tools/gopls/internal/util/moreiters" internalastutil "golang.org/x/tools/internal/astutil" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/tokeninternal" ) -// Finds the start and end positions of the enclosing struct or returns an error if none is found. -func findEnclosingStruct(c cursor.Cursor) (token.Pos, token.Pos, error) { - for cur := range c.Enclosing((*ast.StructType)(nil)) { - return cur.Node().Pos(), cur.Node().End(), nil - } - return token.NoPos, token.NoPos, fmt.Errorf("no struct enclosing the given positions") -} - // ModifyTags applies the given struct tag modifications to the specified struct. func ModifyTags(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, args command.ModifyTagsArgs, m *modifytags.Modification) ([]protocol.DocumentChange, error) { pgf, err := snapshot.ParseGo(ctx, fh, parsego.Full) diff --git a/gopls/internal/golang/rename.go b/gopls/internal/golang/rename.go index f23f179c6ff..f11a3356d74 100644 --- a/gopls/internal/golang/rename.go +++ b/gopls/internal/golang/rename.go @@ -1658,7 +1658,7 @@ func parsePackageNameDecl(ctx context.Context, snapshot *cache.Snapshot, fh file // Careful: because we used parsego.Header, // pgf.Pos(ppos) may be beyond EOF => (0, err). pos, _ := pgf.PositionPos(ppos) - return pgf, pgf.File.Name.Pos() <= pos && pos <= pgf.File.Name.End(), nil + return pgf, goplsastutil.NodeContains(pgf.File.Name, pos), nil } // enclosingFile returns the CompiledGoFile of pkg that contains the specified position. diff --git a/gopls/internal/golang/signature_help.go b/gopls/internal/golang/signature_help.go index 1dbd76d57d0..873111d20d9 100644 --- a/gopls/internal/golang/signature_help.go +++ b/gopls/internal/golang/signature_help.go @@ -5,6 +5,7 @@ package golang import ( + "cmp" "context" "fmt" "go/ast" @@ -51,7 +52,7 @@ loop: for i, node := range path { switch node := node.(type) { case *ast.Ident: - // If the selected text is a function/method Ident orSelectorExpr, + // If the selected text is a function/method Ident or SelectorExpr, // even one not in function call position, // show help for its signature. Example: // once.Do(initialize⁁) @@ -67,7 +68,8 @@ loop: break loop } case *ast.CallExpr: - if pos >= node.Lparen && pos <= node.Rparen { + // Beware: the ')' may be missing. + if pos >= node.Lparen && pos <= cmp.Or(node.Rparen, node.End()) { callExpr = node fnval = callExpr.Fun break loop @@ -88,7 +90,6 @@ loop: return nil, 0, nil } } - } if fnval == nil { @@ -194,19 +195,17 @@ func builtinSignature(ctx context.Context, snapshot *cache.Snapshot, callExpr *a }, activeParam, nil } -func activeParameter(callExpr *ast.CallExpr, numParams int, variadic bool, pos token.Pos) (activeParam int) { - if len(callExpr.Args) == 0 { +func activeParameter(call *ast.CallExpr, numParams int, variadic bool, pos token.Pos) (activeParam int) { + if len(call.Args) == 0 { return 0 } // First, check if the position is even in the range of the arguments. - start, end := callExpr.Lparen, callExpr.Rparen + // Beware: the Rparen may be missing. + start, end := call.Lparen, cmp.Or(call.Rparen, call.End()) if !(start <= pos && pos <= end) { return 0 } - for _, expr := range callExpr.Args { - if start == token.NoPos { - start = expr.Pos() - } + for _, expr := range call.Args { end = expr.End() if start <= pos && pos <= end { break diff --git a/gopls/internal/test/integration/completion/completion_test.go b/gopls/internal/test/integration/completion/completion_test.go index eb3d0a34161..59f10f8dff0 100644 --- a/gopls/internal/test/integration/completion/completion_test.go +++ b/gopls/internal/test/integration/completion/completion_test.go @@ -385,6 +385,10 @@ const Name = "mainmod" // Test that we can doctor the source code enough so the file is // parseable and completion works as expected. func TestSourceFixup(t *testing.T) { + // This example relies on the fixer to turn "s." into "s._" so + // that it parses as a SelectorExpr with only local problems, + // instead of snarfing up the following declaration of S + // looking for an identifier; thus completion offers s.i. const files = ` -- go.mod -- module mod.com diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index 51aba838b1b..dcc005ba391 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -1591,7 +1591,7 @@ func completeMarker(mark marker, src protocol.Location, want ...completionItem) want = nil // got is nil if empty } if diff := cmp.Diff(want, got); diff != "" { - mark.errorf("Completion(...) returned unexpect results (-want +got):\n%s", diff) + mark.errorf("Completion(...) returned unexpected results (-want +got):\n%s", diff) } } diff --git a/gopls/internal/test/marker/testdata/codeaction/extract-variadic-63287.txt b/gopls/internal/test/marker/testdata/codeaction/extract-variadic-63287.txt index afabcf49f2a..0e363f811f2 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract-variadic-63287.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract-variadic-63287.txt @@ -1,6 +1,6 @@ This test exercises extract on a variadic function. It is a regression test for bug #63287 in which -the final paramater's "..." would go missing. +the final parameter's "..." would go missing. -- go.mod -- module example.com @@ -25,4 +25,3 @@ func _() { + println(logf) +} + --- end -- From ade411c2a0fd1fa0e332a95c10f10907b22a9306 Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Thu, 17 Apr 2025 12:13:50 -0400 Subject: [PATCH 050/196] gopls/internal/lsprpc: start mcp server by passing -mcp-listen Use errgroup.WithContext to manage the lifecycle of LSP and MCP servers. Making sure both servers will always return error even upon sucessful completion so the error group can cancel the context. Only the first return value is returned. Creates event channels for handling event notification between mcp server and lsp server. Whenever a new lsp session establish, a new session event is sent from the lsp server to the mcp server with the session cache. Whenever an existing lsp session close, an exiting sessionevent is sent from the lsp server to the mcp server. Both cache.Cache and cache.Session are shared between lsp server and mcp server for static analysis. For now, the mcp server expose only a dummy hello tool without leveraging the cache to perform static analysis. For golang/go#73580 Change-Id: If511ada68ac1215cdd8440c195fcd9dc80a875fd Reviewed-on: https://go-review.googlesource.com/c/tools/+/670397 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan Auto-Submit: Hongxiang Jiang --- gopls/internal/cmd/serve.go | 143 ++++++++++++----- gopls/internal/cmd/usage/serve.hlp | 2 + gopls/internal/cmd/usage/usage-v.hlp | 2 + gopls/internal/cmd/usage/usage.hlp | 2 + gopls/internal/lsprpc/lsprpc.go | 24 ++- gopls/internal/lsprpc/lsprpc_test.go | 7 +- gopls/internal/mcp/mcp.go | 149 ++++++++++++++++++ gopls/internal/mcp/mcp_test.go | 36 +++++ .../test/integration/bench/stress_test.go | 5 +- gopls/internal/test/integration/runner.go | 4 +- gopls/internal/test/marker/marker_test.go | 2 +- gopls/internal/util/moremaps/maps.go | 9 ++ 12 files changed, 332 insertions(+), 53 deletions(-) create mode 100644 gopls/internal/mcp/mcp.go create mode 100644 gopls/internal/mcp/mcp_test.go diff --git a/gopls/internal/cmd/serve.go b/gopls/internal/cmd/serve.go index 16f3b160a73..7da129c8f2a 100644 --- a/gopls/internal/cmd/serve.go +++ b/gopls/internal/cmd/serve.go @@ -14,17 +14,19 @@ import ( "os" "time" + "golang.org/x/sync/errgroup" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/debug" "golang.org/x/tools/gopls/internal/lsprpc" + "golang.org/x/tools/gopls/internal/mcp" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/internal/fakenet" "golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/tool" ) -// Serve is a struct that exposes the configurable parts of the LSP server as -// flags, in the right form for tool.Main to consume. +// Serve is a struct that exposes the configurable parts of the LSP and MCP +// server as flags, in the right form for tool.Main to consume. type Serve struct { Logfile string `flag:"logfile" help:"filename to log to. if value is \"auto\", then logging to a default output file is enabled"` Mode string `flag:"mode" help:"no effect"` @@ -38,6 +40,9 @@ type Serve struct { RemoteDebug string `flag:"remote.debug" help:"when used with -remote=auto, the -debug value used to start the daemon"` RemoteLogfile string `flag:"remote.logfile" help:"when used with -remote=auto, the -logfile value used to start the daemon"` + // MCP Server related configurations. + MCPAddress string `flag:"mcp-listen" help:"experimental: address on which to listen for model context protocol connections. If port is localhost:0, pick a random port in localhost instead."` + app *Application } @@ -92,7 +97,17 @@ func (s *Serve) Run(ctx context.Context, args ...string) error { di.ServerAddress = s.Address di.Serve(ctx, s.Debug) } + var ss jsonrpc2.StreamServer + + // eventChan is used by the LSP server to send session lifecycle events + // (creation, exit) to the MCP server. The sender must ensure that an exit + // event for a given LSP session ID is sent after its corresponding creation + // event. + var eventChan chan mcp.SessionEvent + // cache shared between MCP and LSP servers. + var ca *cache.Cache + if s.app.Remote != "" { var err error ss, err = lsprpc.NewForwarder(s.app.Remote, s.remoteArgs) @@ -100,47 +115,93 @@ func (s *Serve) Run(ctx context.Context, args ...string) error { return fmt.Errorf("creating forwarder: %w", err) } } else { - ss = lsprpc.NewStreamServer(cache.New(nil), isDaemon, s.app.options) + if s.MCPAddress != "" { + eventChan = make(chan mcp.SessionEvent) + } + ca = cache.New(nil) + ss = lsprpc.NewStreamServer(ca, isDaemon, eventChan, s.app.options) } - var network, addr string - if s.Address != "" { - network, addr = lsprpc.ParseAddr(s.Address) - } - if s.Port != 0 { - network = "tcp" - // TODO(adonovan): should gopls ever be listening on network - // sockets, or only local ones? - // - // Ian says this was added in anticipation of - // something related to "VS Code remote" that turned - // out to be unnecessary. So I propose we limit it to - // localhost, if only so that we avoid the macOS - // firewall prompt. - // - // Hana says: "s.Address is for the remote access (LSP) - // and s.Port is for debugging purpose (according to - // the Server type documentation). I am not sure why the - // existing code here is mixing up and overwriting addr. - // For debugging endpoint, I think localhost makes perfect sense." - // - // TODO(adonovan): disentangle Address and Port, - // and use only localhost for the latter. - addr = fmt.Sprintf(":%v", s.Port) - } - if addr != "" { - log.Printf("Gopls daemon: listening on %s network, address %s...", network, addr) - defer log.Printf("Gopls daemon: exiting") - return jsonrpc2.ListenAndServe(ctx, network, addr, ss, s.IdleTimeout) - } - stream := jsonrpc2.NewHeaderStream(fakenet.NewConn("stdio", os.Stdin, os.Stdout)) - if s.Trace && di != nil { - stream = protocol.LoggingStream(stream, di.LogWriter) + group, ctx := errgroup.WithContext(ctx) + // Indicate success by a special error so that successful termination + // of one server causes cancellation of the other. + sucess := errors.New("success") + + // Start MCP server. + if eventChan != nil { + group.Go(func() (err error) { + defer func() { + if err == nil { + err = sucess + } + }() + + return mcp.Serve(ctx, s.MCPAddress, eventChan, ca, isDaemon) + }) } - conn := jsonrpc2.NewConn(stream) - err := ss.ServeStream(ctx, conn) - if errors.Is(err, io.EOF) { - return nil + + // Start LSP server. + group.Go(func() (err error) { + defer func() { + // Once we have finished serving LSP over jsonrpc or stdio, + // there can be no more session events. Notify the MCP server. + if eventChan != nil { + close(eventChan) + } + if err == nil { + err = sucess + } + }() + + var network, addr string + if s.Address != "" { + network, addr = lsprpc.ParseAddr(s.Address) + } + if s.Port != 0 { + network = "tcp" + // TODO(adonovan): should gopls ever be listening on network + // sockets, or only local ones? + // + // Ian says this was added in anticipation of + // something related to "VS Code remote" that turned + // out to be unnecessary. So I propose we limit it to + // localhost, if only so that we avoid the macOS + // firewall prompt. + // + // Hana says: "s.Address is for the remote access (LSP) + // and s.Port is for debugging purpose (according to + // the Server type documentation). I am not sure why the + // existing code here is mixing up and overwriting addr. + // For debugging endpoint, I think localhost makes perfect sense." + // + // TODO(adonovan): disentangle Address and Port, + // and use only localhost for the latter. + addr = fmt.Sprintf(":%v", s.Port) + } + + if addr != "" { + log.Printf("Gopls LSP daemon: listening on %s network, address %s...", network, addr) + defer log.Printf("Gopls LSP daemon: exiting") + return jsonrpc2.ListenAndServe(ctx, network, addr, ss, s.IdleTimeout) + } else { + stream := jsonrpc2.NewHeaderStream(fakenet.NewConn("stdio", os.Stdin, os.Stdout)) + if s.Trace && di != nil { + stream = protocol.LoggingStream(stream, di.LogWriter) + } + conn := jsonrpc2.NewConn(stream) + if err := ss.ServeStream(ctx, conn); errors.Is(err, io.EOF) { + return nil + } else { + return err + } + } + }) + + // Wait for all servers to terminate, returning only the first error + // encountered. Subsequent errors are typically due to context cancellation + // and are disregarded. + if err := group.Wait(); err != nil && !errors.Is(err, sucess) { + return err } - return err + return nil } diff --git a/gopls/internal/cmd/usage/serve.hlp b/gopls/internal/cmd/usage/serve.hlp index 370cbce83df..26c3d540ee0 100644 --- a/gopls/internal/cmd/usage/serve.hlp +++ b/gopls/internal/cmd/usage/serve.hlp @@ -16,6 +16,8 @@ server-flags: when used with -listen, shut down the server when there are no connected clients for this duration -logfile=string filename to log to. if value is "auto", then logging to a default output file is enabled + -mcp-listen=string + experimental: address on which to listen for model context protocol connections. If port is localhost:0, pick a random port in localhost instead. -mode=string no effect -port=int diff --git a/gopls/internal/cmd/usage/usage-v.hlp b/gopls/internal/cmd/usage/usage-v.hlp index 044d4251e89..ae5bd9bff0c 100644 --- a/gopls/internal/cmd/usage/usage-v.hlp +++ b/gopls/internal/cmd/usage/usage-v.hlp @@ -59,6 +59,8 @@ flags: when used with -listen, shut down the server when there are no connected clients for this duration -logfile=string filename to log to. if value is "auto", then logging to a default output file is enabled + -mcp-listen=string + experimental: address on which to listen for model context protocol connections. If port is localhost:0, pick a random port in localhost instead. -mode=string no effect -port=int diff --git a/gopls/internal/cmd/usage/usage.hlp b/gopls/internal/cmd/usage/usage.hlp index b918b24a411..a06fff583d5 100644 --- a/gopls/internal/cmd/usage/usage.hlp +++ b/gopls/internal/cmd/usage/usage.hlp @@ -56,6 +56,8 @@ flags: when used with -listen, shut down the server when there are no connected clients for this duration -logfile=string filename to log to. if value is "auto", then logging to a default output file is enabled + -mcp-listen=string + experimental: address on which to listen for model context protocol connections. If port is localhost:0, pick a random port in localhost instead. -mode=string no effect -port=int diff --git a/gopls/internal/lsprpc/lsprpc.go b/gopls/internal/lsprpc/lsprpc.go index 3d26bdd6896..b7fb40139f9 100644 --- a/gopls/internal/lsprpc/lsprpc.go +++ b/gopls/internal/lsprpc/lsprpc.go @@ -22,6 +22,7 @@ import ( "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/debug" "golang.org/x/tools/gopls/internal/label" + "golang.org/x/tools/gopls/internal/mcp" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/protocol/command" "golang.org/x/tools/gopls/internal/server" @@ -45,13 +46,17 @@ type streamServer struct { // serverForTest may be set to a test fake for testing. serverForTest protocol.Server + + // eventChan is an optional channel for LSP server session lifecycle events, + // including session creation and termination. If nil, no events are sent. + eventChan chan mcp.SessionEvent } // NewStreamServer creates a StreamServer using the shared cache. If // withTelemetry is true, each session is instrumented with telemetry that // records RPC statistics. -func NewStreamServer(cache *cache.Cache, daemon bool, optionsFunc func(*settings.Options)) jsonrpc2.StreamServer { - return &streamServer{cache: cache, daemon: daemon, optionsOverrides: optionsFunc} +func NewStreamServer(cache *cache.Cache, daemon bool, eventChan chan mcp.SessionEvent, optionsFunc func(*settings.Options)) jsonrpc2.StreamServer { + return &streamServer{cache: cache, daemon: daemon, eventChan: eventChan, optionsOverrides: optionsFunc} } // ServeStream implements the jsonrpc2.StreamServer interface, by handling @@ -86,10 +91,25 @@ func (s *streamServer) ServeStream(ctx context.Context, conn jsonrpc2.Conn) erro handshaker(session, executable, s.daemon, protocol.ServerHandler(svr, jsonrpc2.MethodNotFound)))) + + if s.eventChan != nil { + s.eventChan <- mcp.SessionEvent{ + Session: session, + Type: mcp.SessionNew, + } + defer func() { + s.eventChan <- mcp.SessionEvent{ + Session: session, + Type: mcp.SessionExiting, + } + }() + } + if s.daemon { log.Printf("Session %s: connected", session.ID()) defer log.Printf("Session %s: exited", session.ID()) } + <-conn.Done() return conn.Err() } diff --git a/gopls/internal/lsprpc/lsprpc_test.go b/gopls/internal/lsprpc/lsprpc_test.go index c8f0267cc3c..d3018383fcd 100644 --- a/gopls/internal/lsprpc/lsprpc_test.go +++ b/gopls/internal/lsprpc/lsprpc_test.go @@ -58,7 +58,7 @@ func TestClientLogging(t *testing.T) { client := FakeClient{Logs: make(chan string, 10)} ctx = debug.WithInstance(ctx) - ss := NewStreamServer(cache.New(nil), false, nil).(*StreamServer) + ss := NewStreamServer(cache.New(nil), false, nil, nil).(*StreamServer) ss.serverForTest = server ts := servertest.NewPipeServer(ss, nil) defer checkClose(t, ts.Close) @@ -121,7 +121,7 @@ func checkClose(t *testing.T, closer func() error) { func setupForwarding(ctx context.Context, t *testing.T, s protocol.Server) (direct, forwarded servertest.Connector, cleanup func()) { t.Helper() serveCtx := debug.WithInstance(ctx) - ss := NewStreamServer(cache.New(nil), false, nil).(*StreamServer) + ss := NewStreamServer(cache.New(nil), false, nil, nil).(*StreamServer) ss.serverForTest = s tsDirect := servertest.NewTCPServer(serveCtx, ss, nil) @@ -215,8 +215,7 @@ func TestDebugInfoLifecycle(t *testing.T) { clientCtx := debug.WithInstance(baseCtx) serverCtx := debug.WithInstance(baseCtx) - cache := cache.New(nil) - ss := NewStreamServer(cache, false, nil) + ss := NewStreamServer(cache.New(nil), false, nil, nil) tsBackend := servertest.NewTCPServer(serverCtx, ss, nil) forwarder, err := NewForwarder("tcp;"+tsBackend.Addr, nil) diff --git a/gopls/internal/mcp/mcp.go b/gopls/internal/mcp/mcp.go new file mode 100644 index 00000000000..ac09ffc300c --- /dev/null +++ b/gopls/internal/mcp/mcp.go @@ -0,0 +1,149 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "log" + "net" + "net/http" + "sync" + + "golang.org/x/tools/gopls/internal/cache" + "golang.org/x/tools/gopls/internal/util/moremaps" + "golang.org/x/tools/internal/mcp" +) + +// EventType differentiates between new and exiting sessions. +type EventType int + +const ( + SessionNew EventType = iota + SessionExiting +) + +// SessionEvent holds information about the session event. +type SessionEvent struct { + Type EventType + Session *cache.Session +} + +// Serve start a MCP server serving at the input address. +func Serve(ctx context.Context, address string, eventChan chan SessionEvent, cache *cache.Cache, isDaemon bool) error { + m := manager{ + mcpHandlers: make(map[string]*mcp.SSEHandler), + eventChan: eventChan, + cache: cache, + isDaemon: isDaemon, + } + return m.serve(ctx, address) +} + +// manager manages the mapping between LSP sessions and MCP servers. +type manager struct { + mu sync.Mutex // lock for mcpHandlers. + mcpHandlers map[string]*mcp.SSEHandler // map from lsp session ids to MCP sse handlers. + + eventChan chan SessionEvent // channel for receiving session creation and termination event + isDaemon bool + cache *cache.Cache // TODO(hxjiang): use cache to perform static analysis +} + +// serve serves MCP server at the input address. +func (m *manager) serve(ctx context.Context, address string) error { + // Spin up go routine listen to the session event channel until channel close. + go func() { + for event := range m.eventChan { + m.mu.Lock() + switch event.Type { + case SessionNew: + m.mcpHandlers[event.Session.ID()] = mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { + return newServer(m.cache, event.Session) + }) + case SessionExiting: + delete(m.mcpHandlers, event.Session.ID()) + } + m.mu.Unlock() + } + }() + + // In daemon mode, gopls serves mcp server at ADDRESS/sessions/$SESSIONID. + // Otherwise, gopls serves mcp server at ADDRESS. + mux := http.NewServeMux() + if m.isDaemon { + mux.HandleFunc("/sessions/{id}", func(w http.ResponseWriter, r *http.Request) { + sessionID := r.PathValue("id") + + m.mu.Lock() + handler := m.mcpHandlers[sessionID] + m.mu.Unlock() + + if handler == nil { + http.Error(w, fmt.Sprintf("session %s not established", sessionID), http.StatusNotFound) + return + } + + handler.ServeHTTP(w, r) + }) + } else { + // TODO(hxjiang): should gopls serve only at a specific path? + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + m.mu.Lock() + // When not in daemon mode, gopls has at most one LSP session. + _, handler, ok := moremaps.Arbitrary(m.mcpHandlers) + m.mu.Unlock() + + if !ok { + http.Error(w, "session not established", http.StatusNotFound) + return + } + + handler.ServeHTTP(w, r) + }) + } + + listener, err := net.Listen("tcp", address) + if err != nil { + return err + } + defer listener.Close() + // TODO(hxjiang): expose the mcp server address to the lsp client. + if m.isDaemon { + log.Printf("Gopls MCP daemon: listening on address %s...", listener.Addr()) + } + defer log.Printf("Gopls MCP server: exiting") + + svr := http.Server{ + Handler: mux, + BaseContext: func(net.Listener) context.Context { return ctx }, + } + // Run the server until cancellation. + go func() { + <-ctx.Done() + svr.Close() + }() + return svr.Serve(listener) +} + +func newServer(_ *cache.Cache, session *cache.Session) *mcp.Server { + s := mcp.NewServer("golang", "v0.1", nil) + + // TODO(hxjiang): replace dummy tool with tools which use cache and session. + s.AddTools(mcp.NewTool("hello_world", "Say hello to someone", helloHandler(session))) + return s +} + +type HelloParams struct { + Name string `json:"name" mcp:"the name to say hi to"` +} + +func helloHandler(session *cache.Session) func(ctx context.Context, cc *mcp.ServerConnection, request *HelloParams) ([]mcp.Content, error) { + return func(ctx context.Context, cc *mcp.ServerConnection, request *HelloParams) ([]mcp.Content, error) { + return []mcp.Content{ + mcp.TextContent{Text: "Hi " + request.Name + ", this is lsp session " + session.ID()}, + }, nil + } +} diff --git a/gopls/internal/mcp/mcp_test.go b/gopls/internal/mcp/mcp_test.go new file mode 100644 index 00000000000..95288f71aee --- /dev/null +++ b/gopls/internal/mcp/mcp_test.go @@ -0,0 +1,36 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "golang.org/x/tools/gopls/internal/mcp" +) + +func TestContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + res := make(chan error) + go func() { + res <- mcp.Serve(ctx, "localhost:0", nil, nil, true) + }() + + time.Sleep(1 * time.Second) + cancel() + + select { + case err := <-res: + if !errors.Is(err, http.ErrServerClosed) { + t.Errorf("mcp server unexpected return got %v, want: %v", err, http.ErrServerClosed) + } + case <-time.After(5 * time.Second): + t.Errorf("mcp server did not terminate after 5 seconds of context cancellation") + } +} diff --git a/gopls/internal/test/integration/bench/stress_test.go b/gopls/internal/test/integration/bench/stress_test.go index 1b63e3aff9e..3021ad88603 100644 --- a/gopls/internal/test/integration/bench/stress_test.go +++ b/gopls/internal/test/integration/bench/stress_test.go @@ -43,11 +43,10 @@ func TestPilosaStress(t *testing.T) { if err != nil { t.Fatal(err) } - - server := lsprpc.NewStreamServer(cache.New(nil), false, nil) + server := lsprpc.NewStreamServer(cache.New(nil), false, nil, nil) ts := servertest.NewPipeServer(server, jsonrpc2.NewRawStream) - ctx := context.Background() + ctx := context.Background() editor, err := fake.NewEditor(sandbox, fake.EditorConfig{}).Connect(ctx, ts, fake.ClientHooks{}) if err != nil { t.Fatal(err) diff --git a/gopls/internal/test/integration/runner.go b/gopls/internal/test/integration/runner.go index 8fdcc26af59..96427461580 100644 --- a/gopls/internal/test/integration/runner.go +++ b/gopls/internal/test/integration/runner.go @@ -341,7 +341,7 @@ func (s *loggingFramer) printBuffers(testname string, w io.Writer) { // defaultServer handles the Default execution mode. func (r *Runner) defaultServer() jsonrpc2.StreamServer { - return lsprpc.NewStreamServer(cache.New(r.store), false, nil) + return lsprpc.NewStreamServer(cache.New(r.store), false, nil, nil) } // forwardedServer handles the Forwarded execution mode. @@ -349,7 +349,7 @@ func (r *Runner) forwardedServer() jsonrpc2.StreamServer { r.tsOnce.Do(func() { ctx := context.Background() ctx = debug.WithInstance(ctx) - ss := lsprpc.NewStreamServer(cache.New(nil), false, nil) + ss := lsprpc.NewStreamServer(cache.New(nil), false, nil, nil) r.ts = servertest.NewTCPServer(ctx, ss, nil) }) return newForwarder("tcp", r.ts.Addr) diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index dcc005ba391..d2d4f899d48 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -968,7 +968,7 @@ func newEnv(t *testing.T, cache *cache.Cache, files, proxyFiles map[string][]byt ctx = debug.WithInstance(ctx) awaiter := integration.NewAwaiter(sandbox.Workdir) - ss := lsprpc.NewStreamServer(cache, false, nil) + ss := lsprpc.NewStreamServer(cache, false, nil, nil) server := servertest.NewPipeServer(ss, jsonrpc2.NewRawStream) editor, err := fake.NewEditor(sandbox, config).Connect(ctx, server, awaiter.Hooks()) if err != nil { diff --git a/gopls/internal/util/moremaps/maps.go b/gopls/internal/util/moremaps/maps.go index e25627d67b5..f85f20a9747 100644 --- a/gopls/internal/util/moremaps/maps.go +++ b/gopls/internal/util/moremaps/maps.go @@ -11,6 +11,15 @@ import ( "slices" ) +// Arbitrary returns an arbitrary (key, value) entry from the map and ok is true, if +// the map is not empty. Otherwise, it returns zero values for K and V, and false. +func Arbitrary[K comparable, V any](m map[K]V) (_ K, _ V, ok bool) { + for k, v := range m { + return k, v, true + } + return +} + // Group returns a new non-nil map containing the elements of s grouped by the // keys returned from the key func. func Group[K comparable, V any](s []V, key func(V) K) map[K][]V { From 8ac1955e840d0146bcdd2906bd8380c518415bde Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 10 May 2025 09:03:39 -0400 Subject: [PATCH 051/196] internal/mcp: implement resources Implementation of resources. This follows the design sketched in design/design.md, except for Server.FileResourceHandler, which will be in a forthcoming CL. As the spec recommends, the server verifies an incoming URI from the client against its list of resources, to make sure the URI came from one of them. Some things that aren't in the design doc: - I provide an exported function that creates a "resource not found" jsonrpc2.WireEror error for server authors to return. A wrinkle: the MCP spec defines the error code for "resource not found" as -32002 [1]. That code is reserved for use by implementations of JSON-RPC, not applications. That matters to us because our JSON-RPC implementation already uses that code for the "server closing" error. I filed a bug against MCP about it [2] and chose a different code for the "resource not found" error. - As a slight convenience, I fill in some fields in the ReadResourceResult if the handler doesn't set them. [1] https://modelcontextprotocol.io/specification/2025-03-26/server/resources#error-handling [2] https://github.com/modelcontextprotocol/modelcontextprotocol/issues/509 Change-Id: I06153da982ea00748a3691ded923448a880cea2a Reviewed-on: https://go-review.googlesource.com/c/tools/+/671363 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/client.go | 10 ++ internal/mcp/mcp_test.go | 253 +++++++++++++++++++----------- internal/mcp/protocol/generate.go | 17 +- internal/mcp/protocol/protocol.go | 56 +++++++ internal/mcp/server.go | 127 +++++++++++++-- internal/mcp/transport.go | 4 +- 6 files changed, 366 insertions(+), 101 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 6592cc8eb6d..3a32a3c7355 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -221,6 +221,16 @@ func (c *Client) CallTool(ctx context.Context, name string, args map[string]any) return standardCall[protocol.CallToolResult](ctx, c.conn, "tools/call", params) } +// ListResources lists the resources that are currently available on the server. +func (c *Client) ListResources(ctx context.Context, params *protocol.ListResourcesParams) (*protocol.ListResourcesResult, error) { + return standardCall[protocol.ListResourcesResult](ctx, c.conn, "resources/list", params) +} + +// ReadResource ask the server to read a resource and return its contents. +func (c *Client) ReadResource(ctx context.Context, params *protocol.ReadResourceParams) (*protocol.ReadResourceResult, error) { + return standardCall[protocol.ReadResourceResult](ctx, c.conn, "resources/read", params) +} + func standardCall[TRes, TParams any](ctx context.Context, conn *jsonrpc2.Connection, method string, params TParams) (*TRes, error) { var result TRes if err := call(ctx, conn, method, params, &result); err != nil { diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index aea6081473e..a4defd6f4d3 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" "golang.org/x/tools/internal/mcp/jsonschema" "golang.org/x/tools/internal/mcp/protocol" ) @@ -92,103 +93,179 @@ func TestEndToEnd(t *testing.T) { if err := c.Ping(ctx); err != nil { t.Fatalf("ping failed: %v", err) } + t.Run("prompts", func(t *testing.T) { + gotPrompts, err := c.ListPrompts(ctx) + if err != nil { + t.Errorf("prompts/list failed: %v", err) + } + wantPrompts := []protocol.Prompt{ + { + Name: "code_review", + Description: "do a code review", + Arguments: []protocol.PromptArgument{{Name: "Code", Required: true}}, + }, + {Name: "fail"}, + } + if diff := cmp.Diff(wantPrompts, gotPrompts); diff != "" { + t.Fatalf("prompts/list mismatch (-want +got):\n%s", diff) + } - gotPrompts, err := c.ListPrompts(ctx) - if err != nil { - t.Errorf("prompts/list failed: %v", err) - } - wantPrompts := []protocol.Prompt{ - { - Name: "code_review", - Description: "do a code review", - Arguments: []protocol.PromptArgument{{Name: "Code", Required: true}}, - }, - {Name: "fail"}, - } - if diff := cmp.Diff(wantPrompts, gotPrompts); diff != "" { - t.Fatalf("prompts/list mismatch (-want +got):\n%s", diff) - } - - gotReview, err := c.GetPrompt(ctx, "code_review", map[string]string{"Code": "1+1"}) - if err != nil { - t.Fatal(err) - } - wantReview := &protocol.GetPromptResult{ - Description: "Code review prompt", - Messages: []protocol.PromptMessage{{ - Content: TextContent{Text: "Please review the following code: 1+1"}.ToWire(), - Role: "user", - }}, - } - if diff := cmp.Diff(wantReview, gotReview); diff != "" { - t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) - } + gotReview, err := c.GetPrompt(ctx, "code_review", map[string]string{"Code": "1+1"}) + if err != nil { + t.Fatal(err) + } + wantReview := &protocol.GetPromptResult{ + Description: "Code review prompt", + Messages: []protocol.PromptMessage{{ + Content: TextContent{Text: "Please review the following code: 1+1"}.ToWire(), + Role: "user", + }}, + } + if diff := cmp.Diff(wantReview, gotReview); diff != "" { + t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) + } - if _, err := c.GetPrompt(ctx, "fail", map[string]string{}); err == nil || !strings.Contains(err.Error(), failure.Error()) { - t.Errorf("fail returned unexpected error: got %v, want containing %v", err, failure) - } + if _, err := c.GetPrompt(ctx, "fail", map[string]string{}); err == nil || !strings.Contains(err.Error(), failure.Error()) { + t.Errorf("fail returned unexpected error: got %v, want containing %v", err, failure) + } + }) - gotTools, err := c.ListTools(ctx) - if err != nil { - t.Errorf("tools/list failed: %v", err) - } - wantTools := []protocol.Tool{ - { - Name: "fail", - Description: "just fail", - InputSchema: &jsonschema.Schema{ - Type: "object", - AdditionalProperties: falseSchema, + t.Run("tools", func(t *testing.T) { + gotTools, err := c.ListTools(ctx) + if err != nil { + t.Errorf("tools/list failed: %v", err) + } + wantTools := []protocol.Tool{ + { + Name: "fail", + Description: "just fail", + InputSchema: &jsonschema.Schema{ + Type: "object", + AdditionalProperties: falseSchema, + }, }, - }, - { - Name: "greet", - Description: "say hi", - InputSchema: &jsonschema.Schema{ - Type: "object", - Required: []string{"Name"}, - Properties: map[string]*jsonschema.Schema{ - "Name": {Type: "string"}, + { + Name: "greet", + Description: "say hi", + InputSchema: &jsonschema.Schema{ + Type: "object", + Required: []string{"Name"}, + Properties: map[string]*jsonschema.Schema{ + "Name": {Type: "string"}, + }, + AdditionalProperties: falseSchema, }, - AdditionalProperties: falseSchema, }, - }, - } - if diff := cmp.Diff(wantTools, gotTools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) - } + } + if diff := cmp.Diff(wantTools, gotTools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) + } - gotHi, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}) - if err != nil { - t.Fatal(err) - } - wantHi := &protocol.CallToolResult{ - Content: []protocol.Content{{Type: "text", Text: "hi user"}}, - } - if diff := cmp.Diff(wantHi, gotHi); diff != "" { - t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) - } + gotHi, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}) + if err != nil { + t.Fatal(err) + } + wantHi := &protocol.CallToolResult{ + Content: []protocol.Content{{Type: "text", Text: "hi user"}}, + } + if diff := cmp.Diff(wantHi, gotHi); diff != "" { + t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) + } - gotFail, err := c.CallTool(ctx, "fail", map[string]any{}) - // Counter-intuitively, when a tool fails, we don't expect an RPC error for - // call tool: instead, the failure is embedded in the result. - if err != nil { - t.Fatal(err) - } - wantFail := &protocol.CallToolResult{ - IsError: true, - Content: []protocol.Content{{Type: "text", Text: failure.Error()}}, - } - if diff := cmp.Diff(wantFail, gotFail); diff != "" { - t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) - } + gotFail, err := c.CallTool(ctx, "fail", map[string]any{}) + // Counter-intuitively, when a tool fails, we don't expect an RPC error for + // call tool: instead, the failure is embedded in the result. + if err != nil { + t.Fatal(err) + } + wantFail := &protocol.CallToolResult{ + IsError: true, + Content: []protocol.Content{{Type: "text", Text: failure.Error()}}, + } + if diff := cmp.Diff(wantFail, gotFail); diff != "" { + t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) + } + }) - rootRes, err := sc.ListRoots(ctx, &protocol.ListRootsParams{}) - gotRoots := rootRes.Roots - wantRoots := slices.Collect(c.roots.all()) - if diff := cmp.Diff(wantRoots, gotRoots); diff != "" { - t.Errorf("roots/list mismatch (-want +got):\n%s", diff) - } + t.Run("resources", func(t *testing.T) { + resource1 := protocol.Resource{ + Name: "public", + MIMEType: "text/plain", + URI: "file:///file1.txt", + } + resource2 := protocol.Resource{ + Name: "public", // names are not unique IDs + MIMEType: "text/plain", + URI: "file:///nonexistent.txt", + } + + readHandler := func(_ context.Context, r protocol.Resource, _ *protocol.ReadResourceParams) (*protocol.ReadResourceResult, error) { + if r.URI == "file:///file1.txt" { + return &protocol.ReadResourceResult{ + Contents: &protocol.ResourceContents{ + Text: "file contents", + }, + }, nil + } + return nil, ResourceNotFoundError(r.URI) + } + s.AddResources( + &ServerResource{resource1, readHandler}, + &ServerResource{resource2, readHandler}) + + lrres, err := c.ListResources(ctx, nil) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff([]protocol.Resource{resource1, resource2}, lrres.Resources); diff != "" { + t.Errorf("resources/list mismatch (-want, +got):\n%s", diff) + } + + for _, tt := range []struct { + uri string + mimeType string // "": not found; "text/plain": resource; "text/template": template + }{ + {"file:///file1.txt", "text/plain"}, + {"file:///nonexistent.txt", ""}, + // TODO(jba): add resource template cases when we implement them + } { + rres, err := c.ReadResource(ctx, &protocol.ReadResourceParams{URI: tt.uri}) + if err != nil { + var werr *jsonrpc2.WireError + if errors.As(err, &werr) && werr.Code == codeResourceNotFound { + if tt.mimeType != "" { + t.Errorf("%s: not found but expected it to be", tt.uri) + } + } else { + t.Fatalf("reading %s: %v", tt.uri, err) + } + } else { + if got := rres.Contents.URI; got != tt.uri { + t.Errorf("got uri %q, want %q", got, tt.uri) + } + if got := rres.Contents.MIMEType; got != tt.mimeType { + t.Errorf("%s: got MIME type %q, want %q", tt.uri, got, tt.mimeType) + } + } + } + }) + t.Run("roots", func(t *testing.T) { + // Take the server's first ServerConnection. + var sc *ServerConnection + for sc = range s.Clients() { + break + } + + rootRes, err := sc.ListRoots(ctx, &protocol.ListRootsParams{}) + if err != nil { + t.Fatal(err) + } + gotRoots := rootRes.Roots + wantRoots := slices.Collect(c.roots.all()) + if diff := cmp.Diff(wantRoots, gotRoots); diff != "" { + t.Errorf("roots/list mismatch (-want +got):\n%s", diff) + } + }) // Disconnect. c.Close() diff --git a/internal/mcp/protocol/generate.go b/internal/mcp/protocol/generate.go index ff5f2465627..aed9da9b433 100644 --- a/internal/mcp/protocol/generate.go +++ b/internal/mcp/protocol/generate.go @@ -82,6 +82,11 @@ var declarations = config{ Fields: config{"Params": {Name: "ListPromptsParams"}}, }, "ListPromptsResult": {}, + "ListResourcesRequest": { + Name: "-", + Fields: config{"Params": {Name: "ListResourcesParams"}}, + }, + "ListResourcesResult": {}, "ListRootsRequest": { Name: "-", Fields: config{"Params": {Name: "ListRootsParams"}}, @@ -97,8 +102,16 @@ var declarations = config{ "PromptArgument": {}, "ProgressToken": {Name: "-", Substitute: "any"}, // null|number|string "RequestId": {Name: "-", Substitute: "any"}, // null|number|string - "Role": {}, - "Root": {}, + "ReadResourceRequest": { + Name: "-", + Fields: config{"Params": {Name: "ReadResourceParams"}}, + }, + "ReadResourceResult": { + Fields: config{"Contents": {Substitute: "*ResourceContents"}}, + }, + "Resource": {}, + "Role": {}, + "Root": {}, "ServerCapabilities": { Fields: config{ diff --git a/internal/mcp/protocol/protocol.go b/internal/mcp/protocol/protocol.go index fbc7b256fae..35b4088de6a 100644 --- a/internal/mcp/protocol/protocol.go +++ b/internal/mcp/protocol/protocol.go @@ -156,6 +156,23 @@ type ListPromptsResult struct { Prompts []Prompt `json:"prompts"` } +type ListResourcesParams struct { + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +// The server's response to a resources/list request from the client. +type ListResourcesResult struct { + // This result property is reserved by the protocol to allow clients and servers + // to attach additional metadata to their responses. + Meta map[string]json.RawMessage `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + Resources []Resource `json:"resources"` +} + type ListRootsParams struct { Meta *struct { // If specified, the caller is requesting out-of-band progress notifications for @@ -228,6 +245,45 @@ type PromptMessage struct { Role Role `json:"role"` } +type ReadResourceParams struct { + // The URI of the resource to read. The URI can use any protocol; it is up to + // the server how to interpret it. + URI string `json:"uri"` +} + +// The server's response to a resources/read request from the client. +type ReadResourceResult struct { + // This result property is reserved by the protocol to allow clients and servers + // to attach additional metadata to their responses. + Meta map[string]json.RawMessage `json:"_meta,omitempty"` + Contents *ResourceContents `json:"contents"` +} + +// A known resource that the server is capable of reading. +type Resource struct { + // Optional annotations for the client. + Annotations *Annotations `json:"annotations,omitempty"` + // A description of what this resource represents. + // + // This can be used by clients to improve the LLM's understanding of available + // resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // A human-readable name for this resource. + // + // This can be used by clients to populate UI elements. + Name string `json:"name"` + // The size of the raw resource content, in bytes (i.e., before base64 encoding + // or any tokenization), if known. + // + // This can be used by Hosts to display file sizes and estimate context window + // usage. + Size int64 `json:"size,omitempty"` + // The URI of this resource. + URI string `json:"uri"` +} + // Present if the server offers any resources to read. type ResourceCapabilities struct { // Whether this server supports notifications for changes to the resource list. diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 0a75411959e..95fd2b5f561 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -9,6 +9,8 @@ import ( "encoding/json" "fmt" "iter" + "log" + "net/url" "slices" "sync" @@ -26,10 +28,11 @@ type Server struct { version string opts ServerOptions - mu sync.Mutex - prompts *featureSet[*Prompt] - tools *featureSet[*Tool] - conns []*ServerConnection + mu sync.Mutex + prompts *featureSet[*Prompt] + tools *featureSet[*Tool] + resources *featureSet[*ServerResource] + conns []*ServerConnection } // ServerOptions is used to configure behavior of the server. @@ -49,11 +52,12 @@ func NewServer(name, version string, opts *ServerOptions) *Server { opts = new(ServerOptions) } return &Server{ - name: name, - version: version, - opts: *opts, - prompts: newFeatureSet(func(p *Prompt) string { return p.Definition.Name }), - tools: newFeatureSet(func(t *Tool) string { return t.Definition.Name }), + name: name, + version: version, + opts: *opts, + prompts: newFeatureSet(func(p *Prompt) string { return p.Definition.Name }), + tools: newFeatureSet(func(t *Tool) string { return t.Definition.Name }), + resources: newFeatureSet(func(r *ServerResource) string { return r.Resource.URI }), } } @@ -99,6 +103,62 @@ func (s *Server) RemoveTools(names ...string) { } } +// ResourceNotFoundError returns an error indicating that a resource being read could +// not be found. +func ResourceNotFoundError(uri string) error { + return &jsonrpc2.WireError{ + Code: codeResourceNotFound, + Message: "Resource not found", + Data: json.RawMessage(fmt.Sprintf(`{"uri":%q}`, uri)), + } +} + +// The error code to return when a resource isn't found. +// See https://modelcontextprotocol.io/specification/2025-03-26/server/resources#error-handling +// However, the code they chose in in the wrong space +// (see https://github.com/modelcontextprotocol/modelcontextprotocol/issues/509). +// so we pick a different one, arbirarily for now (until they fix it). +// The immediate problem is that jsonprc2 defines -32002 as "server closing". +const codeResourceNotFound = -31002 + +// A ReadResourceHandler is a function that reads a resource. +// If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. +type ReadResourceHandler func(context.Context, protocol.Resource, *protocol.ReadResourceParams) (*protocol.ReadResourceResult, error) + +// A ServerResource associates a Resource with its handler. +type ServerResource struct { + Resource protocol.Resource + Handler ReadResourceHandler +} + +// AddResource adds the given resource to the server and associates it with +// a [ReadResourceHandler], which will be called when the client calls [ClientSession.ReadResource]. +// If a resource with the same URI already exists, this one replaces it. +// AddResource panics if a resource URI is invalid or not absolute (has an empty scheme). +func (s *Server) AddResources(resources ...*ServerResource) { + s.mu.Lock() + defer s.mu.Unlock() + for _, r := range resources { + u, err := url.Parse(r.Resource.URI) + if err != nil { + panic(err) // url.Parse includes the URI in the error + } + if !u.IsAbs() { + panic(fmt.Errorf("URI %s needs a scheme", r.Resource.URI)) + } + s.resources.add(r) + } + // TODO: notify +} + +// RemoveResources removes the resources with the given URIs. +// It is not an error to remove a nonexistent resource. +func (s *Server) RemoveResources(uris ...string) { + s.mu.Lock() + defer s.mu.Unlock() + s.resources.remove(uris...) +} + // Clients returns an iterator that yields the current set of client // connections. func (s *Server) Clients() iter.Seq[*ServerConnection] { @@ -149,6 +209,47 @@ func (s *Server) callTool(ctx context.Context, cc *ServerConnection, params *pro return tool.Handler(ctx, cc, params.Arguments) } +func (s *Server) listResources(_ context.Context, _ *ServerConnection, params *protocol.ListResourcesParams) (*protocol.ListResourcesResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + res := new(protocol.ListResourcesResult) + for r := range s.resources.all() { + res.Resources = append(res.Resources, r.Resource) + } + return res, nil +} + +func (s *Server) readResource(ctx context.Context, _ *ServerConnection, params *protocol.ReadResourceParams) (*protocol.ReadResourceResult, error) { + log.Printf("readResource") + defer log.Printf("done") + uri := params.URI + // Look up the resource URI in the list we have. + // This is a security check as well as an information lookup. + s.mu.Lock() + resource, ok := s.resources.get(uri) + s.mu.Unlock() + if !ok { + // Don't expose the server configuration to the client. + // Treat an unregistered resource the same as a registered one that couldn't be found. + return nil, ResourceNotFoundError(uri) + } + res, err := resource.Handler(ctx, resource.Resource, params) + if err != nil { + return nil, err + } + if res == nil || res.Contents == nil { + return nil, fmt.Errorf("reading resource %s: read handler returned nil information", uri) + } + // As a convenience, populate some fields. + if res.Contents.URI == "" { + res.Contents.URI = uri + } + if res.Contents.MIMEType == "" { + res.Contents.MIMEType = resource.Resource.MIMEType + } + return res, nil +} + // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection. @@ -226,7 +327,7 @@ func (cc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) ( case "initialize", "ping": default: if !initialized { - return nil, fmt.Errorf("method %q is invalid during session ininitialization", req.Method) + return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } } @@ -254,6 +355,12 @@ func (cc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) ( case "tools/call": return dispatch(ctx, cc, req, cc.server.callTool) + case "resources/list": + return dispatch(ctx, cc, req, cc.server.listResources) + + case "resources/read": + return dispatch(ctx, cc, req, cc.server.readResource) + case "notifications/initialized": } return nil, jsonrpc2.ErrNotHandled diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index 403d3a2371c..faa2a806586 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -147,6 +147,8 @@ func (c *canceller) Preempt(ctx context.Context, req *jsonrpc2.Request) (result // call executes and awaits a jsonrpc2 call on the given connection, // translating errors into the mcp domain. func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { + // TODO: the "%w"s in this function effectively make jsonrpc2.WireError part of the API. + // Consider alternatives. call := conn.Call(ctx, method, params) err := call.Await(ctx, result) switch { @@ -160,7 +162,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params, }) return errors.Join(ctx.Err(), err) case err != nil: - return fmt.Errorf("calling %q: %v", method, err) + return fmt.Errorf("calling %q: %w", method, err) } return nil } From 0987b8950f448705e510e01e0f0735aa1abdafb1 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 26 Apr 2025 18:42:33 -0400 Subject: [PATCH 052/196] jsonschema: resolve remote references Support remote references. Most of the machinery was in place; we just had to define the right loader, copy files from the test repo, and download the meta-schema files from json-schema.org. Change-Id: Ib588200ef0fa483641b9b53f9904f53b39dc4b82 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670795 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- .../draft2020-12/meta/applicator.json | 45 +++ .../draft2020-12/meta/content.json | 14 + .../meta-schemas/draft2020-12/meta/core.json | 48 +++ .../draft2020-12/meta/format-annotation.json | 11 + .../draft2020-12/meta/meta-data.json | 34 ++ .../draft2020-12/meta/unevaluated.json | 12 + .../draft2020-12/meta/validation.json | 95 +++++ .../meta-schemas/draft2020-12/schema.json | 58 +++ .../testdata/draft2020-12/refRemote.json | 342 ++++++++++++++++++ .../mcp/jsonschema/testdata/remotes/README.md | 4 + .../remotes/different-id-ref-string.json | 5 + .../baseUriChange/folderInteger.json | 4 + .../baseUriChangeFolder/folderInteger.json | 4 + .../folderInteger.json | 4 + .../draft2020-12/detached-dynamicref.json | 13 + .../remotes/draft2020-12/detached-ref.json | 13 + .../draft2020-12/extendible-dynamic-ref.json | 21 ++ .../draft2020-12/format-assertion-false.json | 13 + .../draft2020-12/format-assertion-true.json | 13 + .../remotes/draft2020-12/integer.json | 4 + .../locationIndependentIdentifier.json | 12 + .../metaschema-no-validation.json | 13 + .../metaschema-optional-vocabulary.json | 14 + .../remotes/draft2020-12/name-defs.json | 16 + .../draft2020-12/nested/foo-ref-string.json | 7 + .../remotes/draft2020-12/nested/string.json | 4 + .../remotes/draft2020-12/prefixItems.json | 7 + .../remotes/draft2020-12/ref-and-defs.json | 12 + .../remotes/draft2020-12/subSchemas.json | 11 + .../testdata/remotes/draft2020-12/tree.json | 17 + .../nested-absolute-ref-to-string.json | 9 + .../testdata/remotes/urn-ref-string.json | 5 + internal/mcp/jsonschema/validate_test.go | 33 +- 33 files changed, 913 insertions(+), 4 deletions(-) create mode 100644 internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/applicator.json create mode 100644 internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/content.json create mode 100644 internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/core.json create mode 100644 internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json create mode 100644 internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/meta-data.json create mode 100644 internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json create mode 100644 internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/validation.json create mode 100644 internal/mcp/jsonschema/meta-schemas/draft2020-12/schema.json create mode 100644 internal/mcp/jsonschema/testdata/draft2020-12/refRemote.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/README.md create mode 100644 internal/mcp/jsonschema/testdata/remotes/different-id-ref-string.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/detached-ref.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/integer.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/name-defs.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/nested/string.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/prefixItems.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/subSchemas.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/draft2020-12/tree.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/nested-absolute-ref-to-string.json create mode 100644 internal/mcp/jsonschema/testdata/remotes/urn-ref-string.json diff --git a/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/applicator.json b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/applicator.json new file mode 100644 index 00000000000..f4775974a92 --- /dev/null +++ b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/applicator.json @@ -0,0 +1,45 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://json-schema.org/draft/2020-12/meta/applicator", + "$dynamicAnchor": "meta", + + "title": "Applicator vocabulary meta-schema", + "type": ["object", "boolean"], + "properties": { + "prefixItems": { "$ref": "#/$defs/schemaArray" }, + "items": { "$dynamicRef": "#meta" }, + "contains": { "$dynamicRef": "#meta" }, + "additionalProperties": { "$dynamicRef": "#meta" }, + "properties": { + "type": "object", + "additionalProperties": { "$dynamicRef": "#meta" }, + "default": {} + }, + "patternProperties": { + "type": "object", + "additionalProperties": { "$dynamicRef": "#meta" }, + "propertyNames": { "format": "regex" }, + "default": {} + }, + "dependentSchemas": { + "type": "object", + "additionalProperties": { "$dynamicRef": "#meta" }, + "default": {} + }, + "propertyNames": { "$dynamicRef": "#meta" }, + "if": { "$dynamicRef": "#meta" }, + "then": { "$dynamicRef": "#meta" }, + "else": { "$dynamicRef": "#meta" }, + "allOf": { "$ref": "#/$defs/schemaArray" }, + "anyOf": { "$ref": "#/$defs/schemaArray" }, + "oneOf": { "$ref": "#/$defs/schemaArray" }, + "not": { "$dynamicRef": "#meta" } + }, + "$defs": { + "schemaArray": { + "type": "array", + "minItems": 1, + "items": { "$dynamicRef": "#meta" } + } + } +} diff --git a/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/content.json b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/content.json new file mode 100644 index 00000000000..76e3760d269 --- /dev/null +++ b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/content.json @@ -0,0 +1,14 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://json-schema.org/draft/2020-12/meta/content", + "$dynamicAnchor": "meta", + + "title": "Content vocabulary meta-schema", + + "type": ["object", "boolean"], + "properties": { + "contentEncoding": { "type": "string" }, + "contentMediaType": { "type": "string" }, + "contentSchema": { "$dynamicRef": "#meta" } + } +} diff --git a/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/core.json b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/core.json new file mode 100644 index 00000000000..69186228948 --- /dev/null +++ b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/core.json @@ -0,0 +1,48 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://json-schema.org/draft/2020-12/meta/core", + "$dynamicAnchor": "meta", + + "title": "Core vocabulary meta-schema", + "type": ["object", "boolean"], + "properties": { + "$id": { + "$ref": "#/$defs/uriReferenceString", + "$comment": "Non-empty fragments not allowed.", + "pattern": "^[^#]*#?$" + }, + "$schema": { "$ref": "#/$defs/uriString" }, + "$ref": { "$ref": "#/$defs/uriReferenceString" }, + "$anchor": { "$ref": "#/$defs/anchorString" }, + "$dynamicRef": { "$ref": "#/$defs/uriReferenceString" }, + "$dynamicAnchor": { "$ref": "#/$defs/anchorString" }, + "$vocabulary": { + "type": "object", + "propertyNames": { "$ref": "#/$defs/uriString" }, + "additionalProperties": { + "type": "boolean" + } + }, + "$comment": { + "type": "string" + }, + "$defs": { + "type": "object", + "additionalProperties": { "$dynamicRef": "#meta" } + } + }, + "$defs": { + "anchorString": { + "type": "string", + "pattern": "^[A-Za-z_][-A-Za-z0-9._]*$" + }, + "uriString": { + "type": "string", + "format": "uri" + }, + "uriReferenceString": { + "type": "string", + "format": "uri-reference" + } + } +} diff --git a/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json new file mode 100644 index 00000000000..3479e6695ed --- /dev/null +++ b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json @@ -0,0 +1,11 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://json-schema.org/draft/2020-12/meta/format-annotation", + "$dynamicAnchor": "meta", + + "title": "Format vocabulary meta-schema for annotation results", + "type": ["object", "boolean"], + "properties": { + "format": { "type": "string" } + } +} diff --git a/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/meta-data.json b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/meta-data.json new file mode 100644 index 00000000000..4049ab21b11 --- /dev/null +++ b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/meta-data.json @@ -0,0 +1,34 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://json-schema.org/draft/2020-12/meta/meta-data", + "$dynamicAnchor": "meta", + + "title": "Meta-data vocabulary meta-schema", + + "type": ["object", "boolean"], + "properties": { + "title": { + "type": "string" + }, + "description": { + "type": "string" + }, + "default": true, + "deprecated": { + "type": "boolean", + "default": false + }, + "readOnly": { + "type": "boolean", + "default": false + }, + "writeOnly": { + "type": "boolean", + "default": false + }, + "examples": { + "type": "array", + "items": true + } + } +} diff --git a/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json new file mode 100644 index 00000000000..93779e54ed3 --- /dev/null +++ b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json @@ -0,0 +1,12 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://json-schema.org/draft/2020-12/meta/unevaluated", + "$dynamicAnchor": "meta", + + "title": "Unevaluated applicator vocabulary meta-schema", + "type": ["object", "boolean"], + "properties": { + "unevaluatedItems": { "$dynamicRef": "#meta" }, + "unevaluatedProperties": { "$dynamicRef": "#meta" } + } +} diff --git a/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/validation.json b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/validation.json new file mode 100644 index 00000000000..ebb75db77a7 --- /dev/null +++ b/internal/mcp/jsonschema/meta-schemas/draft2020-12/meta/validation.json @@ -0,0 +1,95 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://json-schema.org/draft/2020-12/meta/validation", + "$dynamicAnchor": "meta", + + "title": "Validation vocabulary meta-schema", + "type": ["object", "boolean"], + "properties": { + "type": { + "anyOf": [ + { "$ref": "#/$defs/simpleTypes" }, + { + "type": "array", + "items": { "$ref": "#/$defs/simpleTypes" }, + "minItems": 1, + "uniqueItems": true + } + ] + }, + "const": true, + "enum": { + "type": "array", + "items": true + }, + "multipleOf": { + "type": "number", + "exclusiveMinimum": 0 + }, + "maximum": { + "type": "number" + }, + "exclusiveMaximum": { + "type": "number" + }, + "minimum": { + "type": "number" + }, + "exclusiveMinimum": { + "type": "number" + }, + "maxLength": { "$ref": "#/$defs/nonNegativeInteger" }, + "minLength": { "$ref": "#/$defs/nonNegativeIntegerDefault0" }, + "pattern": { + "type": "string", + "format": "regex" + }, + "maxItems": { "$ref": "#/$defs/nonNegativeInteger" }, + "minItems": { "$ref": "#/$defs/nonNegativeIntegerDefault0" }, + "uniqueItems": { + "type": "boolean", + "default": false + }, + "maxContains": { "$ref": "#/$defs/nonNegativeInteger" }, + "minContains": { + "$ref": "#/$defs/nonNegativeInteger", + "default": 1 + }, + "maxProperties": { "$ref": "#/$defs/nonNegativeInteger" }, + "minProperties": { "$ref": "#/$defs/nonNegativeIntegerDefault0" }, + "required": { "$ref": "#/$defs/stringArray" }, + "dependentRequired": { + "type": "object", + "additionalProperties": { + "$ref": "#/$defs/stringArray" + } + } + }, + "$defs": { + "nonNegativeInteger": { + "type": "integer", + "minimum": 0 + }, + "nonNegativeIntegerDefault0": { + "$ref": "#/$defs/nonNegativeInteger", + "default": 0 + }, + "simpleTypes": { + "enum": [ + "array", + "boolean", + "integer", + "null", + "number", + "object", + "string" + ] + }, + "stringArray": { + "type": "array", + "items": { "type": "string" }, + "uniqueItems": true, + "default": [] + } + } +} diff --git a/internal/mcp/jsonschema/meta-schemas/draft2020-12/schema.json b/internal/mcp/jsonschema/meta-schemas/draft2020-12/schema.json new file mode 100644 index 00000000000..d5e2d31c3c8 --- /dev/null +++ b/internal/mcp/jsonschema/meta-schemas/draft2020-12/schema.json @@ -0,0 +1,58 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://json-schema.org/draft/2020-12/schema", + "$vocabulary": { + "https://json-schema.org/draft/2020-12/vocab/core": true, + "https://json-schema.org/draft/2020-12/vocab/applicator": true, + "https://json-schema.org/draft/2020-12/vocab/unevaluated": true, + "https://json-schema.org/draft/2020-12/vocab/validation": true, + "https://json-schema.org/draft/2020-12/vocab/meta-data": true, + "https://json-schema.org/draft/2020-12/vocab/format-annotation": true, + "https://json-schema.org/draft/2020-12/vocab/content": true + }, + "$dynamicAnchor": "meta", + + "title": "Core and Validation specifications meta-schema", + "allOf": [ + {"$ref": "meta/core"}, + {"$ref": "meta/applicator"}, + {"$ref": "meta/unevaluated"}, + {"$ref": "meta/validation"}, + {"$ref": "meta/meta-data"}, + {"$ref": "meta/format-annotation"}, + {"$ref": "meta/content"} + ], + "type": ["object", "boolean"], + "$comment": "This meta-schema also defines keywords that have appeared in previous drafts in order to prevent incompatible extensions as they remain in common use.", + "properties": { + "definitions": { + "$comment": "\"definitions\" has been replaced by \"$defs\".", + "type": "object", + "additionalProperties": { "$dynamicRef": "#meta" }, + "deprecated": true, + "default": {} + }, + "dependencies": { + "$comment": "\"dependencies\" has been split and replaced by \"dependentSchemas\" and \"dependentRequired\" in order to serve their differing semantics.", + "type": "object", + "additionalProperties": { + "anyOf": [ + { "$dynamicRef": "#meta" }, + { "$ref": "meta/validation#/$defs/stringArray" } + ] + }, + "deprecated": true, + "default": {} + }, + "$recursiveAnchor": { + "$comment": "\"$recursiveAnchor\" has been replaced by \"$dynamicAnchor\".", + "$ref": "meta/core#/$defs/anchorString", + "deprecated": true + }, + "$recursiveRef": { + "$comment": "\"$recursiveRef\" has been replaced by \"$dynamicRef\".", + "$ref": "meta/core#/$defs/uriReferenceString", + "deprecated": true + } + } +} diff --git a/internal/mcp/jsonschema/testdata/draft2020-12/refRemote.json b/internal/mcp/jsonschema/testdata/draft2020-12/refRemote.json new file mode 100644 index 00000000000..047ac74ca0c --- /dev/null +++ b/internal/mcp/jsonschema/testdata/draft2020-12/refRemote.json @@ -0,0 +1,342 @@ +[ + { + "description": "remote ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://localhost:1234/draft2020-12/integer.json" + }, + "tests": [ + { + "description": "remote ref valid", + "data": 1, + "valid": true + }, + { + "description": "remote ref invalid", + "data": "a", + "valid": false + } + ] + }, + { + "description": "fragment within remote ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://localhost:1234/draft2020-12/subSchemas.json#/$defs/integer" + }, + "tests": [ + { + "description": "remote fragment valid", + "data": 1, + "valid": true + }, + { + "description": "remote fragment invalid", + "data": "a", + "valid": false + } + ] + }, + { + "description": "anchor within remote ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://localhost:1234/draft2020-12/locationIndependentIdentifier.json#foo" + }, + "tests": [ + { + "description": "remote anchor valid", + "data": 1, + "valid": true + }, + { + "description": "remote anchor invalid", + "data": "a", + "valid": false + } + ] + }, + { + "description": "ref within remote ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://localhost:1234/draft2020-12/subSchemas.json#/$defs/refToInteger" + }, + "tests": [ + { + "description": "ref within ref valid", + "data": 1, + "valid": true + }, + { + "description": "ref within ref invalid", + "data": "a", + "valid": false + } + ] + }, + { + "description": "base URI change", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/", + "items": { + "$id": "baseUriChange/", + "items": {"$ref": "folderInteger.json"} + } + }, + "tests": [ + { + "description": "base URI change ref valid", + "data": [[1]], + "valid": true + }, + { + "description": "base URI change ref invalid", + "data": [["a"]], + "valid": false + } + ] + }, + { + "description": "base URI change - change folder", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/scope_change_defs1.json", + "type" : "object", + "properties": {"list": {"$ref": "baseUriChangeFolder/"}}, + "$defs": { + "baz": { + "$id": "baseUriChangeFolder/", + "type": "array", + "items": {"$ref": "folderInteger.json"} + } + } + }, + "tests": [ + { + "description": "number is valid", + "data": {"list": [1]}, + "valid": true + }, + { + "description": "string is invalid", + "data": {"list": ["a"]}, + "valid": false + } + ] + }, + { + "description": "base URI change - change folder in subschema", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/scope_change_defs2.json", + "type" : "object", + "properties": {"list": {"$ref": "baseUriChangeFolderInSubschema/#/$defs/bar"}}, + "$defs": { + "baz": { + "$id": "baseUriChangeFolderInSubschema/", + "$defs": { + "bar": { + "type": "array", + "items": {"$ref": "folderInteger.json"} + } + } + } + } + }, + "tests": [ + { + "description": "number is valid", + "data": {"list": [1]}, + "valid": true + }, + { + "description": "string is invalid", + "data": {"list": ["a"]}, + "valid": false + } + ] + }, + { + "description": "root ref in remote ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/object", + "type": "object", + "properties": { + "name": {"$ref": "name-defs.json#/$defs/orNull"} + } + }, + "tests": [ + { + "description": "string is valid", + "data": { + "name": "foo" + }, + "valid": true + }, + { + "description": "null is valid", + "data": { + "name": null + }, + "valid": true + }, + { + "description": "object is invalid", + "data": { + "name": { + "name": null + } + }, + "valid": false + } + ] + }, + { + "description": "remote ref with ref to defs", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/schema-remote-ref-ref-defs1.json", + "$ref": "ref-and-defs.json" + }, + "tests": [ + { + "description": "invalid", + "data": { + "bar": 1 + }, + "valid": false + }, + { + "description": "valid", + "data": { + "bar": "a" + }, + "valid": true + } + ] + }, + { + "description": "Location-independent identifier in remote ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://localhost:1234/draft2020-12/locationIndependentIdentifier.json#/$defs/refToInteger" + }, + "tests": [ + { + "description": "integer is valid", + "data": 1, + "valid": true + }, + { + "description": "string is invalid", + "data": "foo", + "valid": false + } + ] + }, + { + "description": "retrieved nested refs resolve relative to their URI not $id", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/some-id", + "properties": { + "name": {"$ref": "nested/foo-ref-string.json"} + } + }, + "tests": [ + { + "description": "number is invalid", + "data": { + "name": {"foo": 1} + }, + "valid": false + }, + { + "description": "string is valid", + "data": { + "name": {"foo": "a"} + }, + "valid": true + } + ] + }, + { + "description": "remote HTTP ref with different $id", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://localhost:1234/different-id-ref-string.json" + }, + "tests": [ + { + "description": "number is invalid", + "data": 1, + "valid": false + }, + { + "description": "string is valid", + "data": "foo", + "valid": true + } + ] + }, + { + "description": "remote HTTP ref with different URN $id", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://localhost:1234/urn-ref-string.json" + }, + "tests": [ + { + "description": "number is invalid", + "data": 1, + "valid": false + }, + { + "description": "string is valid", + "data": "foo", + "valid": true + } + ] + }, + { + "description": "remote HTTP ref with nested absolute ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://localhost:1234/nested-absolute-ref-to-string.json" + }, + "tests": [ + { + "description": "number is invalid", + "data": 1, + "valid": false + }, + { + "description": "string is valid", + "data": "foo", + "valid": true + } + ] + }, + { + "description": "$ref to $ref finds detached $anchor", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://localhost:1234/draft2020-12/detached-ref.json#/$defs/foo" + }, + "tests": [ + { + "description": "number is valid", + "data": 1, + "valid": true + }, + { + "description": "non-number is invalid", + "data": "a", + "valid": false + } + ] + } +] diff --git a/internal/mcp/jsonschema/testdata/remotes/README.md b/internal/mcp/jsonschema/testdata/remotes/README.md new file mode 100644 index 00000000000..8a641dbd348 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/README.md @@ -0,0 +1,4 @@ +# JSON Schema test suite: remote references + +These files were copied from +https://github.com/json-schema-org/JSON-Schema-Test-Suite/tree/83e866b46c9f9e7082fd51e83a61c5f2145a1ab7/remotes. diff --git a/internal/mcp/jsonschema/testdata/remotes/different-id-ref-string.json b/internal/mcp/jsonschema/testdata/remotes/different-id-ref-string.json new file mode 100644 index 00000000000..7f888609398 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/different-id-ref-string.json @@ -0,0 +1,5 @@ +{ + "$id": "http://localhost:1234/real-id-ref-string.json", + "$defs": {"bar": {"type": "string"}}, + "$ref": "#/$defs/bar" +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json new file mode 100644 index 00000000000..1f44a631321 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "integer" +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json new file mode 100644 index 00000000000..1f44a631321 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "integer" +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json new file mode 100644 index 00000000000..1f44a631321 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "integer" +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json new file mode 100644 index 00000000000..07cce1dac47 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json @@ -0,0 +1,13 @@ +{ + "$id": "http://localhost:1234/draft2020-12/detached-dynamicref.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "foo": { + "$dynamicRef": "#detached" + }, + "detached": { + "$dynamicAnchor": "detached", + "type": "integer" + } + } +} \ No newline at end of file diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/detached-ref.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/detached-ref.json new file mode 100644 index 00000000000..9c2dca93ca4 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/detached-ref.json @@ -0,0 +1,13 @@ +{ + "$id": "http://localhost:1234/draft2020-12/detached-ref.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "foo": { + "$ref": "#detached" + }, + "detached": { + "$anchor": "detached", + "type": "integer" + } + } +} \ No newline at end of file diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json new file mode 100644 index 00000000000..65bc0c217d3 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json @@ -0,0 +1,21 @@ +{ + "description": "extendible array", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/extendible-dynamic-ref.json", + "type": "object", + "properties": { + "elements": { + "type": "array", + "items": { + "$dynamicRef": "#elements" + } + } + }, + "required": ["elements"], + "additionalProperties": false, + "$defs": { + "elements": { + "$dynamicAnchor": "elements" + } + } +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json new file mode 100644 index 00000000000..43a711c9d20 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json @@ -0,0 +1,13 @@ +{ + "$id": "http://localhost:1234/draft2020-12/format-assertion-false.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$vocabulary": { + "https://json-schema.org/draft/2020-12/vocab/core": true, + "https://json-schema.org/draft/2020-12/vocab/format-assertion": false + }, + "$dynamicAnchor": "meta", + "allOf": [ + { "$ref": "https://json-schema.org/draft/2020-12/meta/core" }, + { "$ref": "https://json-schema.org/draft/2020-12/meta/format-assertion" } + ] +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json new file mode 100644 index 00000000000..39c6b0abf5b --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json @@ -0,0 +1,13 @@ +{ + "$id": "http://localhost:1234/draft2020-12/format-assertion-true.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$vocabulary": { + "https://json-schema.org/draft/2020-12/vocab/core": true, + "https://json-schema.org/draft/2020-12/vocab/format-assertion": true + }, + "$dynamicAnchor": "meta", + "allOf": [ + { "$ref": "https://json-schema.org/draft/2020-12/meta/core" }, + { "$ref": "https://json-schema.org/draft/2020-12/meta/format-assertion" } + ] +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/integer.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/integer.json new file mode 100644 index 00000000000..1f44a631321 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/integer.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "integer" +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json new file mode 100644 index 00000000000..6565a1ee000 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json @@ -0,0 +1,12 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "refToInteger": { + "$ref": "#foo" + }, + "A": { + "$anchor": "foo", + "type": "integer" + } + } +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json new file mode 100644 index 00000000000..71be8b5da08 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json @@ -0,0 +1,13 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/metaschema-no-validation.json", + "$vocabulary": { + "https://json-schema.org/draft/2020-12/vocab/applicator": true, + "https://json-schema.org/draft/2020-12/vocab/core": true + }, + "$dynamicAnchor": "meta", + "allOf": [ + { "$ref": "https://json-schema.org/draft/2020-12/meta/applicator" }, + { "$ref": "https://json-schema.org/draft/2020-12/meta/core" } + ] +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json new file mode 100644 index 00000000000..a6963e54806 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json @@ -0,0 +1,14 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/metaschema-optional-vocabulary.json", + "$vocabulary": { + "https://json-schema.org/draft/2020-12/vocab/validation": true, + "https://json-schema.org/draft/2020-12/vocab/core": true, + "http://localhost:1234/draft/2020-12/vocab/custom": false + }, + "$dynamicAnchor": "meta", + "allOf": [ + { "$ref": "https://json-schema.org/draft/2020-12/meta/validation" }, + { "$ref": "https://json-schema.org/draft/2020-12/meta/core" } + ] +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/name-defs.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/name-defs.json new file mode 100644 index 00000000000..67bc33c5151 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/name-defs.json @@ -0,0 +1,16 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "orNull": { + "anyOf": [ + { + "type": "null" + }, + { + "$ref": "#" + } + ] + } + }, + "type": "string" +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json new file mode 100644 index 00000000000..29661ff9fb1 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json @@ -0,0 +1,7 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "foo": {"$ref": "string.json"} + } +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/nested/string.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/nested/string.json new file mode 100644 index 00000000000..6607ac53454 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/nested/string.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "string" +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/prefixItems.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/prefixItems.json new file mode 100644 index 00000000000..acd8293c61a --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/prefixItems.json @@ -0,0 +1,7 @@ +{ + "$id": "http://localhost:1234/draft2020-12/prefixItems.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "prefixItems": [ + {"type": "string"} + ] +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json new file mode 100644 index 00000000000..16d30fa3aa3 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json @@ -0,0 +1,12 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/ref-and-defs.json", + "$defs": { + "inner": { + "properties": { + "bar": { "type": "string" } + } + } + }, + "$ref": "#/$defs/inner" +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/subSchemas.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/subSchemas.json new file mode 100644 index 00000000000..1bb4846d757 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/subSchemas.json @@ -0,0 +1,11 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "integer": { + "type": "integer" + }, + "refToInteger": { + "$ref": "#/$defs/integer" + } + } +} diff --git a/internal/mcp/jsonschema/testdata/remotes/draft2020-12/tree.json b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/tree.json new file mode 100644 index 00000000000..b07555fb333 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/draft2020-12/tree.json @@ -0,0 +1,17 @@ +{ + "description": "tree schema, extensible", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/tree.json", + "$dynamicAnchor": "node", + + "type": "object", + "properties": { + "data": true, + "children": { + "type": "array", + "items": { + "$dynamicRef": "#node" + } + } + } +} diff --git a/internal/mcp/jsonschema/testdata/remotes/nested-absolute-ref-to-string.json b/internal/mcp/jsonschema/testdata/remotes/nested-absolute-ref-to-string.json new file mode 100644 index 00000000000..f46c761643c --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/nested-absolute-ref-to-string.json @@ -0,0 +1,9 @@ +{ + "$defs": { + "bar": { + "$id": "http://localhost:1234/the-nested-id.json", + "type": "string" + } + }, + "$ref": "http://localhost:1234/the-nested-id.json" +} diff --git a/internal/mcp/jsonschema/testdata/remotes/urn-ref-string.json b/internal/mcp/jsonschema/testdata/remotes/urn-ref-string.json new file mode 100644 index 00000000000..aca2211b7f0 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/remotes/urn-ref-string.json @@ -0,0 +1,5 @@ +{ + "$id": "urn:uuid:feebdaed-ffff-0000-ffff-0000deadbeef", + "$defs": {"bar": {"type": "string"}}, + "$ref": "#/$defs/bar" +} diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index f8fe929eb0e..bd66560ef83 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -6,6 +6,8 @@ package jsonschema import ( "encoding/json" + "fmt" + "net/url" "os" "path/filepath" "strings" @@ -52,15 +54,12 @@ func TestValidate(t *testing.T) { } for _, g := range groups { t.Run(g.Description, func(t *testing.T) { - if strings.Contains(g.Description, "remote ref") { - t.Skip("remote refs not yet supported") - } for s := range g.Schema.all() { if s.DynamicAnchor != "" || s.DynamicRef != "" { t.Skip("schema or subschema has unimplemented keywords") } } - rs, err := g.Schema.Resolve("", nil) + rs, err := g.Schema.Resolve("", loadRemote) if err != nil { t.Fatal(err) } @@ -116,3 +115,29 @@ func TestStructInstance(t *testing.T) { } } } + +// loadRemote loads a remote reference used in the test suite. +func loadRemote(uri *url.URL) (*Schema, error) { + // Anything with localhost:1234 refers to the remotes directory in the test suite repo. + if uri.Host == "localhost:1234" { + return loadSchemaFromFile(filepath.FromSlash(filepath.Join("testdata/remotes", uri.Path))) + } + // One test needs the meta-schema files. + const metaPrefix = "https://json-schema.org/draft/2020-12/" + if after, ok := strings.CutPrefix(uri.String(), metaPrefix); ok { + return loadSchemaFromFile(filepath.FromSlash("meta-schemas/draft2020-12/" + after + ".json")) + } + return nil, fmt.Errorf("don't know how to load %s", uri) +} + +func loadSchemaFromFile(filename string) (*Schema, error) { + data, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + var s Schema + if err := json.Unmarshal(data, &s); err != nil { + return nil, fmt.Errorf("unmarshaling JSON at %s: %w", filename, err) + } + return &s, nil +} From 7d76ce6d146a00d2de90c2e67e395015f7aeffd0 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 7 May 2025 21:11:11 -0400 Subject: [PATCH 053/196] internal/mcp/jsonschema: check that schemas form a tree When we check the root schema during resolution, first make sure its descendants form a tree. Change-Id: Icc430e7df463aa84b961c28616be567f98960430 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670776 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- internal/mcp/jsonschema/resolve.go | 15 ++++++++++++--- internal/mcp/jsonschema/resolve_test.go | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index 512ead0341d..2ba51443773 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -161,14 +161,23 @@ func (r *resolver) resolveRefs(rs *Resolved) error { return nil } -func (s *Schema) check() error { - if s == nil { +func (root *Schema) check() error { + if root == nil { return errors.New("nil schema") } var errs []error report := func(err error) { errs = append(errs, err) } - for ss := range s.all() { + seen := map[*Schema]bool{} + for ss := range root.all() { + if seen[ss] { + // The schema graph rooted at s is not a tree, but it needs to + // be because we assume a unique parent when we store a schema's base + // in the Schema. A cycle would also put Schema.all into an infinite + // recursion. + return fmt.Errorf("schemas rooted at %s do not form a tree (saw %s twice)", root, ss) + } + seen[ss] = true ss.checkLocal(report) } return errors.Join(errs...) diff --git a/internal/mcp/jsonschema/resolve_test.go b/internal/mcp/jsonschema/resolve_test.go index 2b469d4dd9a..67b2fe0f687 100644 --- a/internal/mcp/jsonschema/resolve_test.go +++ b/internal/mcp/jsonschema/resolve_test.go @@ -10,6 +10,7 @@ import ( "net/url" "regexp" "slices" + "strings" "testing" ) @@ -40,6 +41,22 @@ func TestCheckLocal(t *testing.T) { } } +func TestSchemaNonTree(t *testing.T) { + run := func(s *Schema, kind string) { + err := s.check() + if err == nil || !strings.Contains(err.Error(), "tree") { + t.Fatalf("did not detect %s", kind) + } + } + + s := &Schema{Type: "number"} + run(&Schema{Items: s, Contains: s}, "DAG") + + root := &Schema{Items: s} + s.Items = root + run(root, "cycle") +} + func TestResolveURIs(t *testing.T) { for _, baseURI := range []string{"", "http://a.com"} { t.Run(baseURI, func(t *testing.T) { From 403f7ef14562a7007a9e4e1fc996a49a491299fc Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Thu, 8 May 2025 18:59:40 +0000 Subject: [PATCH 054/196] internal/mcp/design: add pagination Design for pagination. Change-Id: I8257a499b1e4691382b5681b84a52030344ab0e3 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671135 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam Reviewed-by: Robert Findley --- internal/mcp/design/design.md | 37 ++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index e1d8582dbf6..aeba7f82158 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -1016,4 +1016,39 @@ type ClientOptions struct { ### Pagination - +Servers initiate pagination for `ListTools`, `ListPrompts`, `ListResources`, +and `ListResourceTemplates`, dictating the page size and providing a +`NextCursor` field in the Result if more pages exist. The SDK implements keyset +pagination, using the `unique ID` as the key for a stable sort order and encoding +the cursor as an opaque string. + +For server implementations, the page size for the list operation may be +configured via the `ServerOptions.PageSize` field. PageSize must be a +non-negative integer. If zero, a sensible default is used. + +```go +type ServerOptions { + ... + PageSize int +} +``` + +Client requests for List methods include an optional Cursor field for +pagination. Server responses for List methods include a `NextCursor` field if +more pages exist. + +In addition to the `List` methods, the SDK provides an iterator method for each +list operation. This simplifies pagination for cients by automatically handling +the underlying pagination logic. + +For example, we if we have a List method like this: + +```go +func (*ClientSession) ListTools(context.Context, *ListToolsParams) (*ListToolsResult, error) +``` + +We will also provide an iterator method like this: + +```go +func (*ClientSession) Tools(context.Context, *ListToolsParams) iter.Seq2[Tool, error] +``` From f1f12cfafeb44535738c285fec3130825ee95e72 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 13 May 2025 20:58:47 +0000 Subject: [PATCH 055/196] internal/mcp: merge the protocol package into mcp In accordance with the updated design, merge the protocol package into the mcp package. To do so, rename protocol.Content to WireContent and protocol.ResourceContents to WireResource. Also, remove some stray logging in readResource. Change-Id: I76f61ce8d67dd3ac520ef2e1959ade0a0785a480 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672375 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/mcp/client.go | 51 ++++---- internal/mcp/cmd_test.go | 5 +- internal/mcp/content.go | 79 +++++++++--- internal/mcp/content_test.go | 17 ++- internal/mcp/examples/hello/main.go | 7 +- internal/mcp/{protocol => }/generate.go | 21 +-- internal/mcp/mcp.go | 2 + internal/mcp/mcp_test.go | 47 ++++--- internal/mcp/prompt.go | 31 +++-- internal/mcp/prompt_test.go | 11 +- internal/mcp/{protocol => }/protocol.go | 162 ++++++++++++------------ internal/mcp/protocol/content.go | 56 -------- internal/mcp/protocol/doc.go | 13 -- internal/mcp/server.go | 56 ++++---- internal/mcp/sse_test.go | 5 +- internal/mcp/tool.go | 29 ++--- internal/mcp/tool_test.go | 2 +- internal/mcp/transport.go | 5 +- 18 files changed, 282 insertions(+), 317 deletions(-) rename internal/mcp/{protocol => }/generate.go (96%) rename internal/mcp/{protocol => }/protocol.go (94%) delete mode 100644 internal/mcp/protocol/content.go delete mode 100644 internal/mcp/protocol/doc.go diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 3a32a3c7355..f538e16fddb 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -13,7 +13,6 @@ import ( "sync" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" - "golang.org/x/tools/internal/mcp/protocol" ) // A Client is an MCP client, which may be connected to an MCP server @@ -25,8 +24,8 @@ type Client struct { opts ClientOptions mu sync.Mutex conn *jsonrpc2.Connection - roots *featureSet[protocol.Root] - initializeResult *protocol.InitializeResult + roots *featureSet[Root] + initializeResult *initializeResult } // NewClient creates a new Client. @@ -39,7 +38,7 @@ func NewClient(name, version string, t Transport, opts *ClientOptions) *Client { name: name, version: version, transport: t, - roots: newFeatureSet(func(r protocol.Root) string { return r.URI }), + roots: newFeatureSet(func(r Root) string { return r.URI }), } if opts != nil { c.opts = *opts @@ -84,13 +83,13 @@ func (c *Client) Start(ctx context.Context) (err error) { if err != nil { return err } - params := &protocol.InitializeParams{ - ClientInfo: protocol.Implementation{Name: c.name, Version: c.version}, + params := &initializeParams{ + ClientInfo: implementation{Name: c.name, Version: c.version}, } if err := call(ctx, c.conn, "initialize", params, &c.initializeResult); err != nil { return err } - if err := c.conn.Notify(ctx, "notifications/initialized", &protocol.InitializedParams{}); err != nil { + if err := c.conn.Notify(ctx, "notifications/initialized", &initializedParams{}); err != nil { return err } return nil @@ -113,7 +112,7 @@ func (c *Client) Wait() error { // replacing any with the same URIs, // and notifies any connected servers. // TODO: notification -func (c *Client) AddRoots(roots ...protocol.Root) { +func (c *Client) AddRoots(roots ...Root) { c.mu.Lock() defer c.mu.Unlock() c.roots.add(roots...) @@ -129,10 +128,10 @@ func (c *Client) RemoveRoots(uris ...string) { c.roots.remove(uris...) } -func (c *Client) listRoots(_ context.Context, _ *protocol.ListRootsParams) (*protocol.ListRootsResult, error) { +func (c *Client) listRoots(_ context.Context, _ *ListRootsParams) (*ListRootsResult, error) { c.mu.Lock() defer c.mu.Unlock() - return &protocol.ListRootsResult{ + return &ListRootsResult{ Roots: slices.Collect(c.roots.all()), }, nil } @@ -160,10 +159,10 @@ func (c *Client) Ping(ctx context.Context) error { } // ListPrompts lists prompts that are currently available on the server. -func (c *Client) ListPrompts(ctx context.Context) ([]protocol.Prompt, error) { +func (c *Client) ListPrompts(ctx context.Context) ([]Prompt, error) { var ( - params = &protocol.ListPromptsParams{} - result protocol.ListPromptsResult + params = &ListPromptsParams{} + result ListPromptsResult ) if err := call(ctx, c.conn, "prompts/list", params, &result); err != nil { return nil, err @@ -172,13 +171,13 @@ func (c *Client) ListPrompts(ctx context.Context) ([]protocol.Prompt, error) { } // GetPrompt gets a prompt from the server. -func (c *Client) GetPrompt(ctx context.Context, name string, args map[string]string) (*protocol.GetPromptResult, error) { +func (c *Client) GetPrompt(ctx context.Context, name string, args map[string]string) (*GetPromptResult, error) { var ( - params = &protocol.GetPromptParams{ + params = &GetPromptParams{ Name: name, Arguments: args, } - result = &protocol.GetPromptResult{} + result = &GetPromptResult{} ) if err := call(ctx, c.conn, "prompts/get", params, result); err != nil { return nil, err @@ -187,10 +186,10 @@ func (c *Client) GetPrompt(ctx context.Context, name string, args map[string]str } // ListTools lists tools that are currently available on the server. -func (c *Client) ListTools(ctx context.Context) ([]protocol.Tool, error) { +func (c *Client) ListTools(ctx context.Context) ([]Tool, error) { var ( - params = &protocol.ListToolsParams{} - result protocol.ListToolsResult + params = &ListToolsParams{} + result ListToolsResult ) if err := call(ctx, c.conn, "tools/list", params, &result); err != nil { return nil, err @@ -199,7 +198,7 @@ func (c *Client) ListTools(ctx context.Context) ([]protocol.Tool, error) { } // CallTool calls the tool with the given name and arguments. -func (c *Client) CallTool(ctx context.Context, name string, args map[string]any) (_ *protocol.CallToolResult, err error) { +func (c *Client) CallTool(ctx context.Context, name string, args map[string]any) (_ *CallToolResult, err error) { defer func() { if err != nil { err = fmt.Errorf("calling tool %q: %w", name, err) @@ -214,21 +213,21 @@ func (c *Client) CallTool(ctx context.Context, name string, args map[string]any) argsJSON[name] = argJSON } - params := &protocol.CallToolParams{ + params := &CallToolParams{ Name: name, Arguments: argsJSON, } - return standardCall[protocol.CallToolResult](ctx, c.conn, "tools/call", params) + return standardCall[CallToolResult](ctx, c.conn, "tools/call", params) } // ListResources lists the resources that are currently available on the server. -func (c *Client) ListResources(ctx context.Context, params *protocol.ListResourcesParams) (*protocol.ListResourcesResult, error) { - return standardCall[protocol.ListResourcesResult](ctx, c.conn, "resources/list", params) +func (c *Client) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { + return standardCall[ListResourcesResult](ctx, c.conn, "resources/list", params) } // ReadResource ask the server to read a resource and return its contents. -func (c *Client) ReadResource(ctx context.Context, params *protocol.ReadResourceParams) (*protocol.ReadResourceResult, error) { - return standardCall[protocol.ReadResourceResult](ctx, c.conn, "resources/read", params) +func (c *Client) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { + return standardCall[ReadResourceResult](ctx, c.conn, "resources/read", params) } func standardCall[TRes, TParams any](ctx context.Context, conn *jsonrpc2.Connection, method string, params TParams) (*TRes, error) { diff --git a/internal/mcp/cmd_test.go b/internal/mcp/cmd_test.go index e9f2c8fbb64..abcaff321dd 100644 --- a/internal/mcp/cmd_test.go +++ b/internal/mcp/cmd_test.go @@ -13,7 +13,6 @@ import ( "github.com/google/go-cmp/cmp" "golang.org/x/tools/internal/mcp" - "golang.org/x/tools/internal/mcp/protocol" ) const runAsServer = "_MCP_RUN_AS_SERVER" @@ -57,8 +56,8 @@ func TestCmdTransport(t *testing.T) { if err != nil { log.Fatal(err) } - want := &protocol.CallToolResult{ - Content: []protocol.Content{{Type: "text", Text: "Hi user"}}, + want := &mcp.CallToolResult{ + Content: []mcp.WireContent{{Type: "text", Text: "Hi user"}}, } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("greet returned unexpected content (-want +got):\n%s", diff) diff --git a/internal/mcp/content.go b/internal/mcp/content.go index bf318e32ee1..24cb2cfb1f1 100644 --- a/internal/mcp/content.go +++ b/internal/mcp/content.go @@ -5,18 +5,63 @@ package mcp import ( + "encoding/json" "fmt" - - "golang.org/x/tools/internal/mcp/protocol" ) +// The []byte fields below are marked omitzero, not omitempty: +// we want to marshal an empty byte slice. + +// WireContent is the wire format for content. +// It represents the protocol types TextContent, ImageContent, AudioContent +// and EmbeddedResource. +// The Type field distinguishes them. In the protocol, each type has a constant +// value for the field. +// At most one of Text, Data, and Resource is non-zero. +type WireContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Data []byte `json:"data,omitzero"` + Resource *WireResource `json:"resource,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` +} + +// A WireResource is either a TextResourceContents or a BlobResourceContents. +// See https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts#L524-L551 +// for the inheritance structure. +// If Blob is nil, this is a TextResourceContents; otherwise it's a BlobResourceContents. +// +// The URI field describes the resource location. +type WireResource struct { + URI string `json:"uri,"` + MIMEType string `json:"mimeType,omitempty"` + Text string `json:"text"` + Blob []byte `json:"blob,omitzero"` +} + +func (c *WireContent) UnmarshalJSON(data []byte) error { + type wireContent WireContent // for naive unmarshaling + var c2 wireContent + if err := json.Unmarshal(data, &c2); err != nil { + return err + } + switch c2.Type { + case "text", "image", "audio", "resource": + default: + return fmt.Errorf("unrecognized content type %s", c.Type) + } + *c = WireContent(c2) + return nil +} + // Content is the union of supported content types: [TextContent], // [ImageContent], [AudioContent], and [ResourceContent]. // // ToWire converts content to its jsonrpc2 wire format. type Content interface { // TODO: unexport this, and move the tests that use it to this package. - ToWire() protocol.Content + ToWire() WireContent } // TextContent is a textual content. @@ -24,8 +69,8 @@ type TextContent struct { Text string } -func (c TextContent) ToWire() protocol.Content { - return protocol.Content{Type: "text", Text: c.Text} +func (c TextContent) ToWire() WireContent { + return WireContent{Type: "text", Text: c.Text} } // ImageContent contains base64-encoded image data. @@ -34,8 +79,8 @@ type ImageContent struct { MIMEType string } -func (c ImageContent) ToWire() protocol.Content { - return protocol.Content{Type: "image", MIMEType: c.MIMEType, Data: c.Data} +func (c ImageContent) ToWire() WireContent { + return WireContent{Type: "image", MIMEType: c.MIMEType, Data: c.Data} } // AudioContent contains base64-encoded audio data. @@ -44,8 +89,8 @@ type AudioContent struct { MIMEType string } -func (c AudioContent) ToWire() protocol.Content { - return protocol.Content{Type: "audio", MIMEType: c.MIMEType, Data: c.Data} +func (c AudioContent) ToWire() WireContent { + return WireContent{Type: "audio", MIMEType: c.MIMEType, Data: c.Data} } // ResourceContent contains embedded resources. @@ -53,13 +98,13 @@ type ResourceContent struct { Resource EmbeddedResource } -func (r ResourceContent) ToWire() protocol.Content { +func (r ResourceContent) ToWire() WireContent { res := r.Resource.toWire() - return protocol.Content{Type: "resource", Resource: &res} + return WireContent{Type: "resource", Resource: &res} } type EmbeddedResource interface { - toWire() protocol.ResourceContents + toWire() WireResource } // The {Text,Blob}ResourceContents types match the protocol definitions, @@ -72,8 +117,8 @@ type TextResourceContents struct { Text string } -func (r TextResourceContents) toWire() protocol.ResourceContents { - return protocol.ResourceContents{ +func (r TextResourceContents) toWire() WireResource { + return WireResource{ URI: r.URI, MIMEType: r.MIMEType, Text: r.Text, @@ -88,8 +133,8 @@ type BlobResourceContents struct { Blob []byte } -func (r BlobResourceContents) toWire() protocol.ResourceContents { - return protocol.ResourceContents{ +func (r BlobResourceContents) toWire() WireResource { + return WireResource{ URI: r.URI, MIMEType: r.MIMEType, Blob: r.Blob, @@ -98,7 +143,7 @@ func (r BlobResourceContents) toWire() protocol.ResourceContents { // ContentFromWireContent converts content from the jsonrpc2 wire format to a // typed Content value. -func ContentFromWireContent(c protocol.Content) Content { +func ContentFromWireContent(c WireContent) Content { switch c.Type { case "text": return TextContent{Text: c.Text} diff --git a/internal/mcp/content_test.go b/internal/mcp/content_test.go index 1984db36be4..548989afa90 100644 --- a/internal/mcp/content_test.go +++ b/internal/mcp/content_test.go @@ -9,22 +9,21 @@ import ( "github.com/google/go-cmp/cmp" "golang.org/x/tools/internal/mcp" - "golang.org/x/tools/internal/mcp/protocol" ) func TestContent(t *testing.T) { tests := []struct { in mcp.Content - want protocol.Content + want mcp.WireContent }{ - {mcp.TextContent{Text: "hello"}, protocol.Content{Type: "text", Text: "hello"}}, + {mcp.TextContent{Text: "hello"}, mcp.WireContent{Type: "text", Text: "hello"}}, { mcp.ImageContent{Data: []byte("a1b2c3"), MIMEType: "image/png"}, - protocol.Content{Type: "image", Data: []byte("a1b2c3"), MIMEType: "image/png"}, + mcp.WireContent{Type: "image", Data: []byte("a1b2c3"), MIMEType: "image/png"}, }, { mcp.AudioContent{Data: []byte("a1b2c3"), MIMEType: "audio/wav"}, - protocol.Content{Type: "audio", Data: []byte("a1b2c3"), MIMEType: "audio/wav"}, + mcp.WireContent{Type: "audio", Data: []byte("a1b2c3"), MIMEType: "audio/wav"}, }, { mcp.ResourceContent{ @@ -34,9 +33,9 @@ func TestContent(t *testing.T) { Text: "abc", }, }, - protocol.Content{ + mcp.WireContent{ Type: "resource", - Resource: &protocol.ResourceContents{ + Resource: &mcp.WireResource{ URI: "file://foo", MIMEType: "text", Text: "abc", @@ -51,9 +50,9 @@ func TestContent(t *testing.T) { Blob: []byte("a1b2c3"), }, }, - protocol.Content{ + mcp.WireContent{ Type: "resource", - Resource: &protocol.ResourceContents{ + Resource: &mcp.WireResource{ URI: "file://foo", MIMEType: "text", Blob: []byte("a1b2c3"), diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index 56a32618e7b..9cfba154d37 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -12,7 +12,6 @@ import ( "os" "golang.org/x/tools/internal/mcp" - "golang.org/x/tools/internal/mcp/protocol" ) var httpAddr = flag.String("http", "", "if set, use SSE HTTP at this address, instead of stdin/stdout") @@ -27,10 +26,10 @@ func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) ([]m }, nil } -func PromptHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) (*protocol.GetPromptResult, error) { - return &protocol.GetPromptResult{ +func PromptHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{ Description: "Code review prompt", - Messages: []protocol.PromptMessage{ + Messages: []mcp.PromptMessage{ {Role: "user", Content: mcp.TextContent{Text: "Say hi to " + params.Name}.ToWire()}, }, }, nil diff --git a/internal/mcp/protocol/generate.go b/internal/mcp/generate.go similarity index 96% rename from internal/mcp/protocol/generate.go rename to internal/mcp/generate.go index aed9da9b433..d764aa8346a 100644 --- a/internal/mcp/protocol/generate.go +++ b/internal/mcp/generate.go @@ -67,15 +67,15 @@ var declarations = config{ Fields: config{"Params": {Name: "GetPromptParams"}}, }, "GetPromptResult": {}, - "Implementation": {}, + "Implementation": {Name: "implementation"}, "InitializeRequest": { Name: "-", - Fields: config{"Params": {Name: "InitializeParams"}}, + Fields: config{"Params": {Name: "initializeParams"}}, }, - "InitializeResult": {}, + "InitializeResult": {Name: "initializeResult"}, "InitializedNotification": { Name: "-", - Fields: config{"Params": {Name: "InitializedParams"}}, + Fields: config{"Params": {Name: "initializedParams"}}, }, "ListPromptsRequest": { Name: "-", @@ -107,17 +107,18 @@ var declarations = config{ Fields: config{"Params": {Name: "ReadResourceParams"}}, }, "ReadResourceResult": { - Fields: config{"Contents": {Substitute: "*ResourceContents"}}, + Fields: config{"Contents": {Substitute: "*WireResource"}}, }, "Resource": {}, "Role": {}, "Root": {}, "ServerCapabilities": { + Name: "serverCapabilities", Fields: config{ - "Prompts": {Name: "PromptCapabilities"}, - "Resources": {Name: "ResourceCapabilities"}, - "Tools": {Name: "ToolCapabilities"}, + "Prompts": {Name: "promptCapabilities"}, + "Resources": {Name: "resourceCapabilities"}, + "Tools": {Name: "toolCapabilities"}, }, }, "Tool": { @@ -161,7 +162,7 @@ func main() { // Code generated by generate.go. DO NOT EDIT. -package protocol +package mcp import ( "encoding/json" @@ -287,7 +288,7 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma if slices.ContainsFunc(def.AnyOf, func(s *jsonschema.Schema) bool { return s.Ref == "#/definitions/TextContent" }) { - fmt.Fprintf(w, "Content") + fmt.Fprintf(w, "WireContent") } else { // E.g. union types. fmt.Fprintf(w, "json.RawMessage") diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go index a20b11def3f..201f42092f2 100644 --- a/internal/mcp/mcp.go +++ b/internal/mcp/mcp.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:generate go run generate.go + // The mcp package provides an SDK for writing model context protocol clients // and servers. // diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index a4defd6f4d3..65ef7222035 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -18,7 +18,6 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" "golang.org/x/tools/internal/mcp/jsonschema" - "golang.org/x/tools/internal/mcp/protocol" ) type hiParams struct { @@ -50,15 +49,15 @@ func TestEndToEnd(t *testing.T) { ) s.AddPrompts( - NewPrompt("code_review", "do a code review", func(_ context.Context, _ *ServerConnection, params struct{ Code string }) (*protocol.GetPromptResult, error) { - return &protocol.GetPromptResult{ + NewPrompt("code_review", "do a code review", func(_ context.Context, _ *ServerConnection, params struct{ Code string }) (*GetPromptResult, error) { + return &GetPromptResult{ Description: "Code review prompt", - Messages: []protocol.PromptMessage{ + Messages: []PromptMessage{ {Role: "user", Content: TextContent{Text: "Please review the following code: " + params.Code}.ToWire()}, }, }, nil }), - NewPrompt("fail", "", func(_ context.Context, _ *ServerConnection, params struct{}) (*protocol.GetPromptResult, error) { + NewPrompt("fail", "", func(_ context.Context, _ *ServerConnection, params struct{}) (*GetPromptResult, error) { return nil, failure }), ) @@ -83,7 +82,7 @@ func TestEndToEnd(t *testing.T) { }() c := NewClient("testClient", "v1.0.0", ct, nil) - c.AddRoots(protocol.Root{URI: "file:///root"}) + c.AddRoots(Root{URI: "file:///root"}) // Connect the client. if err := c.Start(ctx); err != nil { @@ -98,11 +97,11 @@ func TestEndToEnd(t *testing.T) { if err != nil { t.Errorf("prompts/list failed: %v", err) } - wantPrompts := []protocol.Prompt{ + wantPrompts := []Prompt{ { Name: "code_review", Description: "do a code review", - Arguments: []protocol.PromptArgument{{Name: "Code", Required: true}}, + Arguments: []PromptArgument{{Name: "Code", Required: true}}, }, {Name: "fail"}, } @@ -114,9 +113,9 @@ func TestEndToEnd(t *testing.T) { if err != nil { t.Fatal(err) } - wantReview := &protocol.GetPromptResult{ + wantReview := &GetPromptResult{ Description: "Code review prompt", - Messages: []protocol.PromptMessage{{ + Messages: []PromptMessage{{ Content: TextContent{Text: "Please review the following code: 1+1"}.ToWire(), Role: "user", }}, @@ -135,7 +134,7 @@ func TestEndToEnd(t *testing.T) { if err != nil { t.Errorf("tools/list failed: %v", err) } - wantTools := []protocol.Tool{ + wantTools := []Tool{ { Name: "fail", Description: "just fail", @@ -165,8 +164,8 @@ func TestEndToEnd(t *testing.T) { if err != nil { t.Fatal(err) } - wantHi := &protocol.CallToolResult{ - Content: []protocol.Content{{Type: "text", Text: "hi user"}}, + wantHi := &CallToolResult{ + Content: []WireContent{{Type: "text", Text: "hi user"}}, } if diff := cmp.Diff(wantHi, gotHi); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) @@ -178,9 +177,9 @@ func TestEndToEnd(t *testing.T) { if err != nil { t.Fatal(err) } - wantFail := &protocol.CallToolResult{ + wantFail := &CallToolResult{ IsError: true, - Content: []protocol.Content{{Type: "text", Text: failure.Error()}}, + Content: []WireContent{{Type: "text", Text: failure.Error()}}, } if diff := cmp.Diff(wantFail, gotFail); diff != "" { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) @@ -188,21 +187,21 @@ func TestEndToEnd(t *testing.T) { }) t.Run("resources", func(t *testing.T) { - resource1 := protocol.Resource{ + resource1 := Resource{ Name: "public", MIMEType: "text/plain", URI: "file:///file1.txt", } - resource2 := protocol.Resource{ + resource2 := Resource{ Name: "public", // names are not unique IDs MIMEType: "text/plain", URI: "file:///nonexistent.txt", } - readHandler := func(_ context.Context, r protocol.Resource, _ *protocol.ReadResourceParams) (*protocol.ReadResourceResult, error) { + readHandler := func(_ context.Context, r Resource, _ *ReadResourceParams) (*ReadResourceResult, error) { if r.URI == "file:///file1.txt" { - return &protocol.ReadResourceResult{ - Contents: &protocol.ResourceContents{ + return &ReadResourceResult{ + Contents: &WireResource{ Text: "file contents", }, }, nil @@ -217,7 +216,7 @@ func TestEndToEnd(t *testing.T) { if err != nil { t.Fatal(err) } - if diff := cmp.Diff([]protocol.Resource{resource1, resource2}, lrres.Resources); diff != "" { + if diff := cmp.Diff([]Resource{resource1, resource2}, lrres.Resources); diff != "" { t.Errorf("resources/list mismatch (-want, +got):\n%s", diff) } @@ -229,7 +228,7 @@ func TestEndToEnd(t *testing.T) { {"file:///nonexistent.txt", ""}, // TODO(jba): add resource template cases when we implement them } { - rres, err := c.ReadResource(ctx, &protocol.ReadResourceParams{URI: tt.uri}) + rres, err := c.ReadResource(ctx, &ReadResourceParams{URI: tt.uri}) if err != nil { var werr *jsonrpc2.WireError if errors.As(err, &werr) && werr.Code == codeResourceNotFound { @@ -256,7 +255,7 @@ func TestEndToEnd(t *testing.T) { break } - rootRes, err := sc.ListRoots(ctx, &protocol.ListRootsParams{}) + rootRes, err := sc.ListRoots(ctx, &ListRootsParams{}) if err != nil { t.Fatal(err) } @@ -283,7 +282,7 @@ func TestEndToEnd(t *testing.T) { // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, tools ...*Tool) (*ServerConnection, *Client) { +func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerConnection, *Client) { t.Helper() ctx := context.Background() diff --git a/internal/mcp/prompt.go b/internal/mcp/prompt.go index 878a54eed7c..3eb3bb53668 100644 --- a/internal/mcp/prompt.go +++ b/internal/mcp/prompt.go @@ -13,15 +13,14 @@ import ( "golang.org/x/tools/internal/mcp/internal/util" "golang.org/x/tools/internal/mcp/jsonschema" - "golang.org/x/tools/internal/mcp/protocol" ) // A PromptHandler handles a call to prompts/get. -type PromptHandler func(context.Context, *ServerConnection, map[string]string) (*protocol.GetPromptResult, error) +type PromptHandler func(context.Context, *ServerConnection, map[string]string) (*GetPromptResult, error) // A Prompt is a prompt definition bound to a prompt handler. -type Prompt struct { - Definition protocol.Prompt +type ServerPrompt struct { + Definition Prompt Handler PromptHandler } @@ -33,7 +32,7 @@ type Prompt struct { // of type string or *string. The argument names for the resulting prompt // definition correspond to the JSON names of the request fields, and any // fields that are not marked "omitempty" are considered required. -func NewPrompt[TReq any](name, description string, handler func(context.Context, *ServerConnection, TReq) (*protocol.GetPromptResult, error), opts ...PromptOption) *Prompt { +func NewPrompt[TReq any](name, description string, handler func(context.Context, *ServerConnection, TReq) (*GetPromptResult, error), opts ...PromptOption) *ServerPrompt { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) @@ -41,8 +40,8 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, if schema.Type != "object" || !reflect.DeepEqual(schema.AdditionalProperties, &jsonschema.Schema{Not: &jsonschema.Schema{}}) { panic(fmt.Sprintf("handler request type must be a struct")) } - prompt := &Prompt{ - Definition: protocol.Prompt{ + prompt := &ServerPrompt{ + Definition: Prompt{ Name: name, Description: description, }, @@ -55,13 +54,13 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, if prop.Type != "string" { panic(fmt.Sprintf("handler type must consist only of string fields")) } - prompt.Definition.Arguments = append(prompt.Definition.Arguments, protocol.PromptArgument{ + prompt.Definition.Arguments = append(prompt.Definition.Arguments, PromptArgument{ Name: name, Description: prop.Description, Required: required[name], }) } - prompt.Handler = func(ctx context.Context, cc *ServerConnection, args map[string]string) (*protocol.GetPromptResult, error) { + prompt.Handler = func(ctx context.Context, cc *ServerConnection, args map[string]string) (*GetPromptResult, error) { // For simplicity, just marshal and unmarshal the arguments. // This could be avoided in the future. rawArgs, err := json.Marshal(args) @@ -82,12 +81,12 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, // A PromptOption configures the behavior of a Prompt. type PromptOption interface { - set(*Prompt) + set(*ServerPrompt) } -type promptSetter func(*Prompt) +type promptSetter func(*ServerPrompt) -func (s promptSetter) set(p *Prompt) { s(p) } +func (s promptSetter) set(p *ServerPrompt) { s(p) } // Argument configures the 'schema' of a prompt argument. // If the argument does not exist, it is added. @@ -95,14 +94,14 @@ func (s promptSetter) set(p *Prompt) { s(p) } // Since prompt arguments are not a full JSON schema, Argument only accepts // Required and Description, and panics when encountering any other option. func Argument(name string, opts ...SchemaOption) PromptOption { - return promptSetter(func(p *Prompt) { - i := slices.IndexFunc(p.Definition.Arguments, func(arg protocol.PromptArgument) bool { + return promptSetter(func(p *ServerPrompt) { + i := slices.IndexFunc(p.Definition.Arguments, func(arg PromptArgument) bool { return arg.Name == name }) - var arg protocol.PromptArgument + var arg PromptArgument if i < 0 { i = len(p.Definition.Arguments) - arg = protocol.PromptArgument{Name: name} + arg = PromptArgument{Name: name} p.Definition.Arguments = append(p.Definition.Arguments, arg) } else { arg = p.Definition.Arguments[i] diff --git a/internal/mcp/prompt_test.go b/internal/mcp/prompt_test.go index 22afe812305..29a1c5ac5a0 100644 --- a/internal/mcp/prompt_test.go +++ b/internal/mcp/prompt_test.go @@ -10,18 +10,17 @@ import ( "github.com/google/go-cmp/cmp" "golang.org/x/tools/internal/mcp" - "golang.org/x/tools/internal/mcp/protocol" ) // testPromptHandler is used for type inference in TestNewPrompt. -func testPromptHandler[T any](context.Context, *mcp.ServerConnection, T) (*protocol.GetPromptResult, error) { +func testPromptHandler[T any](context.Context, *mcp.ServerConnection, T) (*mcp.GetPromptResult, error) { panic("not implemented") } func TestNewPrompt(t *testing.T) { tests := []struct { - prompt *mcp.Prompt - want []protocol.PromptArgument + prompt *mcp.ServerPrompt + want []mcp.PromptArgument }{ { mcp.NewPrompt("empty", "", testPromptHandler[struct{}]), @@ -29,7 +28,7 @@ func TestNewPrompt(t *testing.T) { }, { mcp.NewPrompt("add_arg", "", testPromptHandler[struct{}], mcp.Argument("x")), - []protocol.PromptArgument{{Name: "x"}}, + []mcp.PromptArgument{{Name: "x"}}, }, { mcp.NewPrompt("combo", "", testPromptHandler[struct { @@ -39,7 +38,7 @@ func TestNewPrompt(t *testing.T) { }], mcp.Argument("name", mcp.Description("the person's name")), mcp.Argument("State", mcp.Required(false))), - []protocol.PromptArgument{ + []mcp.PromptArgument{ {Name: "State"}, {Name: "country"}, {Name: "name", Required: true, Description: "the person's name"}, diff --git a/internal/mcp/protocol/protocol.go b/internal/mcp/protocol.go similarity index 94% rename from internal/mcp/protocol/protocol.go rename to internal/mcp/protocol.go index 35b4088de6a..9c7bde61dee 100644 --- a/internal/mcp/protocol/protocol.go +++ b/internal/mcp/protocol.go @@ -4,7 +4,7 @@ // Code generated by generate.go. DO NOT EDIT. -package protocol +package mcp import ( "encoding/json" @@ -47,7 +47,7 @@ type CallToolResult struct { // This result property is reserved by the protocol to allow clients and servers // to attach additional metadata to their responses. Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Content []Content `json:"content"` + Content []WireContent `json:"content"` // Whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). @@ -99,46 +99,6 @@ type GetPromptResult struct { Messages []PromptMessage `json:"messages"` } -// Describes the name and version of an MCP implementation. -type Implementation struct { - Name string `json:"name"` - Version string `json:"version"` -} - -type InitializeParams struct { - Capabilities ClientCapabilities `json:"capabilities"` - ClientInfo Implementation `json:"clientInfo"` - // The latest version of the Model Context Protocol that the client supports. - // The client MAY decide to support older versions as well. - ProtocolVersion string `json:"protocolVersion"` -} - -// After receiving an initialize request from the client, the server sends this -// response. -type InitializeResult struct { - // This result property is reserved by the protocol to allow clients and servers - // to attach additional metadata to their responses. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Capabilities ServerCapabilities `json:"capabilities"` - // Instructions describing how to use the server and its features. - // - // This can be used by clients to improve the LLM's understanding of available - // tools, resources, etc. It can be thought of like a "hint" to the model. For - // example, this information MAY be added to the system prompt. - Instructions string `json:"instructions,omitempty"` - // The version of the Model Context Protocol that the server wants to use. This - // may not match the version that the client requested. If the client cannot - // support this version, it MUST disconnect. - ProtocolVersion string `json:"protocolVersion"` - ServerInfo Implementation `json:"serverInfo"` -} - -type InitializedParams struct { - // This parameter name is reserved by MCP to allow clients and servers to attach - // additional metadata to their notifications. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` -} - type ListPromptsParams struct { // An opaque token representing the current pagination position. If provided, // the server should return results starting after this cursor. @@ -230,19 +190,13 @@ type PromptArgument struct { Required bool `json:"required,omitempty"` } -// Present if the server offers any prompt templates. -type PromptCapabilities struct { - // Whether this server supports notifications for changes to the prompt list. - ListChanged bool `json:"listChanged,omitempty"` -} - // Describes a message returned as part of a prompt. // // This is similar to `SamplingMessage`, but also supports the embedding of // resources from the MCP server. type PromptMessage struct { - Content Content `json:"content"` - Role Role `json:"role"` + Content WireContent `json:"content"` + Role Role `json:"role"` } type ReadResourceParams struct { @@ -256,7 +210,7 @@ type ReadResourceResult struct { // This result property is reserved by the protocol to allow clients and servers // to attach additional metadata to their responses. Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Contents *ResourceContents `json:"contents"` + Contents *WireResource `json:"contents"` } // A known resource that the server is capable of reading. @@ -284,14 +238,6 @@ type Resource struct { URI string `json:"uri"` } -// Present if the server offers any resources to read. -type ResourceCapabilities struct { - // Whether this server supports notifications for changes to the resource list. - ListChanged bool `json:"listChanged,omitempty"` - // Whether this server supports subscribing to resource updates. - Subscribe bool `json:"subscribe,omitempty"` -} - // The sender or recipient of messages and data in a conversation. type Role string @@ -307,27 +253,6 @@ type Root struct { URI string `json:"uri"` } -// Capabilities that a server may support. Known capabilities are defined here, -// in this schema, but this is not a closed set: any server can define its own, -// additional capabilities. -type ServerCapabilities struct { - // Present if the server supports argument autocompletion suggestions. - Completions struct { - } `json:"completions,omitempty"` - // Experimental, non-standard capabilities that the server supports. - Experimental map[string]struct { - } `json:"experimental,omitempty"` - // Present if the server supports sending log messages to the client. - Logging struct { - } `json:"logging,omitempty"` - // Present if the server offers any prompt templates. - Prompts *PromptCapabilities `json:"prompts,omitempty"` - // Present if the server offers any resources to read. - Resources *ResourceCapabilities `json:"resources,omitempty"` - // Present if the server offers any tools to call. - Tools *ToolCapabilities `json:"tools,omitempty"` -} - // Definition for a tool the client can call. type Tool struct { // Optional additional tool information. @@ -380,8 +305,83 @@ type ToolAnnotations struct { Title string `json:"title,omitempty"` } +// Describes the name and version of an MCP implementation. +type implementation struct { + Name string `json:"name"` + Version string `json:"version"` +} + +type initializeParams struct { + Capabilities ClientCapabilities `json:"capabilities"` + ClientInfo implementation `json:"clientInfo"` + // The latest version of the Model Context Protocol that the client supports. + // The client MAY decide to support older versions as well. + ProtocolVersion string `json:"protocolVersion"` +} + +// After receiving an initialize request from the client, the server sends this +// response. +type initializeResult struct { + // This result property is reserved by the protocol to allow clients and servers + // to attach additional metadata to their responses. + Meta map[string]json.RawMessage `json:"_meta,omitempty"` + Capabilities serverCapabilities `json:"capabilities"` + // Instructions describing how to use the server and its features. + // + // This can be used by clients to improve the LLM's understanding of available + // tools, resources, etc. It can be thought of like a "hint" to the model. For + // example, this information MAY be added to the system prompt. + Instructions string `json:"instructions,omitempty"` + // The version of the Model Context Protocol that the server wants to use. This + // may not match the version that the client requested. If the client cannot + // support this version, it MUST disconnect. + ProtocolVersion string `json:"protocolVersion"` + ServerInfo implementation `json:"serverInfo"` +} + +type initializedParams struct { + // This parameter name is reserved by MCP to allow clients and servers to attach + // additional metadata to their notifications. + Meta map[string]json.RawMessage `json:"_meta,omitempty"` +} + +// Present if the server offers any prompt templates. +type promptCapabilities struct { + // Whether this server supports notifications for changes to the prompt list. + ListChanged bool `json:"listChanged,omitempty"` +} + +// Present if the server offers any resources to read. +type resourceCapabilities struct { + // Whether this server supports notifications for changes to the resource list. + ListChanged bool `json:"listChanged,omitempty"` + // Whether this server supports subscribing to resource updates. + Subscribe bool `json:"subscribe,omitempty"` +} + +// Capabilities that a server may support. Known capabilities are defined here, +// in this schema, but this is not a closed set: any server can define its own, +// additional capabilities. +type serverCapabilities struct { + // Present if the server supports argument autocompletion suggestions. + Completions struct { + } `json:"completions,omitempty"` + // Experimental, non-standard capabilities that the server supports. + Experimental map[string]struct { + } `json:"experimental,omitempty"` + // Present if the server supports sending log messages to the client. + Logging struct { + } `json:"logging,omitempty"` + // Present if the server offers any prompt templates. + Prompts *promptCapabilities `json:"prompts,omitempty"` + // Present if the server offers any resources to read. + Resources *resourceCapabilities `json:"resources,omitempty"` + // Present if the server offers any tools to call. + Tools *toolCapabilities `json:"tools,omitempty"` +} + // Present if the server offers any tools to call. -type ToolCapabilities struct { +type toolCapabilities struct { // Whether this server supports notifications for changes to the tool list. ListChanged bool `json:"listChanged,omitempty"` } diff --git a/internal/mcp/protocol/content.go b/internal/mcp/protocol/content.go deleted file mode 100644 index 76f017da6cd..00000000000 --- a/internal/mcp/protocol/content.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2025 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package protocol - -import ( - "encoding/json" - "fmt" -) - -// The []byte fields below are marked omitzero, not omitempty: -// we want to marshal an empty byte slice. - -// Content is the wire format for content. -// It represents the protocol types TextContent, ImageContent, AudioContent -// and EmbeddedResource. -// The Type field distinguishes them. In the protocol, each type has a constant -// value for the field. -// At most one of Text, Data, and Resource is non-zero. -type Content struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - MIMEType string `json:"mimeType,omitempty"` - Data []byte `json:"data,omitzero"` - Resource *ResourceContents `json:"resource,omitempty"` - Annotations *Annotations `json:"annotations,omitempty"` -} - -// A ResourceContents is either a TextResourceContents or a BlobResourceContents. -// See https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts#L524-L551 -// for the inheritance structure. -// If Blob is nil, this is a TextResourceContents; otherwise it's a BlobResourceContents. -// -// The URI field describes the resource location. -type ResourceContents struct { - URI string `json:"uri,"` - MIMEType string `json:"mimeType,omitempty"` - Text string `json:"text"` - Blob []byte `json:"blob,omitzero"` -} - -func (c *Content) UnmarshalJSON(data []byte) error { - type wireContent Content // for naive unmarshaling - var c2 wireContent - if err := json.Unmarshal(data, &c2); err != nil { - return err - } - switch c2.Type { - case "text", "image", "audio", "resource": - default: - return fmt.Errorf("unrecognized content type %s", c.Type) - } - *c = Content(c2) - return nil -} diff --git a/internal/mcp/protocol/doc.go b/internal/mcp/protocol/doc.go deleted file mode 100644 index ec86936a35d..00000000000 --- a/internal/mcp/protocol/doc.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2025 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:generate go run generate.go - -// The protocol package contains types that define the MCP protocol. -// -// It is auto-generated from the MCP spec. Run go generate to update it. -// The generated set of types is intended to be minimal, in the sense that we -// only generate types that are actually used by the SDK. See generate.go for -// instructions on how to generate more (or different) types. -package protocol diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 95fd2b5f561..f6f5cd1d43a 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -9,13 +9,11 @@ import ( "encoding/json" "fmt" "iter" - "log" "net/url" "slices" "sync" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" - "golang.org/x/tools/internal/mcp/protocol" ) // A Server is an instance of an MCP server. @@ -29,8 +27,8 @@ type Server struct { opts ServerOptions mu sync.Mutex - prompts *featureSet[*Prompt] - tools *featureSet[*Tool] + prompts *featureSet[*ServerPrompt] + tools *featureSet[*ServerTool] resources *featureSet[*ServerResource] conns []*ServerConnection } @@ -55,15 +53,15 @@ func NewServer(name, version string, opts *ServerOptions) *Server { name: name, version: version, opts: *opts, - prompts: newFeatureSet(func(p *Prompt) string { return p.Definition.Name }), - tools: newFeatureSet(func(t *Tool) string { return t.Definition.Name }), + prompts: newFeatureSet(func(p *ServerPrompt) string { return p.Definition.Name }), + tools: newFeatureSet(func(t *ServerTool) string { return t.Definition.Name }), resources: newFeatureSet(func(r *ServerResource) string { return r.Resource.URI }), } } // AddPrompts adds the given prompts to the server, // replacing any with the same names. -func (s *Server) AddPrompts(prompts ...*Prompt) { +func (s *Server) AddPrompts(prompts ...*ServerPrompt) { s.mu.Lock() defer s.mu.Unlock() s.prompts.add(prompts...) @@ -84,7 +82,7 @@ func (s *Server) RemovePrompts(names ...string) { // AddTools adds the given tools to the server, // replacing any with the same names. -func (s *Server) AddTools(tools ...*Tool) { +func (s *Server) AddTools(tools ...*ServerTool) { s.mu.Lock() defer s.mu.Unlock() s.tools.add(tools...) @@ -123,11 +121,11 @@ const codeResourceNotFound = -31002 // A ReadResourceHandler is a function that reads a resource. // If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. -type ReadResourceHandler func(context.Context, protocol.Resource, *protocol.ReadResourceParams) (*protocol.ReadResourceResult, error) +type ReadResourceHandler func(context.Context, Resource, *ReadResourceParams) (*ReadResourceResult, error) // A ServerResource associates a Resource with its handler. type ServerResource struct { - Resource protocol.Resource + Resource Resource Handler ReadResourceHandler } @@ -168,17 +166,17 @@ func (s *Server) Clients() iter.Seq[*ServerConnection] { return slices.Values(clients) } -func (s *Server) listPrompts(_ context.Context, _ *ServerConnection, params *protocol.ListPromptsParams) (*protocol.ListPromptsResult, error) { +func (s *Server) listPrompts(_ context.Context, _ *ServerConnection, params *ListPromptsParams) (*ListPromptsResult, error) { s.mu.Lock() defer s.mu.Unlock() - res := new(protocol.ListPromptsResult) + res := new(ListPromptsResult) for p := range s.prompts.all() { res.Prompts = append(res.Prompts, p.Definition) } return res, nil } -func (s *Server) getPrompt(ctx context.Context, cc *ServerConnection, params *protocol.GetPromptParams) (*protocol.GetPromptResult, error) { +func (s *Server) getPrompt(ctx context.Context, cc *ServerConnection, params *GetPromptParams) (*GetPromptResult, error) { s.mu.Lock() prompt, ok := s.prompts.get(params.Name) s.mu.Unlock() @@ -189,17 +187,17 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerConnection, params *pr return prompt.Handler(ctx, cc, params.Arguments) } -func (s *Server) listTools(_ context.Context, _ *ServerConnection, params *protocol.ListToolsParams) (*protocol.ListToolsResult, error) { +func (s *Server) listTools(_ context.Context, _ *ServerConnection, params *ListToolsParams) (*ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() - res := new(protocol.ListToolsResult) + res := new(ListToolsResult) for t := range s.tools.all() { res.Tools = append(res.Tools, t.Definition) } return res, nil } -func (s *Server) callTool(ctx context.Context, cc *ServerConnection, params *protocol.CallToolParams) (*protocol.CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, cc *ServerConnection, params *CallToolParams) (*CallToolResult, error) { s.mu.Lock() tool, ok := s.tools.get(params.Name) s.mu.Unlock() @@ -209,19 +207,17 @@ func (s *Server) callTool(ctx context.Context, cc *ServerConnection, params *pro return tool.Handler(ctx, cc, params.Arguments) } -func (s *Server) listResources(_ context.Context, _ *ServerConnection, params *protocol.ListResourcesParams) (*protocol.ListResourcesResult, error) { +func (s *Server) listResources(_ context.Context, _ *ServerConnection, params *ListResourcesParams) (*ListResourcesResult, error) { s.mu.Lock() defer s.mu.Unlock() - res := new(protocol.ListResourcesResult) + res := new(ListResourcesResult) for r := range s.resources.all() { res.Resources = append(res.Resources, r.Resource) } return res, nil } -func (s *Server) readResource(ctx context.Context, _ *ServerConnection, params *protocol.ReadResourceParams) (*protocol.ReadResourceResult, error) { - log.Printf("readResource") - defer log.Printf("done") +func (s *Server) readResource(ctx context.Context, _ *ServerConnection, params *ReadResourceParams) (*ReadResourceResult, error) { uri := params.URI // Look up the resource URI in the list we have. // This is a security check as well as an information lookup. @@ -302,7 +298,7 @@ type ServerConnection struct { conn *jsonrpc2.Connection mu sync.Mutex - initializeParams *protocol.InitializeParams + initializeParams *initializeParams initialized bool } @@ -311,8 +307,8 @@ func (cc *ServerConnection) Ping(ctx context.Context) error { return call(ctx, cc.conn, "ping", nil, nil) } -func (cc *ServerConnection) ListRoots(ctx context.Context, params *protocol.ListRootsParams) (*protocol.ListRootsResult, error) { - return standardCall[protocol.ListRootsResult](ctx, cc.conn, "roots/list", params) +func (cc *ServerConnection) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { + return standardCall[ListRootsResult](ctx, cc.conn, "roots/list", params) } func (cc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { @@ -366,7 +362,7 @@ func (cc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) ( return nil, jsonrpc2.ErrNotHandled } -func (cc *ServerConnection) initialize(ctx context.Context, _ *ServerConnection, params *protocol.InitializeParams) (*protocol.InitializeResult, error) { +func (cc *ServerConnection) initialize(ctx context.Context, _ *ServerConnection, params *initializeParams) (*initializeResult, error) { cc.mu.Lock() cc.initializeParams = params cc.mu.Unlock() @@ -382,19 +378,19 @@ func (cc *ServerConnection) initialize(ctx context.Context, _ *ServerConnection, cc.mu.Unlock() }() - return &protocol.InitializeResult{ + return &initializeResult{ // TODO(rfindley): support multiple protocol versions. ProtocolVersion: "2024-11-05", - Capabilities: protocol.ServerCapabilities{ - Prompts: &protocol.PromptCapabilities{ + Capabilities: serverCapabilities{ + Prompts: &promptCapabilities{ ListChanged: false, // not yet supported }, - Tools: &protocol.ToolCapabilities{ + Tools: &toolCapabilities{ ListChanged: false, // not yet supported }, }, Instructions: cc.server.opts.Instructions, - ServerInfo: protocol.Implementation{ + ServerInfo: implementation{ Name: cc.server.name, Version: cc.server.version, }, diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index f901964a51f..de7580ea847 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -12,7 +12,6 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "golang.org/x/tools/internal/mcp/protocol" ) func TestSSEServer(t *testing.T) { @@ -48,8 +47,8 @@ func TestSSEServer(t *testing.T) { if err != nil { t.Fatal(err) } - wantHi := &protocol.CallToolResult{ - Content: []protocol.Content{{Type: "text", Text: "hi user"}}, + wantHi := &CallToolResult{ + Content: []WireContent{{Type: "text", Text: "hi user"}}, } if diff := cmp.Diff(wantHi, gotHi); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index 4c0f0eafe00..c2bab6eccd6 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -11,15 +11,14 @@ import ( "golang.org/x/tools/internal/mcp/internal/util" "golang.org/x/tools/internal/mcp/jsonschema" - "golang.org/x/tools/internal/mcp/protocol" ) // A ToolHandler handles a call to tools/call. -type ToolHandler func(context.Context, *ServerConnection, map[string]json.RawMessage) (*protocol.CallToolResult, error) +type ToolHandler func(context.Context, *ServerConnection, map[string]json.RawMessage) (*CallToolResult, error) // A Tool is a tool definition that is bound to a tool handler. -type Tool struct { - Definition protocol.Tool +type ServerTool struct { + Definition Tool Handler ToolHandler } @@ -36,12 +35,12 @@ type Tool struct { // // TODO: just have the handler return a CallToolResult: returning []Content is // going to be inconsistent with other server features. -func NewTool[TReq any](name, description string, handler func(context.Context, *ServerConnection, TReq) ([]Content, error), opts ...ToolOption) *Tool { +func NewTool[TReq any](name, description string, handler func(context.Context, *ServerConnection, TReq) ([]Content, error), opts ...ToolOption) *ServerTool { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) } - wrapped := func(ctx context.Context, cc *ServerConnection, args map[string]json.RawMessage) (*protocol.CallToolResult, error) { + wrapped := func(ctx context.Context, cc *ServerConnection, args map[string]json.RawMessage) (*CallToolResult, error) { // For simplicity, just marshal and unmarshal the arguments. // This could be avoided in the future. rawArgs, err := json.Marshal(args) @@ -56,18 +55,18 @@ func NewTool[TReq any](name, description string, handler func(context.Context, * // TODO: investigate why server errors are embedded in this strange way, // rather than returned as jsonrpc2 server errors. if err != nil { - return &protocol.CallToolResult{ - Content: []protocol.Content{TextContent{Text: err.Error()}.ToWire()}, + return &CallToolResult{ + Content: []WireContent{TextContent{Text: err.Error()}.ToWire()}, IsError: true, }, nil } - res := &protocol.CallToolResult{ + res := &CallToolResult{ Content: util.Apply(content, Content.ToWire), } return res, nil } - t := &Tool{ - Definition: protocol.Tool{ + t := &ServerTool{ + Definition: Tool{ Name: name, Description: description, InputSchema: schema, @@ -90,17 +89,17 @@ func unmarshalSchema(data json.RawMessage, _ *jsonschema.Schema, v any) error { // A ToolOption configures the behavior of a Tool. type ToolOption interface { - set(*Tool) + set(*ServerTool) } -type toolSetter func(*Tool) +type toolSetter func(*ServerTool) -func (s toolSetter) set(t *Tool) { s(t) } +func (s toolSetter) set(t *ServerTool) { s(t) } // Input applies the provided [SchemaOption] configuration to the tool's input // schema. func Input(opts ...SchemaOption) ToolOption { - return toolSetter(func(t *Tool) { + return toolSetter(func(t *ServerTool) { for _, opt := range opts { opt.set(t.Definition.InputSchema) } diff --git a/internal/mcp/tool_test.go b/internal/mcp/tool_test.go index 45bc82048e1..e6bf4c0eafb 100644 --- a/internal/mcp/tool_test.go +++ b/internal/mcp/tool_test.go @@ -21,7 +21,7 @@ func testToolHandler[T any](context.Context, *mcp.ServerConnection, T) ([]mcp.Co func TestNewTool(t *testing.T) { tests := []struct { - tool *mcp.Tool + tool *mcp.ServerTool want *jsonschema.Schema }{ { diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index faa2a806586..7c902848d60 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -15,7 +15,6 @@ import ( "sync" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" - "golang.org/x/tools/internal/mcp/protocol" "golang.org/x/tools/internal/xcontext" ) @@ -131,7 +130,7 @@ type canceller struct { // Preempt implements jsonrpc2.Preempter. func (c *canceller) Preempt(ctx context.Context, req *jsonrpc2.Request) (result any, err error) { if req.Method == "notifications/cancelled" { - var params protocol.CancelledParams + var params CancelledParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { return nil, err } @@ -156,7 +155,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params, return fmt.Errorf("calling %q: %w", method, ErrConnectionClosed) case ctx.Err() != nil: // Notify the peer of cancellation. - err := conn.Notify(xcontext.Detach(ctx), "notifications/cancelled", &protocol.CancelledParams{ + err := conn.Notify(xcontext.Detach(ctx), "notifications/cancelled", &CancelledParams{ Reason: ctx.Err().Error(), RequestID: call.ID().Raw(), }) From 34082a6be05c5aabe40b13abe113cc472e9d5c94 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 13 May 2025 21:51:26 +0000 Subject: [PATCH 056/196] internal/mcp: simplify content Now that the protocol package is merged into the mcp package, clean up the higher level 'Content' union types, in favor of types constructors. Change-Id: I37bedb0ff0a296eef799ba65c37ca7f94ee1dbbe Reviewed-on: https://go-review.googlesource.com/c/tools/+/672415 Reviewed-by: Jonathan Amsterdam TryBot-Bypass: Jonathan Amsterdam --- gopls/internal/mcp/mcp.go | 2 +- internal/mcp/cmd_test.go | 2 +- internal/mcp/content.go | 160 +++++++--------------------- internal/mcp/content_test.go | 67 +++++------- internal/mcp/design/design.md | 30 +++--- internal/mcp/examples/hello/main.go | 4 +- internal/mcp/examples/sse/main.go | 2 +- internal/mcp/generate.go | 4 +- internal/mcp/internal/util/util.go | 9 -- internal/mcp/mcp_test.go | 12 +-- internal/mcp/protocol.go | 8 +- internal/mcp/server_example_test.go | 2 +- internal/mcp/sse_example_test.go | 2 +- internal/mcp/sse_test.go | 2 +- internal/mcp/tool.go | 5 +- internal/mcp/util.go | 5 - 16 files changed, 108 insertions(+), 208 deletions(-) diff --git a/gopls/internal/mcp/mcp.go b/gopls/internal/mcp/mcp.go index ac09ffc300c..b21a700d598 100644 --- a/gopls/internal/mcp/mcp.go +++ b/gopls/internal/mcp/mcp.go @@ -143,7 +143,7 @@ type HelloParams struct { func helloHandler(session *cache.Session) func(ctx context.Context, cc *mcp.ServerConnection, request *HelloParams) ([]mcp.Content, error) { return func(ctx context.Context, cc *mcp.ServerConnection, request *HelloParams) ([]mcp.Content, error) { return []mcp.Content{ - mcp.TextContent{Text: "Hi " + request.Name + ", this is lsp session " + session.ID()}, + mcp.NewTextContent("Hi " + request.Name + ", this is lsp session " + session.ID()), }, nil } } diff --git a/internal/mcp/cmd_test.go b/internal/mcp/cmd_test.go index abcaff321dd..f4ce863785a 100644 --- a/internal/mcp/cmd_test.go +++ b/internal/mcp/cmd_test.go @@ -57,7 +57,7 @@ func TestCmdTransport(t *testing.T) { log.Fatal(err) } want := &mcp.CallToolResult{ - Content: []mcp.WireContent{{Type: "text", Text: "Hi user"}}, + Content: []mcp.Content{{Type: "text", Text: "Hi user"}}, } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("greet returned unexpected content (-want +got):\n%s", diff) diff --git a/internal/mcp/content.go b/internal/mcp/content.go index 24cb2cfb1f1..8e4f282b74b 100644 --- a/internal/mcp/content.go +++ b/internal/mcp/content.go @@ -11,37 +11,25 @@ import ( // The []byte fields below are marked omitzero, not omitempty: // we want to marshal an empty byte slice. +// TODO(jba): figure out how to fix this for 1.23. -// WireContent is the wire format for content. +// Content is the wire format for content. // It represents the protocol types TextContent, ImageContent, AudioContent // and EmbeddedResource. // The Type field distinguishes them. In the protocol, each type has a constant // value for the field. // At most one of Text, Data, and Resource is non-zero. -type WireContent struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - MIMEType string `json:"mimeType,omitempty"` - Data []byte `json:"data,omitzero"` - Resource *WireResource `json:"resource,omitempty"` - Annotations *Annotations `json:"annotations,omitempty"` +type Content struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Data []byte `json:"data,omitzero"` + Resource *ResourceContents `json:"resource,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` } -// A WireResource is either a TextResourceContents or a BlobResourceContents. -// See https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts#L524-L551 -// for the inheritance structure. -// If Blob is nil, this is a TextResourceContents; otherwise it's a BlobResourceContents. -// -// The URI field describes the resource location. -type WireResource struct { - URI string `json:"uri,"` - MIMEType string `json:"mimeType,omitempty"` - Text string `json:"text"` - Blob []byte `json:"blob,omitzero"` -} - -func (c *WireContent) UnmarshalJSON(data []byte) error { - type wireContent WireContent // for naive unmarshaling +func (c *Content) UnmarshalJSON(data []byte) error { + type wireContent Content // for naive unmarshaling var c2 wireContent if err := json.Unmarshal(data, &c2); err != nil { return err @@ -51,77 +39,44 @@ func (c *WireContent) UnmarshalJSON(data []byte) error { default: return fmt.Errorf("unrecognized content type %s", c.Type) } - *c = WireContent(c2) + *c = Content(c2) return nil } -// Content is the union of supported content types: [TextContent], -// [ImageContent], [AudioContent], and [ResourceContent]. -// -// ToWire converts content to its jsonrpc2 wire format. -type Content interface { - // TODO: unexport this, and move the tests that use it to this package. - ToWire() WireContent -} - -// TextContent is a textual content. -type TextContent struct { - Text string -} - -func (c TextContent) ToWire() WireContent { - return WireContent{Type: "text", Text: c.Text} -} - -// ImageContent contains base64-encoded image data. -type ImageContent struct { - Data []byte // base64-encoded - MIMEType string -} - -func (c ImageContent) ToWire() WireContent { - return WireContent{Type: "image", MIMEType: c.MIMEType, Data: c.Data} +func NewTextContent(text string) Content { + return Content{Type: "text", Text: text} } -// AudioContent contains base64-encoded audio data. -type AudioContent struct { - Data []byte - MIMEType string -} - -func (c AudioContent) ToWire() WireContent { - return WireContent{Type: "audio", MIMEType: c.MIMEType, Data: c.Data} -} - -// ResourceContent contains embedded resources. -type ResourceContent struct { - Resource EmbeddedResource +func NewImageContent(data []byte, mimeType string) Content { + return Content{Type: "image", Data: data, MIMEType: mimeType} } -func (r ResourceContent) ToWire() WireContent { - res := r.Resource.toWire() - return WireContent{Type: "resource", Resource: &res} +func NewAudioContent(data []byte, mimeType string) Content { + return Content{Type: "audio", Data: data, MIMEType: mimeType} } -type EmbeddedResource interface { - toWire() WireResource +func NewResourceContent(resource ResourceContents) Content { + return Content{Type: "resource", Resource: &resource} } -// The {Text,Blob}ResourceContents types match the protocol definitions, -// but we represent both as a single type on the wire. - -// A TextResourceContents is the contents of a text resource. -type TextResourceContents struct { - URI string - MIMEType string - Text string +// A ResourceContents is either a TextResourceContents or a BlobResourceContents. +// See https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts#L524-L551 +// for the inheritance structure. +// If Blob is nil, this is a TextResourceContents; otherwise it's a BlobResourceContents. +// +// The URI field describes the resource location. +type ResourceContents struct { + URI string `json:"uri,"` + MIMEType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` + Blob []byte `json:"blob,omitzero"` } -func (r TextResourceContents) toWire() WireResource { - return WireResource{ - URI: r.URI, - MIMEType: r.MIMEType, - Text: r.Text, +func NewTextResourceContents(uri, mimeType, text string) ResourceContents { + return ResourceContents{ + URI: uri, + MIMEType: mimeType, + Text: text, // Blob is nil, indicating this is a TextResourceContents. } } @@ -133,43 +88,10 @@ type BlobResourceContents struct { Blob []byte } -func (r BlobResourceContents) toWire() WireResource { - return WireResource{ - URI: r.URI, - MIMEType: r.MIMEType, - Blob: r.Blob, - } -} - -// ContentFromWireContent converts content from the jsonrpc2 wire format to a -// typed Content value. -func ContentFromWireContent(c WireContent) Content { - switch c.Type { - case "text": - return TextContent{Text: c.Text} - case "image": - return ImageContent{Data: c.Data, MIMEType: c.MIMEType} - case "audio": - return AudioContent{Data: c.Data, MIMEType: c.MIMEType} - case "resource": - r := ResourceContent{} - if c.Resource != nil { - if c.Resource.Blob != nil { - r.Resource = BlobResourceContents{ - URI: c.Resource.URI, - MIMEType: c.Resource.MIMEType, - Blob: c.Resource.Blob, - } - } else { - r.Resource = TextResourceContents{ - URI: c.Resource.URI, - MIMEType: c.Resource.MIMEType, - Text: c.Resource.Text, - } - } - } - return r - default: - panic(fmt.Sprintf("unrecognized wire content type %q", c.Type)) +func NewBlobResourceContents(uri, mimeType string, blob []byte) ResourceContents { + return ResourceContents{ + URI: uri, + MIMEType: mimeType, + Blob: blob, } } diff --git a/internal/mcp/content_test.go b/internal/mcp/content_test.go index 548989afa90..18f34a41a0d 100644 --- a/internal/mcp/content_test.go +++ b/internal/mcp/content_test.go @@ -5,6 +5,7 @@ package mcp_test import ( + "encoding/json" "testing" "github.com/google/go-cmp/cmp" @@ -14,57 +15,45 @@ import ( func TestContent(t *testing.T) { tests := []struct { in mcp.Content - want mcp.WireContent + want string // json serialization }{ - {mcp.TextContent{Text: "hello"}, mcp.WireContent{Type: "text", Text: "hello"}}, + {mcp.NewTextContent("hello"), `{"type":"text","text":"hello"}`}, { - mcp.ImageContent{Data: []byte("a1b2c3"), MIMEType: "image/png"}, - mcp.WireContent{Type: "image", Data: []byte("a1b2c3"), MIMEType: "image/png"}, + mcp.NewImageContent([]byte("a1b2c3"), "image/png"), + `{"type":"image","mimeType":"image/png","data":"YTFiMmMz"}`, }, { - mcp.AudioContent{Data: []byte("a1b2c3"), MIMEType: "audio/wav"}, - mcp.WireContent{Type: "audio", Data: []byte("a1b2c3"), MIMEType: "audio/wav"}, + mcp.NewAudioContent([]byte("a1b2c3"), "audio/wav"), + `{"type":"audio","mimeType":"audio/wav","data":"YTFiMmMz"}`, }, { - mcp.ResourceContent{ - Resource: mcp.TextResourceContents{ - URI: "file://foo", - MIMEType: "text", - Text: "abc", - }, - }, - mcp.WireContent{ - Type: "resource", - Resource: &mcp.WireResource{ - URI: "file://foo", - MIMEType: "text", - Text: "abc", - }, - }, + mcp.NewResourceContent( + mcp.NewTextResourceContents("file://foo", "text", "abc"), + ), + `{"type":"resource","resource":{"uri":"file://foo","mimeType":"text","text":"abc"}}`, }, { - mcp.ResourceContent{ - Resource: mcp.BlobResourceContents{ - URI: "file://foo", - MIMEType: "text", - Blob: []byte("a1b2c3"), - }, - }, - mcp.WireContent{ - Type: "resource", - Resource: &mcp.WireResource{ - URI: "file://foo", - MIMEType: "text", - Blob: []byte("a1b2c3"), - }, - }, + mcp.NewResourceContent( + mcp.NewBlobResourceContents("file://foo", "image/png", []byte("a1b2c3")), + ), + `{"type":"resource","resource":{"uri":"file://foo","mimeType":"image/png","blob":"YTFiMmMz"}}`, }, } for _, test := range tests { - got := test.in.ToWire() - if diff := cmp.Diff(test.want, got); diff != "" { - t.Errorf("ToWire mismatch (-want +got):\n%s", diff) + got, err := json.Marshal(test.in) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(test.want, string(got)); diff != "" { + t.Errorf("json.Marshal(%v) mismatch (-want +got):\n%s", test.in, diff) + } + var out mcp.Content + if err := json.Unmarshal(got, &out); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(test.in, out); diff != "" { + t.Errorf("json.Unmarshal(%q) mismatch (-want +got):\n%s", string(got), diff) } } } diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index aeba7f82158..542ebc9b8f0 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -123,6 +123,7 @@ type Stream interface { Close() error } ``` + Methods accept a Go `Context` and return an `error`, as is idiomatic for APIs that do I/O. @@ -350,17 +351,6 @@ type Content struct { Data []byte `json:"data,omitempty"` Resource *Resource `json:"resource,omitempty"` } - -// Resource is the wire format for embedded resources. -// -// The URI field describes the resource location. At most one of Text and Blob -// is non-zero. -type Resource struct { - URI string `json:"uri,"` - MIMEType string `json:"mimeType,omitempty"` - Text string `json:"text"` - Blob []byte `json:"blob"` -} ``` **Differences from mcp-go**: these types are largely similar, but our type @@ -445,6 +435,7 @@ content, err := session.CallTool(ctx, &CallToolParams{ ... return session.Close() ``` + A server that can handle that client call would look like this: ```go @@ -465,8 +456,6 @@ session until the client disconnects: func (*Server) Run(context.Context, Transport) ``` - - **Differences from mcp-go**: the Server APIs are very similar to mcp-go, though the association between servers and transports is different. In mcp-go, a single server is bound to what we would call an `SSEHTTPHandler`, @@ -493,9 +482,11 @@ documentation. As we saw above, the `ClientSession` method for the specification's `CallTool` RPC takes a context and a params pointer as arguments, and returns a result pointer and error: + ```go func (*ClientSession) CallTool(context.Context, *CallToolParams) (*CallToolResult, error) ``` + Our SDK has a method for every RPC in the spec, and their signatures all share this form. To avoid boilerplate, we don't repeat this signature for RPCs defined in the spec; readers may assume it when we mention a "spec method." @@ -867,12 +858,14 @@ handler to a Go function using reflection to derive its arguments. We provide To add a resource or resource template to a server, users call the `AddResource` and `AddResourceTemplate` methods, passing the resource or template and a function for reading it: + ```go type ReadResourceHandler func(context.Context, *ServerSession, *Resource, *ReadResourceParams) (*ReadResourceResult, error) func (*Server) AddResource(*Resource, ReadResourceHandler) func (*Server) AddResourceTemplate(*ResourceTemplate, ReadResourceHandler) ``` + The `Resource` is passed to the reader function even though it is redundant (the function could have closed over it) so a single handler can support multiple resources. If the incoming resource matches a template, a `Resource` argument is constructed @@ -880,15 +873,18 @@ from the fields in the `ResourceTemplate`. The `ServerSession` argument is there so the reader can observe the client's roots. To read files from the local filesystem, we recommend using `FileReadResourceHandler` to construct a handler: + ```go // FileReadResourceHandler returns a ReadResourceHandler that reads paths using dir as a root directory. // It protects against path traversal attacks. // It will not read any file that is not in the root set of the client requesting the resource. func (*Server) FileReadResourceHandler(dir string) ReadResourceHandler ``` + It guards against [path traversal attacks](https://go.dev/blog/osroot) and observes the client's roots. Here is an example: + ```go // Safely read "/public/puppies.txt". s.AddResource( @@ -897,14 +893,17 @@ s.AddResource( ``` There are also server methods to remove resources and resource templates. + ```go func (*Server) RemoveResources(uris ...string) func (*Server) RemoveResourceTemplates(names ...string) ``` + Resource templates don't have unique identifiers, so removing a name will remove all resource templates with that name. Servers support all of the resource-related spec methods: + - `ListResources` and `ListResourceTemplates` for listings. - `ReadResource` to get the contents of a resource. - `Subscribe` and `Unsubscribe` to manage subscriptions on resources. @@ -916,6 +915,7 @@ then returns the result of calling the associated reader function. #### Subscriptions ClientSessions can manage change notifications on particular resources: + ```go func (*ClientSession) Subscribe(context.Context, *SubscribeParams) error func (*ClientSession) Unsubscribe(context.Context, *UnsubscribeParams) error @@ -929,6 +929,7 @@ user doesn't have to. If a server author wants to support resource subscriptions, they must provide handlers to be called when clients subscribe and unsubscribe. It is an error to provide only one of these handlers. + ```go type ServerOptions struct { ... @@ -940,9 +941,11 @@ type ServerOptions struct { ``` User code should call `ResourceUpdated` when a subscribed resource changes. + ```go func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotification) error ``` + The server routes these notifications to the server sessions that subscribed to the resource. ### ListChanged notifications @@ -1007,6 +1010,7 @@ follows: level, a handler would call `session.Log(ctx, mcp.LevelNotice, "message")`. A client that wishes to receive log messages must provide a handler: + ```go type ClientOptions struct { ... diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index 9cfba154d37..d58067217d0 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -22,7 +22,7 @@ type HiParams struct { func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) ([]mcp.Content, error) { return []mcp.Content{ - mcp.TextContent{Text: "Hi " + params.Name}, + mcp.NewTextContent("Hi " + params.Name), }, nil } @@ -30,7 +30,7 @@ func PromptHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) ( return &mcp.GetPromptResult{ Description: "Code review prompt", Messages: []mcp.PromptMessage{ - {Role: "user", Content: mcp.TextContent{Text: "Say hi to " + params.Name}.ToWire()}, + {Role: "user", Content: mcp.NewTextContent("Say hi to " + params.Name)}, }, }, nil } diff --git a/internal/mcp/examples/sse/main.go b/internal/mcp/examples/sse/main.go index fc590f7e0eb..4b6d104e83d 100644 --- a/internal/mcp/examples/sse/main.go +++ b/internal/mcp/examples/sse/main.go @@ -21,7 +21,7 @@ type SayHiParams struct { func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *SayHiParams) ([]mcp.Content, error) { return []mcp.Content{ - mcp.TextContent{Text: "Hi " + params.Name}, + mcp.NewTextContent("Hi " + params.Name), }, nil } diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index d764aa8346a..8d184c859df 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -107,7 +107,7 @@ var declarations = config{ Fields: config{"Params": {Name: "ReadResourceParams"}}, }, "ReadResourceResult": { - Fields: config{"Contents": {Substitute: "*WireResource"}}, + Fields: config{"Contents": {Substitute: "*ResourceContents"}}, }, "Resource": {}, "Role": {}, @@ -288,7 +288,7 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma if slices.ContainsFunc(def.AnyOf, func(s *jsonschema.Schema) bool { return s.Ref == "#/definitions/TextContent" }) { - fmt.Fprintf(w, "WireContent") + fmt.Fprintf(w, "Content") } else { // E.g. union types. fmt.Fprintf(w, "json.RawMessage") diff --git a/internal/mcp/internal/util/util.go b/internal/mcp/internal/util/util.go index c62a6f7e0af..cdc6038ede8 100644 --- a/internal/mcp/internal/util/util.go +++ b/internal/mcp/internal/util/util.go @@ -10,15 +10,6 @@ import ( "slices" ) -// Apply returns a new slice resulting from applying f to each element of x. -func Apply[S ~[]E, E, F any](x S, f func(E) F) []F { - y := make([]F, len(x)) - for i, e := range x { - y[i] = f(e) - } - return y -} - // Helpers below are copied from gopls' moremaps package. // sorted returns an iterator over the entries of m in key order. diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 65ef7222035..3377b689b5a 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -28,7 +28,7 @@ func sayHi(ctx context.Context, cc *ServerConnection, v hiParams) ([]Content, er if err := cc.Ping(ctx); err != nil { return nil, fmt.Errorf("ping failed: %v", err) } - return []Content{TextContent{Text: "hi " + v.Name}}, nil + return []Content{NewTextContent("hi " + v.Name)}, nil } func TestEndToEnd(t *testing.T) { @@ -53,7 +53,7 @@ func TestEndToEnd(t *testing.T) { return &GetPromptResult{ Description: "Code review prompt", Messages: []PromptMessage{ - {Role: "user", Content: TextContent{Text: "Please review the following code: " + params.Code}.ToWire()}, + {Role: "user", Content: NewTextContent("Please review the following code: " + params.Code)}, }, }, nil }), @@ -116,7 +116,7 @@ func TestEndToEnd(t *testing.T) { wantReview := &GetPromptResult{ Description: "Code review prompt", Messages: []PromptMessage{{ - Content: TextContent{Text: "Please review the following code: 1+1"}.ToWire(), + Content: NewTextContent("Please review the following code: 1+1"), Role: "user", }}, } @@ -165,7 +165,7 @@ func TestEndToEnd(t *testing.T) { t.Fatal(err) } wantHi := &CallToolResult{ - Content: []WireContent{{Type: "text", Text: "hi user"}}, + Content: []Content{{Type: "text", Text: "hi user"}}, } if diff := cmp.Diff(wantHi, gotHi); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) @@ -179,7 +179,7 @@ func TestEndToEnd(t *testing.T) { } wantFail := &CallToolResult{ IsError: true, - Content: []WireContent{{Type: "text", Text: failure.Error()}}, + Content: []Content{{Type: "text", Text: failure.Error()}}, } if diff := cmp.Diff(wantFail, gotFail); diff != "" { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) @@ -201,7 +201,7 @@ func TestEndToEnd(t *testing.T) { readHandler := func(_ context.Context, r Resource, _ *ReadResourceParams) (*ReadResourceResult, error) { if r.URI == "file:///file1.txt" { return &ReadResourceResult{ - Contents: &WireResource{ + Contents: &ResourceContents{ Text: "file contents", }, }, nil diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index 9c7bde61dee..73ffa75819c 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -47,7 +47,7 @@ type CallToolResult struct { // This result property is reserved by the protocol to allow clients and servers // to attach additional metadata to their responses. Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Content []WireContent `json:"content"` + Content []Content `json:"content"` // Whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). @@ -195,8 +195,8 @@ type PromptArgument struct { // This is similar to `SamplingMessage`, but also supports the embedding of // resources from the MCP server. type PromptMessage struct { - Content WireContent `json:"content"` - Role Role `json:"role"` + Content Content `json:"content"` + Role Role `json:"role"` } type ReadResourceParams struct { @@ -210,7 +210,7 @@ type ReadResourceResult struct { // This result property is reserved by the protocol to allow clients and servers // to attach additional metadata to their responses. Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Contents *WireResource `json:"contents"` + Contents *ResourceContents `json:"contents"` } // A known resource that the server is capable of reading. diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index 9cdc1f2ad9f..5af54e999c1 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -18,7 +18,7 @@ type SayHiParams struct { func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *SayHiParams) ([]mcp.Content, error) { return []mcp.Content{ - mcp.TextContent{Text: "Hi " + params.Name}, + mcp.NewTextContent("Hi " + params.Name), }, nil } diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go index 028084faf66..5464356d249 100644 --- a/internal/mcp/sse_example_test.go +++ b/internal/mcp/sse_example_test.go @@ -20,7 +20,7 @@ type AddParams struct { func Add(ctx context.Context, cc *mcp.ServerConnection, params *AddParams) ([]mcp.Content, error) { return []mcp.Content{ - mcp.TextContent{Text: fmt.Sprintf("%d", params.X+params.Y)}, + mcp.NewTextContent(fmt.Sprintf("%d", params.X+params.Y)), }, nil } diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index de7580ea847..10f04268a33 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -48,7 +48,7 @@ func TestSSEServer(t *testing.T) { t.Fatal(err) } wantHi := &CallToolResult{ - Content: []WireContent{{Type: "text", Text: "hi user"}}, + Content: []Content{{Type: "text", Text: "hi user"}}, } if diff := cmp.Diff(wantHi, gotHi); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index c2bab6eccd6..8c89f915ed9 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -9,7 +9,6 @@ import ( "encoding/json" "slices" - "golang.org/x/tools/internal/mcp/internal/util" "golang.org/x/tools/internal/mcp/jsonschema" ) @@ -56,12 +55,12 @@ func NewTool[TReq any](name, description string, handler func(context.Context, * // rather than returned as jsonrpc2 server errors. if err != nil { return &CallToolResult{ - Content: []WireContent{TextContent{Text: err.Error()}.ToWire()}, + Content: []Content{NewTextContent(err.Error())}, IsError: true, }, nil } res := &CallToolResult{ - Content: util.Apply(content, Content.ToWire), + Content: content, } return res, nil } diff --git a/internal/mcp/util.go b/internal/mcp/util.go index 64d4d4851d0..15b3e63d874 100644 --- a/internal/mcp/util.go +++ b/internal/mcp/util.go @@ -14,11 +14,6 @@ func assert(cond bool, msg string) { } } -func is[T any](v any) bool { - _, ok := v.(T) - return ok -} - // Copied from crypto/rand. // TODO: once 1.24 is assured, just use crypto/rand. const base32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" From a5938faf76a884ef34cd1e59fa49c10f333bcca1 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 13 May 2025 18:06:01 -0400 Subject: [PATCH 057/196] internal/mcp: remove omitzero To be compatible with Go 1.23, we cannot use the encoding/json "omitzero" option. Replace with a custom MarshalJSON method. Change-Id: Ied615ea04fb20e2459f9ed20ef03c81edbc74961 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672417 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- internal/mcp/content.go | 33 ++++++++++++++++++++++++--- internal/mcp/content_test.go | 43 ++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/internal/mcp/content.go b/internal/mcp/content.go index 8e4f282b74b..379466da999 100644 --- a/internal/mcp/content.go +++ b/internal/mcp/content.go @@ -6,6 +6,7 @@ package mcp import ( "encoding/json" + "errors" "fmt" ) @@ -23,7 +24,7 @@ type Content struct { Type string `json:"type"` Text string `json:"text,omitempty"` MIMEType string `json:"mimeType,omitempty"` - Data []byte `json:"data,omitzero"` + Data []byte `json:"data,omitempty"` Resource *ResourceContents `json:"resource,omitempty"` Annotations *Annotations `json:"annotations,omitempty"` } @@ -68,8 +69,34 @@ func NewResourceContent(resource ResourceContents) Content { type ResourceContents struct { URI string `json:"uri,"` MIMEType string `json:"mimeType,omitempty"` - Text string `json:"text,omitempty"` - Blob []byte `json:"blob,omitzero"` + Text string `json:"text"` + Blob []byte `json:"blob,omitempty"` +} + +func (r ResourceContents) MarshalJSON() ([]byte, error) { + if r.URI == "" { + return nil, errors.New("ResourceContents missing URI") + } + if r.Blob == nil { + // Text. Marshal normally. + type wireResourceContents ResourceContents + return json.Marshal((wireResourceContents)(r)) + } + // Blob. + if r.Text != "" { + return nil, errors.New("ResourceContents has non-zero Text and Blob fields") + } + // r.Blob may be the empty slice, so marshal with an alternative definition. + br := struct { + URI string `json:"uri,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Blob []byte `json:"blob"` + }{ + URI: r.URI, + MIMEType: r.MIMEType, + Blob: r.Blob, + } + return json.Marshal(br) } func NewTextResourceContents(uri, mimeType, text string) ResourceContents { diff --git a/internal/mcp/content_test.go b/internal/mcp/content_test.go index 18f34a41a0d..9be3fbe940e 100644 --- a/internal/mcp/content_test.go +++ b/internal/mcp/content_test.go @@ -57,3 +57,46 @@ func TestContent(t *testing.T) { } } } + +func TestResourceContents(t *testing.T) { + for _, tt := range []struct { + rc mcp.ResourceContents + want string // marshaled JSON + }{ + { + mcp.ResourceContents{URI: "u", Text: "t"}, + `{"uri":"u","text":"t"}`, + }, + { + mcp.ResourceContents{URI: "u", MIMEType: "m", Text: "t"}, + `{"uri":"u","mimeType":"m","text":"t"}`, + }, + { + mcp.ResourceContents{URI: "u", Text: "", Blob: nil}, + `{"uri":"u","text":""}`, + }, + { + mcp.ResourceContents{URI: "u", Blob: []byte{}}, + `{"uri":"u","blob":""}`, + }, + { + mcp.ResourceContents{URI: "u", Blob: []byte{1}}, + `{"uri":"u","blob":"AQ=="}`, + }, + } { + data, err := json.Marshal(tt.rc) + if err != nil { + t.Fatal(err) + } + if got := string(data); got != tt.want { + t.Errorf("%#v:\ngot %s\nwant %s", tt.rc, got, tt.want) + } + var urc mcp.ResourceContents + if err := json.Unmarshal(data, &urc); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(tt.rc, urc); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } + } +} From 3ab3cc4a55cdb5d98a5e59329bf2e1c5a4726257 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 13 May 2025 18:17:05 -0400 Subject: [PATCH 058/196] internal/mcp: document content constructors Also, make sure that NewBlobResourceContents creates a Blob. Change-Id: I2f0fe2ac6410d6b155deb2c8643a40bc73e25cd4 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672418 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/content.go | 58 ++++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/internal/mcp/content.go b/internal/mcp/content.go index 379466da999..87098d193e3 100644 --- a/internal/mcp/content.go +++ b/internal/mcp/content.go @@ -10,16 +10,16 @@ import ( "fmt" ) -// The []byte fields below are marked omitzero, not omitempty: -// we want to marshal an empty byte slice. -// TODO(jba): figure out how to fix this for 1.23. - // Content is the wire format for content. // It represents the protocol types TextContent, ImageContent, AudioContent // and EmbeddedResource. -// The Type field distinguishes them. In the protocol, each type has a constant -// value for the field. -// At most one of Text, Data, and Resource is non-zero. +// Use [NewTextContent], [NewImageContent], [NewAudioContent] or [NewResourceContent] +// to create one. +// +// The Type field must be one of "text", "image", "audio" or "resource". The +// constructors above populate this field appropriately. +// Although at most one of Text, Data, and Resource should be non-zero, consumers of Content +// use the Type field to determine which value to use; values in the other fields are ignored. type Content struct { Type string `json:"type"` Text string `json:"text,omitempty"` @@ -44,42 +44,47 @@ func (c *Content) UnmarshalJSON(data []byte) error { return nil } +// NewTextContent creates a [Content] with text. func NewTextContent(text string) Content { return Content{Type: "text", Text: text} } +// NewImageContent creates a [Content] with image data. func NewImageContent(data []byte, mimeType string) Content { return Content{Type: "image", Data: data, MIMEType: mimeType} } +// NewAudioContent creates a [Content] with audio data. func NewAudioContent(data []byte, mimeType string) Content { return Content{Type: "audio", Data: data, MIMEType: mimeType} } -func NewResourceContent(resource ResourceContents) Content { - return Content{Type: "resource", Resource: &resource} +// NewResourceContent creates a [Content] with an embedded resource. +func NewResourceContent(resource *ResourceContents) Content { + return Content{Type: "resource", Resource: resource} } -// A ResourceContents is either a TextResourceContents or a BlobResourceContents. +// ResourceContents represents the union of the spec's {Text,Blob}ResourceContents types. // See https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts#L524-L551 // for the inheritance structure. -// If Blob is nil, this is a TextResourceContents; otherwise it's a BlobResourceContents. -// -// The URI field describes the resource location. + +// A ResourceContents is either a TextResourceContents or a BlobResourceContents. +// Use [NewTextResourceContents] or [NextBlobResourceContents] to create one. type ResourceContents struct { - URI string `json:"uri,"` + URI string `json:"uri"` // resource location; must not be empty MIMEType string `json:"mimeType,omitempty"` Text string `json:"text"` - Blob []byte `json:"blob,omitempty"` + Blob []byte `json:"blob,omitempty"` // if nil, then text; else blob } func (r ResourceContents) MarshalJSON() ([]byte, error) { + // If we could assume Go 1.24, we could use omitzero for Blob and avoid this method. if r.URI == "" { return nil, errors.New("ResourceContents missing URI") } if r.Blob == nil { // Text. Marshal normally. - type wireResourceContents ResourceContents + type wireResourceContents ResourceContents // (lacks MarshalJSON method) return json.Marshal((wireResourceContents)(r)) } // Blob. @@ -99,8 +104,9 @@ func (r ResourceContents) MarshalJSON() ([]byte, error) { return json.Marshal(br) } -func NewTextResourceContents(uri, mimeType, text string) ResourceContents { - return ResourceContents{ +// NewTextResourceContents returns a [ResourceContents] containing text. +func NewTextResourceContents(uri, mimeType, text string) *ResourceContents { + return &ResourceContents{ URI: uri, MIMEType: mimeType, Text: text, @@ -108,15 +114,13 @@ func NewTextResourceContents(uri, mimeType, text string) ResourceContents { } } -// A BlobResourceContents is the contents of a blob resource. -type BlobResourceContents struct { - URI string - MIMEType string - Blob []byte -} - -func NewBlobResourceContents(uri, mimeType string, blob []byte) ResourceContents { - return ResourceContents{ +// NewTextResourceContents returns a [ResourceContents] containing a byte slice. +func NewBlobResourceContents(uri, mimeType string, blob []byte) *ResourceContents { + // The only way to distinguish text from blob is a non-nil Blob field. + if blob == nil { + blob = []byte{} + } + return &ResourceContents{ URI: uri, MIMEType: mimeType, Blob: blob, From e31df7749f71480a23b8714b3da770e32ff4f0e5 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 13 May 2025 19:10:34 -0400 Subject: [PATCH 059/196] internal/mcp: pointerize all the things Use pointers for the protocol's struct types, to future-proof against them getting large. Change-Id: I55b204080a41c8140539697ba264e502e912e613 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672575 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- gopls/internal/mcp/mcp.go | 6 +-- internal/mcp/client.go | 12 +++--- internal/mcp/cmd_test.go | 2 +- internal/mcp/content.go | 16 ++++---- internal/mcp/content_test.go | 4 +- internal/mcp/examples/hello/main.go | 6 +-- internal/mcp/examples/sse/main.go | 4 +- internal/mcp/generate.go | 63 +++++++++++++++++++++-------- internal/mcp/jsonschema/schema.go | 8 ++++ internal/mcp/mcp_test.go | 32 +++++++-------- internal/mcp/prompt.go | 12 +++--- internal/mcp/prompt_test.go | 6 +-- internal/mcp/protocol.go | 42 +++++++++---------- internal/mcp/server.go | 8 ++-- internal/mcp/server_example_test.go | 4 +- internal/mcp/sse_example_test.go | 4 +- internal/mcp/sse_test.go | 2 +- internal/mcp/tool.go | 8 ++-- internal/mcp/tool_test.go | 2 +- 19 files changed, 140 insertions(+), 101 deletions(-) diff --git a/gopls/internal/mcp/mcp.go b/gopls/internal/mcp/mcp.go index b21a700d598..cd8236e53c7 100644 --- a/gopls/internal/mcp/mcp.go +++ b/gopls/internal/mcp/mcp.go @@ -140,9 +140,9 @@ type HelloParams struct { Name string `json:"name" mcp:"the name to say hi to"` } -func helloHandler(session *cache.Session) func(ctx context.Context, cc *mcp.ServerConnection, request *HelloParams) ([]mcp.Content, error) { - return func(ctx context.Context, cc *mcp.ServerConnection, request *HelloParams) ([]mcp.Content, error) { - return []mcp.Content{ +func helloHandler(session *cache.Session) func(ctx context.Context, cc *mcp.ServerConnection, request *HelloParams) ([]*mcp.Content, error) { + return func(ctx context.Context, cc *mcp.ServerConnection, request *HelloParams) ([]*mcp.Content, error) { + return []*mcp.Content{ mcp.NewTextContent("Hi " + request.Name + ", this is lsp session " + session.ID()), }, nil } diff --git a/internal/mcp/client.go b/internal/mcp/client.go index f538e16fddb..4d7c35c7b05 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -24,7 +24,7 @@ type Client struct { opts ClientOptions mu sync.Mutex conn *jsonrpc2.Connection - roots *featureSet[Root] + roots *featureSet[*Root] initializeResult *initializeResult } @@ -38,7 +38,7 @@ func NewClient(name, version string, t Transport, opts *ClientOptions) *Client { name: name, version: version, transport: t, - roots: newFeatureSet(func(r Root) string { return r.URI }), + roots: newFeatureSet(func(r *Root) string { return r.URI }), } if opts != nil { c.opts = *opts @@ -84,7 +84,7 @@ func (c *Client) Start(ctx context.Context) (err error) { return err } params := &initializeParams{ - ClientInfo: implementation{Name: c.name, Version: c.version}, + ClientInfo: &implementation{Name: c.name, Version: c.version}, } if err := call(ctx, c.conn, "initialize", params, &c.initializeResult); err != nil { return err @@ -112,7 +112,7 @@ func (c *Client) Wait() error { // replacing any with the same URIs, // and notifies any connected servers. // TODO: notification -func (c *Client) AddRoots(roots ...Root) { +func (c *Client) AddRoots(roots ...*Root) { c.mu.Lock() defer c.mu.Unlock() c.roots.add(roots...) @@ -159,7 +159,7 @@ func (c *Client) Ping(ctx context.Context) error { } // ListPrompts lists prompts that are currently available on the server. -func (c *Client) ListPrompts(ctx context.Context) ([]Prompt, error) { +func (c *Client) ListPrompts(ctx context.Context) ([]*Prompt, error) { var ( params = &ListPromptsParams{} result ListPromptsResult @@ -186,7 +186,7 @@ func (c *Client) GetPrompt(ctx context.Context, name string, args map[string]str } // ListTools lists tools that are currently available on the server. -func (c *Client) ListTools(ctx context.Context) ([]Tool, error) { +func (c *Client) ListTools(ctx context.Context) ([]*Tool, error) { var ( params = &ListToolsParams{} result ListToolsResult diff --git a/internal/mcp/cmd_test.go b/internal/mcp/cmd_test.go index f4ce863785a..764b77bd11f 100644 --- a/internal/mcp/cmd_test.go +++ b/internal/mcp/cmd_test.go @@ -57,7 +57,7 @@ func TestCmdTransport(t *testing.T) { log.Fatal(err) } want := &mcp.CallToolResult{ - Content: []mcp.Content{{Type: "text", Text: "Hi user"}}, + Content: []*mcp.Content{{Type: "text", Text: "Hi user"}}, } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("greet returned unexpected content (-want +got):\n%s", diff) diff --git a/internal/mcp/content.go b/internal/mcp/content.go index 87098d193e3..94f5cd18f4a 100644 --- a/internal/mcp/content.go +++ b/internal/mcp/content.go @@ -45,23 +45,23 @@ func (c *Content) UnmarshalJSON(data []byte) error { } // NewTextContent creates a [Content] with text. -func NewTextContent(text string) Content { - return Content{Type: "text", Text: text} +func NewTextContent(text string) *Content { + return &Content{Type: "text", Text: text} } // NewImageContent creates a [Content] with image data. -func NewImageContent(data []byte, mimeType string) Content { - return Content{Type: "image", Data: data, MIMEType: mimeType} +func NewImageContent(data []byte, mimeType string) *Content { + return &Content{Type: "image", Data: data, MIMEType: mimeType} } // NewAudioContent creates a [Content] with audio data. -func NewAudioContent(data []byte, mimeType string) Content { - return Content{Type: "audio", Data: data, MIMEType: mimeType} +func NewAudioContent(data []byte, mimeType string) *Content { + return &Content{Type: "audio", Data: data, MIMEType: mimeType} } // NewResourceContent creates a [Content] with an embedded resource. -func NewResourceContent(resource *ResourceContents) Content { - return Content{Type: "resource", Resource: resource} +func NewResourceContent(resource *ResourceContents) *Content { + return &Content{Type: "resource", Resource: resource} } // ResourceContents represents the union of the spec's {Text,Blob}ResourceContents types. diff --git a/internal/mcp/content_test.go b/internal/mcp/content_test.go index 9be3fbe940e..59f41e0bf85 100644 --- a/internal/mcp/content_test.go +++ b/internal/mcp/content_test.go @@ -14,7 +14,7 @@ import ( func TestContent(t *testing.T) { tests := []struct { - in mcp.Content + in *mcp.Content want string // json serialization }{ {mcp.NewTextContent("hello"), `{"type":"text","text":"hello"}`}, @@ -48,7 +48,7 @@ func TestContent(t *testing.T) { if diff := cmp.Diff(test.want, string(got)); diff != "" { t.Errorf("json.Marshal(%v) mismatch (-want +got):\n%s", test.in, diff) } - var out mcp.Content + var out *mcp.Content if err := json.Unmarshal(got, &out); err != nil { t.Fatal(err) } diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index d58067217d0..b71f86d242a 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -20,8 +20,8 @@ type HiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) ([]mcp.Content, error) { - return []mcp.Content{ +func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) ([]*mcp.Content, error) { + return []*mcp.Content{ mcp.NewTextContent("Hi " + params.Name), }, nil } @@ -29,7 +29,7 @@ func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) ([]m func PromptHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Description: "Code review prompt", - Messages: []mcp.PromptMessage{ + Messages: []*mcp.PromptMessage{ {Role: "user", Content: mcp.NewTextContent("Say hi to " + params.Name)}, }, }, nil diff --git a/internal/mcp/examples/sse/main.go b/internal/mcp/examples/sse/main.go index 4b6d104e83d..ba793936621 100644 --- a/internal/mcp/examples/sse/main.go +++ b/internal/mcp/examples/sse/main.go @@ -19,8 +19,8 @@ type SayHiParams struct { Name string `json:"name" mcp:"the name to say hi to"` } -func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *SayHiParams) ([]mcp.Content, error) { - return []mcp.Content{ +func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *SayHiParams) ([]*mcp.Content, error) { + return []*mcp.Content{ mcp.NewTextContent("Hi " + params.Name), }, nil } diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index 8d184c859df..84c79f3bda1 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -139,6 +139,10 @@ func main() { if err := json.Unmarshal(data, &schema); err != nil { log.Fatal(err) } + // Resolve the schema so we have the referents of all the Refs. + if _, err := schema.Resolve("", nil); err != nil { + log.Fatal(err) + } // Collect named types. Since we may create new type definitions while // writing types, we collect definitions and concatenate them later. This @@ -253,7 +257,8 @@ func writeDecl(configName string, config typeConfig, def *jsonschema.Schema, nam // be added during writeType, if they are extracted from inner fields. func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named map[string]*bytes.Buffer) error { // Use type names for Named types. - if name := strings.TrimPrefix(def.Ref, "#/definitions/"); name != "" { + name, resolved := deref(def) + if name != "" { // TODO: this check is not quite right: we should really panic if the // definition is missing, *but only if w is not io.Discard*. That's not a // great API: see if we can do something more explicit than io.Discard. @@ -266,6 +271,9 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma } else if cfg.Name != "" { name = cfg.Name } + if isStruct(resolved) { + w.Write([]byte{'*'}) + } } w.Write([]byte(name)) return nil @@ -288,7 +296,7 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma if slices.ContainsFunc(def.AnyOf, func(s *jsonschema.Schema) bool { return s.Ref == "#/definitions/TextContent" }) { - fmt.Fprintf(w, "Content") + fmt.Fprintf(w, "*Content") } else { // E.g. union types. fmt.Fprintf(w, "json.RawMessage") @@ -322,14 +330,15 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma required := slices.Contains(def.Required, name) - // If the field is not required, and is a struct type, indirect with a + // If the field is a struct type, indirect with a // pointer so that it can be empty as defined by encoding/json. - // - // TODO: use omitzero when available. - needPointer := !required && - (strings.HasPrefix(fieldDef.Ref, "#/definitions/") || - fieldDef.Type == "object" && !canHaveAdditionalProperties(fieldDef)) - + // This also future-proofs against the struct getting large. + fieldTypeSchema := fieldDef + // If the schema is a reference, dereference it. + if _, rs := deref(fieldDef); rs != nil { + fieldTypeSchema = rs + } + needPointer := isStruct(fieldTypeSchema) if config != nil && config.Fields[export] != nil { r := config.Fields[export] if r.Substitute != "" { @@ -348,13 +357,8 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma } fmt.Fprintf(w, typename) } - } else { - if needPointer { - fmt.Fprintf(w, "*") - } - if err := writeType(w, nil, fieldDef, named); err != nil { - return fmt.Errorf("failed to write type for field %s: %v", export, err) - } + } else if err := writeType(w, nil, fieldDef, named); err != nil { + return fmt.Errorf("failed to write type for field %s: %v", export, err) } fmt.Fprintf(w, " `json:\"%s", name) if !required { @@ -452,6 +456,33 @@ func exportName(s string) string { return s } +// deref dereferences s.Ref. +// If s.Ref refers to a schema in the Definitions section, deref +// returns the definition name and the associated schema. +// Otherwise, deref returns "", nil. +func deref(s *jsonschema.Schema) (name string, _ *jsonschema.Schema) { + name, ok := strings.CutPrefix(s.Ref, "#/definitions/") + if !ok { + return "", nil + } + return name, s.ResolvedRef() +} + +// isStruct reports whether s should be translated to a struct. +func isStruct(s *jsonschema.Schema) bool { + return s.Type == "object" && s.Properties != nil && !canHaveAdditionalProperties(s) +} + +// schemaJSON returns the JSON for s. +// For debugging. +func schemaJSON(s *jsonschema.Schema) string { + data, err := json.Marshal(s) + if err != nil { + return fmt.Sprintf("", err) + } + return string(data) +} + // Map from initialism to the regexp that matches it. var initialisms = map[string]*regexp.Regexp{ "Id": nil, diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index c032661baf1..f1e16a5decc 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -138,6 +138,14 @@ func (s *Schema) String() string { return "" } +// ResolvedRef returns the Schema to which this schema's $ref keyword +// refers, or nil if it doesn't have a $ref. +// It returns nil if this schema has not been resolved, meaning that +// [Schema.Resolve] was called on it or one of its ancestors. +func (s *Schema) ResolvedRef() *Schema { + return s.resolvedRef +} + // json returns the schema in json format. func (s *Schema) json() string { data, err := json.MarshalIndent(s, "", " ") diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 3377b689b5a..4d8a20c9d75 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -24,11 +24,11 @@ type hiParams struct { Name string } -func sayHi(ctx context.Context, cc *ServerConnection, v hiParams) ([]Content, error) { +func sayHi(ctx context.Context, cc *ServerConnection, v hiParams) ([]*Content, error) { if err := cc.Ping(ctx); err != nil { return nil, fmt.Errorf("ping failed: %v", err) } - return []Content{NewTextContent("hi " + v.Name)}, nil + return []*Content{NewTextContent("hi " + v.Name)}, nil } func TestEndToEnd(t *testing.T) { @@ -43,7 +43,7 @@ func TestEndToEnd(t *testing.T) { // The 'fail' tool returns this error. failure := errors.New("mcp failure") s.AddTools( - NewTool("fail", "just fail", func(context.Context, *ServerConnection, struct{}) ([]Content, error) { + NewTool("fail", "just fail", func(context.Context, *ServerConnection, struct{}) ([]*Content, error) { return nil, failure }), ) @@ -52,7 +52,7 @@ func TestEndToEnd(t *testing.T) { NewPrompt("code_review", "do a code review", func(_ context.Context, _ *ServerConnection, params struct{ Code string }) (*GetPromptResult, error) { return &GetPromptResult{ Description: "Code review prompt", - Messages: []PromptMessage{ + Messages: []*PromptMessage{ {Role: "user", Content: NewTextContent("Please review the following code: " + params.Code)}, }, }, nil @@ -82,7 +82,7 @@ func TestEndToEnd(t *testing.T) { }() c := NewClient("testClient", "v1.0.0", ct, nil) - c.AddRoots(Root{URI: "file:///root"}) + c.AddRoots(&Root{URI: "file:///root"}) // Connect the client. if err := c.Start(ctx); err != nil { @@ -97,11 +97,11 @@ func TestEndToEnd(t *testing.T) { if err != nil { t.Errorf("prompts/list failed: %v", err) } - wantPrompts := []Prompt{ + wantPrompts := []*Prompt{ { Name: "code_review", Description: "do a code review", - Arguments: []PromptArgument{{Name: "Code", Required: true}}, + Arguments: []*PromptArgument{{Name: "Code", Required: true}}, }, {Name: "fail"}, } @@ -115,7 +115,7 @@ func TestEndToEnd(t *testing.T) { } wantReview := &GetPromptResult{ Description: "Code review prompt", - Messages: []PromptMessage{{ + Messages: []*PromptMessage{{ Content: NewTextContent("Please review the following code: 1+1"), Role: "user", }}, @@ -134,7 +134,7 @@ func TestEndToEnd(t *testing.T) { if err != nil { t.Errorf("tools/list failed: %v", err) } - wantTools := []Tool{ + wantTools := []*Tool{ { Name: "fail", Description: "just fail", @@ -165,7 +165,7 @@ func TestEndToEnd(t *testing.T) { t.Fatal(err) } wantHi := &CallToolResult{ - Content: []Content{{Type: "text", Text: "hi user"}}, + Content: []*Content{{Type: "text", Text: "hi user"}}, } if diff := cmp.Diff(wantHi, gotHi); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) @@ -179,7 +179,7 @@ func TestEndToEnd(t *testing.T) { } wantFail := &CallToolResult{ IsError: true, - Content: []Content{{Type: "text", Text: failure.Error()}}, + Content: []*Content{{Type: "text", Text: failure.Error()}}, } if diff := cmp.Diff(wantFail, gotFail); diff != "" { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) @@ -187,18 +187,18 @@ func TestEndToEnd(t *testing.T) { }) t.Run("resources", func(t *testing.T) { - resource1 := Resource{ + resource1 := &Resource{ Name: "public", MIMEType: "text/plain", URI: "file:///file1.txt", } - resource2 := Resource{ + resource2 := &Resource{ Name: "public", // names are not unique IDs MIMEType: "text/plain", URI: "file:///nonexistent.txt", } - readHandler := func(_ context.Context, r Resource, _ *ReadResourceParams) (*ReadResourceResult, error) { + readHandler := func(_ context.Context, r *Resource, _ *ReadResourceParams) (*ReadResourceResult, error) { if r.URI == "file:///file1.txt" { return &ReadResourceResult{ Contents: &ResourceContents{ @@ -216,7 +216,7 @@ func TestEndToEnd(t *testing.T) { if err != nil { t.Fatal(err) } - if diff := cmp.Diff([]Resource{resource1, resource2}, lrres.Resources); diff != "" { + if diff := cmp.Diff([]*Resource{resource1, resource2}, lrres.Resources); diff != "" { t.Errorf("resources/list mismatch (-want, +got):\n%s", diff) } @@ -370,7 +370,7 @@ func TestCancellation(t *testing.T) { cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, cc *ServerConnection, v struct{}) ([]Content, error) { + slowRequest := func(ctx context.Context, cc *ServerConnection, v struct{}) ([]*Content, error) { start <- struct{}{} select { case <-ctx.Done(): diff --git a/internal/mcp/prompt.go b/internal/mcp/prompt.go index 3eb3bb53668..c813cee5413 100644 --- a/internal/mcp/prompt.go +++ b/internal/mcp/prompt.go @@ -20,7 +20,7 @@ type PromptHandler func(context.Context, *ServerConnection, map[string]string) ( // A Prompt is a prompt definition bound to a prompt handler. type ServerPrompt struct { - Definition Prompt + Definition *Prompt Handler PromptHandler } @@ -41,7 +41,7 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, panic(fmt.Sprintf("handler request type must be a struct")) } prompt := &ServerPrompt{ - Definition: Prompt{ + Definition: &Prompt{ Name: name, Description: description, }, @@ -54,7 +54,7 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, if prop.Type != "string" { panic(fmt.Sprintf("handler type must consist only of string fields")) } - prompt.Definition.Arguments = append(prompt.Definition.Arguments, PromptArgument{ + prompt.Definition.Arguments = append(prompt.Definition.Arguments, &PromptArgument{ Name: name, Description: prop.Description, Required: required[name], @@ -95,13 +95,13 @@ func (s promptSetter) set(p *ServerPrompt) { s(p) } // Required and Description, and panics when encountering any other option. func Argument(name string, opts ...SchemaOption) PromptOption { return promptSetter(func(p *ServerPrompt) { - i := slices.IndexFunc(p.Definition.Arguments, func(arg PromptArgument) bool { + i := slices.IndexFunc(p.Definition.Arguments, func(arg *PromptArgument) bool { return arg.Name == name }) - var arg PromptArgument + var arg *PromptArgument if i < 0 { i = len(p.Definition.Arguments) - arg = PromptArgument{Name: name} + arg = &PromptArgument{Name: name} p.Definition.Arguments = append(p.Definition.Arguments, arg) } else { arg = p.Definition.Arguments[i] diff --git a/internal/mcp/prompt_test.go b/internal/mcp/prompt_test.go index 29a1c5ac5a0..88220fa577b 100644 --- a/internal/mcp/prompt_test.go +++ b/internal/mcp/prompt_test.go @@ -20,7 +20,7 @@ func testPromptHandler[T any](context.Context, *mcp.ServerConnection, T) (*mcp.G func TestNewPrompt(t *testing.T) { tests := []struct { prompt *mcp.ServerPrompt - want []mcp.PromptArgument + want []*mcp.PromptArgument }{ { mcp.NewPrompt("empty", "", testPromptHandler[struct{}]), @@ -28,7 +28,7 @@ func TestNewPrompt(t *testing.T) { }, { mcp.NewPrompt("add_arg", "", testPromptHandler[struct{}], mcp.Argument("x")), - []mcp.PromptArgument{{Name: "x"}}, + []*mcp.PromptArgument{{Name: "x"}}, }, { mcp.NewPrompt("combo", "", testPromptHandler[struct { @@ -38,7 +38,7 @@ func TestNewPrompt(t *testing.T) { }], mcp.Argument("name", mcp.Description("the person's name")), mcp.Argument("State", mcp.Required(false))), - []mcp.PromptArgument{ + []*mcp.PromptArgument{ {Name: "State"}, {Name: "country"}, {Name: "name", Required: true, Description: "the person's name"}, diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index 73ffa75819c..cbe868b941c 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -47,7 +47,7 @@ type CallToolResult struct { // This result property is reserved by the protocol to allow clients and servers // to attach additional metadata to their responses. Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Content []Content `json:"content"` + Content []*Content `json:"content"` // Whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). @@ -73,7 +73,7 @@ type ClientCapabilities struct { Experimental map[string]struct { } `json:"experimental,omitempty"` // Present if the client supports listing roots. - Roots *struct { + Roots struct { // Whether the client supports notifications for changes to the roots list. ListChanged bool `json:"listChanged,omitempty"` } `json:"roots,omitempty"` @@ -95,8 +95,8 @@ type GetPromptResult struct { // to attach additional metadata to their responses. Meta map[string]json.RawMessage `json:"_meta,omitempty"` // An optional description for the prompt. - Description string `json:"description,omitempty"` - Messages []PromptMessage `json:"messages"` + Description string `json:"description,omitempty"` + Messages []*PromptMessage `json:"messages"` } type ListPromptsParams struct { @@ -112,8 +112,8 @@ type ListPromptsResult struct { Meta map[string]json.RawMessage `json:"_meta,omitempty"` // An opaque token representing the pagination position after the last returned // result. If present, there may be more results available. - NextCursor string `json:"nextCursor,omitempty"` - Prompts []Prompt `json:"prompts"` + NextCursor string `json:"nextCursor,omitempty"` + Prompts []*Prompt `json:"prompts"` } type ListResourcesParams struct { @@ -129,17 +129,17 @@ type ListResourcesResult struct { Meta map[string]json.RawMessage `json:"_meta,omitempty"` // An opaque token representing the pagination position after the last returned // result. If present, there may be more results available. - NextCursor string `json:"nextCursor,omitempty"` - Resources []Resource `json:"resources"` + NextCursor string `json:"nextCursor,omitempty"` + Resources []*Resource `json:"resources"` } type ListRootsParams struct { - Meta *struct { + Meta struct { // If specified, the caller is requesting out-of-band progress notifications for // this request (as represented by notifications/progress). The value of this // parameter is an opaque token that will be attached to any subsequent // notifications. The receiver is not obligated to provide these notifications. - ProgressToken *any `json:"progressToken,omitempty"` + ProgressToken any `json:"progressToken,omitempty"` } `json:"_meta,omitempty"` } @@ -150,7 +150,7 @@ type ListRootsResult struct { // This result property is reserved by the protocol to allow clients and servers // to attach additional metadata to their responses. Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Roots []Root `json:"roots"` + Roots []*Root `json:"roots"` } type ListToolsParams struct { @@ -166,14 +166,14 @@ type ListToolsResult struct { Meta map[string]json.RawMessage `json:"_meta,omitempty"` // An opaque token representing the pagination position after the last returned // result. If present, there may be more results available. - NextCursor string `json:"nextCursor,omitempty"` - Tools []Tool `json:"tools"` + NextCursor string `json:"nextCursor,omitempty"` + Tools []*Tool `json:"tools"` } // A prompt or prompt template that the server offers. type Prompt struct { // A list of arguments to use for templating the prompt. - Arguments []PromptArgument `json:"arguments,omitempty"` + Arguments []*PromptArgument `json:"arguments,omitempty"` // An optional description of what this prompt provides Description string `json:"description,omitempty"` // The name of the prompt or prompt template. @@ -195,8 +195,8 @@ type PromptArgument struct { // This is similar to `SamplingMessage`, but also supports the embedding of // resources from the MCP server. type PromptMessage struct { - Content Content `json:"content"` - Role Role `json:"role"` + Content *Content `json:"content"` + Role Role `json:"role"` } type ReadResourceParams struct { @@ -312,8 +312,8 @@ type implementation struct { } type initializeParams struct { - Capabilities ClientCapabilities `json:"capabilities"` - ClientInfo implementation `json:"clientInfo"` + Capabilities *ClientCapabilities `json:"capabilities"` + ClientInfo *implementation `json:"clientInfo"` // The latest version of the Model Context Protocol that the client supports. // The client MAY decide to support older versions as well. ProtocolVersion string `json:"protocolVersion"` @@ -325,7 +325,7 @@ type initializeResult struct { // This result property is reserved by the protocol to allow clients and servers // to attach additional metadata to their responses. Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Capabilities serverCapabilities `json:"capabilities"` + Capabilities *serverCapabilities `json:"capabilities"` // Instructions describing how to use the server and its features. // // This can be used by clients to improve the LLM's understanding of available @@ -335,8 +335,8 @@ type initializeResult struct { // The version of the Model Context Protocol that the server wants to use. This // may not match the version that the client requested. If the client cannot // support this version, it MUST disconnect. - ProtocolVersion string `json:"protocolVersion"` - ServerInfo implementation `json:"serverInfo"` + ProtocolVersion string `json:"protocolVersion"` + ServerInfo *implementation `json:"serverInfo"` } type initializedParams struct { diff --git a/internal/mcp/server.go b/internal/mcp/server.go index f6f5cd1d43a..05ca2107d4f 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -121,11 +121,11 @@ const codeResourceNotFound = -31002 // A ReadResourceHandler is a function that reads a resource. // If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. -type ReadResourceHandler func(context.Context, Resource, *ReadResourceParams) (*ReadResourceResult, error) +type ReadResourceHandler func(context.Context, *Resource, *ReadResourceParams) (*ReadResourceResult, error) // A ServerResource associates a Resource with its handler. type ServerResource struct { - Resource Resource + Resource *Resource Handler ReadResourceHandler } @@ -381,7 +381,7 @@ func (cc *ServerConnection) initialize(ctx context.Context, _ *ServerConnection, return &initializeResult{ // TODO(rfindley): support multiple protocol versions. ProtocolVersion: "2024-11-05", - Capabilities: serverCapabilities{ + Capabilities: &serverCapabilities{ Prompts: &promptCapabilities{ ListChanged: false, // not yet supported }, @@ -390,7 +390,7 @@ func (cc *ServerConnection) initialize(ctx context.Context, _ *ServerConnection, }, }, Instructions: cc.server.opts.Instructions, - ServerInfo: implementation{ + ServerInfo: &implementation{ Name: cc.server.name, Version: cc.server.version, }, diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index 5af54e999c1..e596fe69c34 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -16,8 +16,8 @@ type SayHiParams struct { Name string `json:"name" mcp:"the name to say hi to"` } -func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *SayHiParams) ([]mcp.Content, error) { - return []mcp.Content{ +func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *SayHiParams) ([]*mcp.Content, error) { + return []*mcp.Content{ mcp.NewTextContent("Hi " + params.Name), }, nil } diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go index 5464356d249..397a1f8d5c1 100644 --- a/internal/mcp/sse_example_test.go +++ b/internal/mcp/sse_example_test.go @@ -18,8 +18,8 @@ type AddParams struct { X, Y int } -func Add(ctx context.Context, cc *mcp.ServerConnection, params *AddParams) ([]mcp.Content, error) { - return []mcp.Content{ +func Add(ctx context.Context, cc *mcp.ServerConnection, params *AddParams) ([]*mcp.Content, error) { + return []*mcp.Content{ mcp.NewTextContent(fmt.Sprintf("%d", params.X+params.Y)), }, nil } diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index 10f04268a33..4f65339377c 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -48,7 +48,7 @@ func TestSSEServer(t *testing.T) { t.Fatal(err) } wantHi := &CallToolResult{ - Content: []Content{{Type: "text", Text: "hi user"}}, + Content: []*Content{{Type: "text", Text: "hi user"}}, } if diff := cmp.Diff(wantHi, gotHi); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index 8c89f915ed9..4bfce459ebb 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -17,7 +17,7 @@ type ToolHandler func(context.Context, *ServerConnection, map[string]json.RawMes // A Tool is a tool definition that is bound to a tool handler. type ServerTool struct { - Definition Tool + Definition *Tool Handler ToolHandler } @@ -34,7 +34,7 @@ type ServerTool struct { // // TODO: just have the handler return a CallToolResult: returning []Content is // going to be inconsistent with other server features. -func NewTool[TReq any](name, description string, handler func(context.Context, *ServerConnection, TReq) ([]Content, error), opts ...ToolOption) *ServerTool { +func NewTool[TReq any](name, description string, handler func(context.Context, *ServerConnection, TReq) ([]*Content, error), opts ...ToolOption) *ServerTool { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) @@ -55,7 +55,7 @@ func NewTool[TReq any](name, description string, handler func(context.Context, * // rather than returned as jsonrpc2 server errors. if err != nil { return &CallToolResult{ - Content: []Content{NewTextContent(err.Error())}, + Content: []*Content{NewTextContent(err.Error())}, IsError: true, }, nil } @@ -65,7 +65,7 @@ func NewTool[TReq any](name, description string, handler func(context.Context, * return res, nil } t := &ServerTool{ - Definition: Tool{ + Definition: &Tool{ Name: name, Description: description, InputSchema: schema, diff --git a/internal/mcp/tool_test.go b/internal/mcp/tool_test.go index e6bf4c0eafb..a4202a221fe 100644 --- a/internal/mcp/tool_test.go +++ b/internal/mcp/tool_test.go @@ -15,7 +15,7 @@ import ( ) // testToolHandler is used for type inference in TestNewTool. -func testToolHandler[T any](context.Context, *mcp.ServerConnection, T) ([]mcp.Content, error) { +func testToolHandler[T any](context.Context, *mcp.ServerConnection, T) ([]*mcp.Content, error) { panic("not implemented") } From ccbd1d95a5e465ba28ed919e32c06665973066b5 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 14 May 2025 05:50:50 -0400 Subject: [PATCH 060/196] internal/mcp: change function signatures match design doc To ensure backward compatibility, every method in the spec now takes a *Params struct and returns a *Result struct. CallTools is an exception: we'll change that in the next CL. Change-Id: I58835239c67582401b8c9c2d070702d270131c08 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672595 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/client.go | 40 ++++++++-------------------------------- internal/mcp/generate.go | 14 +++++++++----- internal/mcp/mcp_test.go | 18 +++++++++--------- internal/mcp/protocol.go | 10 ++++++++++ internal/mcp/server.go | 2 +- internal/mcp/sse_test.go | 2 +- 6 files changed, 38 insertions(+), 48 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 4d7c35c7b05..bac4f3cadfa 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -154,47 +154,23 @@ func (c *Client) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) } // Ping makes an MCP "ping" request to the server. -func (c *Client) Ping(ctx context.Context) error { - return call(ctx, c.conn, "ping", nil, nil) +func (c *Client) Ping(ctx context.Context, params *PingParams) error { + return call(ctx, c.conn, "ping", params, nil) } // ListPrompts lists prompts that are currently available on the server. -func (c *Client) ListPrompts(ctx context.Context) ([]*Prompt, error) { - var ( - params = &ListPromptsParams{} - result ListPromptsResult - ) - if err := call(ctx, c.conn, "prompts/list", params, &result); err != nil { - return nil, err - } - return result.Prompts, nil +func (c *Client) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { + return standardCall[ListPromptsResult](ctx, c.conn, "prompts/list", params) } // GetPrompt gets a prompt from the server. -func (c *Client) GetPrompt(ctx context.Context, name string, args map[string]string) (*GetPromptResult, error) { - var ( - params = &GetPromptParams{ - Name: name, - Arguments: args, - } - result = &GetPromptResult{} - ) - if err := call(ctx, c.conn, "prompts/get", params, result); err != nil { - return nil, err - } - return result, nil +func (c *Client) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { + return standardCall[GetPromptResult](ctx, c.conn, "prompts/get", params) } // ListTools lists tools that are currently available on the server. -func (c *Client) ListTools(ctx context.Context) ([]*Tool, error) { - var ( - params = &ListToolsParams{} - result ListToolsResult - ) - if err := call(ctx, c.conn, "tools/list", params, &result); err != nil { - return nil, err - } - return result.Tools, nil +func (c *Client) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { + return standardCall[ListToolsResult](ctx, c.conn, "tools/list", params) } // CallTool calls the tool with the given name and arguments. diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index 84c79f3bda1..342f1cc1122 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -97,11 +97,15 @@ var declarations = config{ Fields: config{"Params": {Name: "ListToolsParams"}}, }, "ListToolsResult": {}, - "Prompt": {}, - "PromptMessage": {}, - "PromptArgument": {}, - "ProgressToken": {Name: "-", Substitute: "any"}, // null|number|string - "RequestId": {Name: "-", Substitute: "any"}, // null|number|string + "PingRequest": { + Name: "-", + Fields: config{"Params": {Name: "PingParams"}}, + }, + "Prompt": {}, + "PromptMessage": {}, + "PromptArgument": {}, + "ProgressToken": {Name: "-", Substitute: "any"}, // null|number|string + "RequestId": {Name: "-", Substitute: "any"}, // null|number|string "ReadResourceRequest": { Name: "-", Fields: config{"Params": {Name: "ReadResourceParams"}}, diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 4d8a20c9d75..f62f0ae3c53 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -25,7 +25,7 @@ type hiParams struct { } func sayHi(ctx context.Context, cc *ServerConnection, v hiParams) ([]*Content, error) { - if err := cc.Ping(ctx); err != nil { + if err := cc.Ping(ctx, nil); err != nil { return nil, fmt.Errorf("ping failed: %v", err) } return []*Content{NewTextContent("hi " + v.Name)}, nil @@ -89,11 +89,11 @@ func TestEndToEnd(t *testing.T) { t.Fatal(err) } - if err := c.Ping(ctx); err != nil { + if err := c.Ping(ctx, nil); err != nil { t.Fatalf("ping failed: %v", err) } t.Run("prompts", func(t *testing.T) { - gotPrompts, err := c.ListPrompts(ctx) + res, err := c.ListPrompts(ctx, nil) if err != nil { t.Errorf("prompts/list failed: %v", err) } @@ -105,11 +105,11 @@ func TestEndToEnd(t *testing.T) { }, {Name: "fail"}, } - if diff := cmp.Diff(wantPrompts, gotPrompts); diff != "" { + if diff := cmp.Diff(wantPrompts, res.Prompts); diff != "" { t.Fatalf("prompts/list mismatch (-want +got):\n%s", diff) } - gotReview, err := c.GetPrompt(ctx, "code_review", map[string]string{"Code": "1+1"}) + gotReview, err := c.GetPrompt(ctx, &GetPromptParams{Name: "code_review", Arguments: map[string]string{"Code": "1+1"}}) if err != nil { t.Fatal(err) } @@ -124,13 +124,13 @@ func TestEndToEnd(t *testing.T) { t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) } - if _, err := c.GetPrompt(ctx, "fail", map[string]string{}); err == nil || !strings.Contains(err.Error(), failure.Error()) { + if _, err := c.GetPrompt(ctx, &GetPromptParams{Name: "fail"}); err == nil || !strings.Contains(err.Error(), failure.Error()) { t.Errorf("fail returned unexpected error: got %v, want containing %v", err, failure) } }) t.Run("tools", func(t *testing.T) { - gotTools, err := c.ListTools(ctx) + res, err := c.ListTools(ctx, nil) if err != nil { t.Errorf("tools/list failed: %v", err) } @@ -156,7 +156,7 @@ func TestEndToEnd(t *testing.T) { }, }, } - if diff := cmp.Diff(wantTools, gotTools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) } @@ -350,7 +350,7 @@ func TestBatching(t *testing.T) { errs := make(chan error, batchSize) for i := range batchSize { go func() { - _, err := c.ListTools(ctx) + _, err := c.ListTools(ctx, nil) errs <- err }() time.Sleep(2 * time.Millisecond) diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index cbe868b941c..95444a60edc 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -170,6 +170,16 @@ type ListToolsResult struct { Tools []*Tool `json:"tools"` } +type PingParams struct { + Meta struct { + // If specified, the caller is requesting out-of-band progress notifications for + // this request (as represented by notifications/progress). The value of this + // parameter is an opaque token that will be attached to any subsequent + // notifications. The receiver is not obligated to provide these notifications. + ProgressToken any `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` +} + // A prompt or prompt template that the server offers. type Prompt struct { // A list of arguments to use for templating the prompt. diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 05ca2107d4f..a0e1b210c43 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -303,7 +303,7 @@ type ServerConnection struct { } // Ping makes an MCP "ping" request to the client. -func (cc *ServerConnection) Ping(ctx context.Context) error { +func (cc *ServerConnection) Ping(ctx context.Context, _ *PingParams) error { return call(ctx, cc.conn, "ping", nil, nil) } diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index 4f65339377c..52c50fb42de 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -39,7 +39,7 @@ func TestSSEServer(t *testing.T) { if err := c.Start(ctx); err != nil { t.Fatal(err) } - if err := c.Ping(ctx); err != nil { + if err := c.Ping(ctx, nil); err != nil { t.Fatal(err) } cc := <-conns From cd3f34cfad582b6fd95687728e8a7d3427427978 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 14 May 2025 07:01:59 -0400 Subject: [PATCH 061/196] internal/mcp: change CallTool signature CallTool takes a tool name, arguments, and an options struct. Also, pass the params to ToolHandler. Update the design doc accordingly. Change-Id: I21b7995b607bea976ad708ebb6ac0716018f39cb Reviewed-on: https://go-review.googlesource.com/c/tools/+/672596 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/client.go | 10 +++++++++- internal/mcp/cmd_test.go | 2 +- internal/mcp/design/design.md | 31 ++++++++++++++--------------- internal/mcp/mcp_test.go | 10 +++++----- internal/mcp/server.go | 2 +- internal/mcp/server_example_test.go | 2 +- internal/mcp/sse_example_test.go | 2 +- internal/mcp/sse_test.go | 2 +- internal/mcp/tool.go | 6 +++--- 9 files changed, 37 insertions(+), 30 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index bac4f3cadfa..6dc0ab3a1ff 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -174,7 +174,8 @@ func (c *Client) ListTools(ctx context.Context, params *ListToolsParams) (*ListT } // CallTool calls the tool with the given name and arguments. -func (c *Client) CallTool(ctx context.Context, name string, args map[string]any) (_ *CallToolResult, err error) { +// Pass a [CallToolOptions] to provide additional request fields. +func (c *Client) CallTool(ctx context.Context, name string, args map[string]any, opts *CallToolOptions) (_ *CallToolResult, err error) { defer func() { if err != nil { err = fmt.Errorf("calling tool %q: %w", name, err) @@ -196,6 +197,13 @@ func (c *Client) CallTool(ctx context.Context, name string, args map[string]any) return standardCall[CallToolResult](ctx, c.conn, "tools/call", params) } +// NOTE: the following struct should consist of all fields of callToolParams except name and arguments. + +// CallToolOptions contains options to [Client.CallTools]. +type CallToolOptions struct { + ProgressToken any // string or int +} + // ListResources lists the resources that are currently available on the server. func (c *Client) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { return standardCall[ListResourcesResult](ctx, c.conn, "resources/list", params) diff --git a/internal/mcp/cmd_test.go b/internal/mcp/cmd_test.go index 764b77bd11f..cf899178ba1 100644 --- a/internal/mcp/cmd_test.go +++ b/internal/mcp/cmd_test.go @@ -52,7 +52,7 @@ func TestCmdTransport(t *testing.T) { if err := client.Start(ctx); err != nil { log.Fatal(err) } - got, err := client.CallTool(ctx, "greet", map[string]any{"name": "user"}) + got, err := client.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) if err != nil { log.Fatal(err) } diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 542ebc9b8f0..d82d3a3fec6 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -329,9 +329,8 @@ corresponding to the union of all properties for union elements. For brevity, only a few examples are shown here: ```go -type CallToolParams struct { - Arguments map[string]json.RawMessage `json:"arguments,omitempty"` - Name string `json:"name"` +type ReadResourceParams struct { + URI string `json:"uri"` } type CallToolResult struct { @@ -428,10 +427,7 @@ transport := mcp.NewCommandTransport(exec.Command("myserver")) session, err := client.Connect(ctx, transport) if err != nil { ... } // Call a tool on the server. -content, err := session.CallTool(ctx, &CallToolParams{ - Name: "greet", - Arguments: map[string]any{"name": "you"} , -}) +content, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}, nil}) ... return session.Close() ``` @@ -479,17 +475,20 @@ documentation. ### Spec Methods -As we saw above, the `ClientSession` method for the specification's -`CallTool` RPC takes a context and a params pointer as arguments, and returns a -result pointer and error: +In our SDK, RPC methods that are defined in the specification take a context and +a params pointer as arguments, and return a result pointer and error: ```go -func (*ClientSession) CallTool(context.Context, *CallToolParams) (*CallToolResult, error) +func (*ClientSession) ListTools(context.Context, *ListToolsParams) (*ListToolsResult, error) ``` -Our SDK has a method for every RPC in the spec, and their signatures all share -this form. To avoid boilerplate, we don't repeat this signature for RPCs -defined in the spec; readers may assume it when we mention a "spec method." +Our SDK has a method for every RPC in the spec, and except for `CallTool`, +their signatures all share this form. To avoid boilerplate, we don't repeat this +signature for RPCs defined in the spec; readers may assume it when we mention a +"spec method." + +`CallTool` is the only exception: for convenience, it takes the tool name and +arguments, with an options truct for additional request fields. Why do we use params instead of the full request? JSON-RPC requests consist of a method name and a set of parameters, and the method is already encoded in the Go method name. @@ -564,7 +563,7 @@ can cancel an operation by cancelling the associated context: ```go ctx, cancel := context.WithCancel(ctx) -go session.CallTool(ctx, "slow", map[string]any{}) +go session.CallTool(ctx, "slow", map[string]any{}, nil) cancel() ``` @@ -687,7 +686,7 @@ type Tool struct { Name string `json:"name"` } -type ToolHandler func(context.Context, *ServerSession, map[string]json.RawMessage) (*CallToolResult, error) +type ToolHandler func(context.Context, *ServerSession, *CallToolParams) (*CallToolResult, error) type ServerTool struct { Tool Tool diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index f62f0ae3c53..8703d9fc8c9 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -160,7 +160,7 @@ func TestEndToEnd(t *testing.T) { t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) } - gotHi, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}) + gotHi, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) if err != nil { t.Fatal(err) } @@ -171,7 +171,7 @@ func TestEndToEnd(t *testing.T) { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } - gotFail, err := c.CallTool(ctx, "fail", map[string]any{}) + gotFail, err := c.CallTool(ctx, "fail", map[string]any{}, nil) // Counter-intuitively, when a tool fails, we don't expect an RPC error for // call tool: instead, the failure is embedded in the result. if err != nil { @@ -317,12 +317,12 @@ func TestServerClosing(t *testing.T) { } wg.Done() }() - if _, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}); err != nil { + if _, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil); err != nil { t.Fatalf("after connecting: %v", err) } cc.Close() wg.Wait() - if _, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}); !errors.Is(err, ErrConnectionClosed) { + if _, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil); !errors.Is(err, ErrConnectionClosed) { t.Errorf("after disconnection, got error %v, want EOF", err) } } @@ -384,7 +384,7 @@ func TestCancellation(t *testing.T) { defer sc.Close() ctx, cancel := context.WithCancel(context.Background()) - go sc.CallTool(ctx, "slow", map[string]any{}) + go sc.CallTool(ctx, "slow", map[string]any{}, nil) <-start cancel() select { diff --git a/internal/mcp/server.go b/internal/mcp/server.go index a0e1b210c43..5e3f9e4c748 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -204,7 +204,7 @@ func (s *Server) callTool(ctx context.Context, cc *ServerConnection, params *Cal if !ok { return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, params.Name) } - return tool.Handler(ctx, cc, params.Arguments) + return tool.Handler(ctx, cc, params) } func (s *Server) listResources(_ context.Context, _ *ServerConnection, params *ListResourcesParams) (*ListResourcesResult, error) { diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index e596fe69c34..ab39089f9af 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -39,7 +39,7 @@ func ExampleServer() { log.Fatal(err) } - res, err := client.CallTool(ctx, "greet", map[string]any{"name": "user"}) + res, err := client.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) if err != nil { log.Fatal(err) } diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go index 397a1f8d5c1..b220b0ccfb0 100644 --- a/internal/mcp/sse_example_test.go +++ b/internal/mcp/sse_example_test.go @@ -40,7 +40,7 @@ func ExampleSSEHandler() { } defer client.Close() - res, err := client.CallTool(ctx, "add", map[string]any{"x": 1, "y": 2}) + res, err := client.CallTool(ctx, "add", map[string]any{"x": 1, "y": 2}, nil) if err != nil { log.Fatal(err) } diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index 52c50fb42de..edf1aadf991 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -43,7 +43,7 @@ func TestSSEServer(t *testing.T) { t.Fatal(err) } cc := <-conns - gotHi, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}) + gotHi, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) if err != nil { t.Fatal(err) } diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index 4bfce459ebb..49fefb4e0de 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -13,7 +13,7 @@ import ( ) // A ToolHandler handles a call to tools/call. -type ToolHandler func(context.Context, *ServerConnection, map[string]json.RawMessage) (*CallToolResult, error) +type ToolHandler func(context.Context, *ServerConnection, *CallToolParams) (*CallToolResult, error) // A Tool is a tool definition that is bound to a tool handler. type ServerTool struct { @@ -39,10 +39,10 @@ func NewTool[TReq any](name, description string, handler func(context.Context, * if err != nil { panic(err) } - wrapped := func(ctx context.Context, cc *ServerConnection, args map[string]json.RawMessage) (*CallToolResult, error) { + wrapped := func(ctx context.Context, cc *ServerConnection, params *CallToolParams) (*CallToolResult, error) { // For simplicity, just marshal and unmarshal the arguments. // This could be avoided in the future. - rawArgs, err := json.Marshal(args) + rawArgs, err := json.Marshal(params.Arguments) if err != nil { return nil, err } From cd1dd2807a3f70f62e74086fd820ff2de3f1258e Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 14 May 2025 07:24:51 -0400 Subject: [PATCH 062/196] internal/mcp: use RawMessage for tool args Change the type of CallToolParams.Arguments to json.RawMessage. Change-Id: Ief16ef261b9f41566a1a16efc60a9382a8d414eb Reviewed-on: https://go-review.googlesource.com/c/tools/+/672615 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/client.go | 14 +++++--------- internal/mcp/generate.go | 11 +++++++++-- internal/mcp/protocol.go | 4 ++-- internal/mcp/tool.go | 8 +------- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 6dc0ab3a1ff..9467faba6a0 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -181,18 +181,14 @@ func (c *Client) CallTool(ctx context.Context, name string, args map[string]any, err = fmt.Errorf("calling tool %q: %w", name, err) } }() - argsJSON := make(map[string]json.RawMessage) - for name, arg := range args { - argJSON, err := json.Marshal(arg) - if err != nil { - return nil, fmt.Errorf("marshaling argument %s: %v", name, err) - } - argsJSON[name] = argJSON - } + data, err := json.Marshal(args) + if err != nil { + return nil, fmt.Errorf("marshaling arguments: %w", err) + } params := &CallToolParams{ Name: name, - Arguments: argsJSON, + Arguments: json.RawMessage(data), } return standardCall[CallToolResult](ctx, c.conn, "tools/call", params) } diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index 342f1cc1122..bfeca7705b0 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -53,8 +53,15 @@ type config map[string]*typeConfig var declarations = config{ "Annotations": {}, "CallToolRequest": { - Name: "-", - Fields: config{"Params": {Name: "CallToolParams"}}, + Name: "-", + Fields: config{ + "Params": { + Name: "CallToolParams", + Fields: config{ + "Arguments": {Substitute: "json.RawMessage"}, + }, + }, + }, }, "CallToolResult": {}, "CancelledNotification": { diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index 95444a60edc..dcccceff509 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -29,8 +29,8 @@ type Annotations struct { } type CallToolParams struct { - Arguments map[string]json.RawMessage `json:"arguments,omitempty"` - Name string `json:"name"` + Arguments json.RawMessage `json:"arguments,omitempty"` + Name string `json:"name"` } // The server's response to a tool call. diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index 49fefb4e0de..9238f80c996 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -40,14 +40,8 @@ func NewTool[TReq any](name, description string, handler func(context.Context, * panic(err) } wrapped := func(ctx context.Context, cc *ServerConnection, params *CallToolParams) (*CallToolResult, error) { - // For simplicity, just marshal and unmarshal the arguments. - // This could be avoided in the future. - rawArgs, err := json.Marshal(params.Arguments) - if err != nil { - return nil, err - } var v TReq - if err := unmarshalSchema(rawArgs, schema, &v); err != nil { + if err := unmarshalSchema(params.Arguments, schema, &v); err != nil { return nil, err } content, err := handler(ctx, cc, v) From 283948d3bbbb4090261fcc948da88d2fed756264 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Wed, 14 May 2025 14:13:30 +0000 Subject: [PATCH 063/196] internal/mcp: document iterators Change-Id: I5f7a6bad296d23c1bc075eb767a243c24a42668c Reviewed-on: https://go-review.googlesource.com/c/tools/+/672675 Reviewed-by: Jonathan Amsterdam Reviewed-by: Robert Findley Auto-Submit: Sam Thanawalla TryBot-Bypass: Sam Thanawalla Commit-Queue: Sam Thanawalla --- internal/mcp/design/design.md | 38 ++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index d82d3a3fec6..db7aa3c5dc1 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -496,6 +496,22 @@ Technically, the MCP spec could add a field to a request while preserving backwa compatibility, which would break the Go SDK's compatibility. But in the unlikely event that were to happen, we would add that field to the Params struct. +#### Iterator Methods + +For convenience, iterator methods handle pagination for the `List` spec methods +automatically, traversing all pages. If Params are supplied, iteration begins +from the provided cursor (if present). + +```go +func (*ClientSession) Tools(context.Context, *ListToolsParams) iter.Seq2[Tool, error] + +func (*ClientSession) Prompts(context.Context, *ListPromptsParams) iter.Seq2[Prompt, error] + +func (*ClientSession) Resources(context.Context, *ListResourceParams) iter.Seq2[Resource, error] + +func (*ClientSession) ResourceTemplates(context.Context, *ListResourceTemplatesParams) iter.Seq2[ResourceTemplate, error] +``` + ### Middleware We provide a mechanism to add MCP-level middleware, which runs after the @@ -793,6 +809,9 @@ Schemas are validated on the server before the tool handler is called. Since all the fields of the Tool struct are exported, a Tool can also be created directly with assignment or a struct literal. +Clients can call the spec method `ListTools` or an iterator method `Tools` +to list the available tools. + **Differences from mcp-go**: using variadic options to configure tools was signficantly inspired by mcp-go. However, the distinction between `ToolOption` and `SchemaOption` allows for recursive application of schema options. @@ -847,7 +866,8 @@ server.AddPrompts( server.RemovePrompts("code_review") ``` -Clients can call the spec method `ListPrompts` to list the available prompts and the spec method `GetPrompt` to get one. +Clients can call the spec method `ListPrompts` or an iterator method `Prompts` +to list the available prompts and the spec method `GetPrompt` to get one. **Differences from mcp-go**: We provide a `NewPrompt` helper to bind a prompt handler to a Go function using reflection to derive its arguments. We provide @@ -907,6 +927,8 @@ Servers support all of the resource-related spec methods: - `ReadResource` to get the contents of a resource. - `Subscribe` and `Unsubscribe` to manage subscriptions on resources. +We also provide iterator methods `Resources` and `ResourceTemplates`. + `ReadResource` checks the incoming URI against the server's list of resources and resource templates to make sure it matches one of them, then returns the result of calling the associated reader function. @@ -1042,16 +1064,4 @@ more pages exist. In addition to the `List` methods, the SDK provides an iterator method for each list operation. This simplifies pagination for cients by automatically handling -the underlying pagination logic. - -For example, we if we have a List method like this: - -```go -func (*ClientSession) ListTools(context.Context, *ListToolsParams) (*ListToolsResult, error) -``` - -We will also provide an iterator method like this: - -```go -func (*ClientSession) Tools(context.Context, *ListToolsParams) iter.Seq2[Tool, error] -``` +the underlying pagination logic. See [Iterator Methods](#iterator-methods) above. From 274b895dc1c5344c91931b1e51b2ae495ad26eef Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 13 May 2025 12:23:19 -0400 Subject: [PATCH 064/196] internal/mcp: resource tweaks Get the name and signature of the handler right. Change-Id: I743aaaee45cc07a741ab8ad16e90e1b1d5519b64 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672336 LUCI-TryBot-Result: Go LUCI Reviewed-by: Sam Thanawalla Reviewed-by: Robert Findley --- internal/mcp/design/design.md | 67 ++++++++++++++++++++--------------- internal/mcp/mcp_test.go | 6 ++-- internal/mcp/server.go | 12 +++---- 3 files changed, 47 insertions(+), 38 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index db7aa3c5dc1..bea622955f5 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -875,52 +875,61 @@ handler to a Go function using reflection to derive its arguments. We provide ### Resources and resource templates -To add a resource or resource template to a server, users call the `AddResource` and -`AddResourceTemplate` methods, passing the resource or template and a function for reading it: +In our design, each resource and resource template is associated with a function that reads it, +with this signature: +```go +type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) (*ReadResourceResult, error) +``` +The arguments include the `ServerSession` so the handler can observe the client's roots. +The handler should return the resource contents in a `ReadResourceResult`, calling either `NewTextResourceContents` +or `NewBlobResourceContents`. If the handler omits the URI or MIME type, the server will populate them from the +resource. +The `ServerResource` and `ServerResourceTemplate` types hold the association between the resource and its handler: ```go -type ReadResourceHandler func(context.Context, *ServerSession, *Resource, *ReadResourceParams) (*ReadResourceResult, error) +type ServerResource struct { + Resource Resource + Handler ResourceHandler +} -func (*Server) AddResource(*Resource, ReadResourceHandler) -func (*Server) AddResourceTemplate(*ResourceTemplate, ReadResourceHandler) +type ServerResourceTemplate struct { + Template ResourceTemplate + Handler ResourceHandler +} ``` -The `Resource` is passed to the reader function even though it is redundant (the function could have closed over it) -so a single handler can support multiple resources. -If the incoming resource matches a template, a `Resource` argument is constructed -from the fields in the `ResourceTemplate`. -The `ServerSession` argument is there so the reader can observe the client's roots. +To add a resource or resource template to a server, users call the `AddResources` and +`AddResourceTemplates` methods with one or more `ServerResource`s or `ServerResourceTemplate`s: + +```go +func (*Server) AddResources(...*ServerResource) +func (*Server) AddResourceTemplates(...*ServerResourceTemplate) + +func (s *Server) RemoveResources(uris ...string) +func (s *Server) RemoveResourceTemplates(uriTemplates ...string) +``` -To read files from the local filesystem, we recommend using `FileReadResourceHandler` to construct a handler: +The `ReadResource` method finds a resource or resource template matching the argument URI and calls +its assocated handler. +If the argument URI matches a template, the `Resource` argument to the handler is constructed +from the fields in the `ResourceTemplate`. +To read files from the local filesystem, we recommend using `FileResourceHandler` to construct a handler: ```go -// FileReadResourceHandler returns a ReadResourceHandler that reads paths using dir as a root directory. +// FileResourceHandler returns a ResourceHandler that reads paths using dir as a root directory. // It protects against path traversal attacks. // It will not read any file that is not in the root set of the client requesting the resource. -func (*Server) FileReadResourceHandler(dir string) ReadResourceHandler +func (*Server) FileResourceHandler(dir string) ResourceHandler ``` - -It guards against [path traversal attacks](https://go.dev/blog/osroot) -and observes the client's roots. Here is an example: ```go // Safely read "/public/puppies.txt". -s.AddResource( - &mcp.Resource{URI: "file:///puppies.txt"}, - s.FileReadResourceHandler("/public")) -``` - -There are also server methods to remove resources and resource templates. - -```go -func (*Server) RemoveResources(uris ...string) -func (*Server) RemoveResourceTemplates(names ...string) +s.AddResources(&mcp.ServerResource{ + Resource: mcp.Resource{URI: "file:///puppies.txt"}, + Handler: s.FileReadResourceHandler("/public")}) ``` -Resource templates don't have unique identifiers, so removing a name will remove all -resource templates with that name. - Servers support all of the resource-related spec methods: - `ListResources` and `ListResourceTemplates` for listings. diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 8703d9fc8c9..50b64881087 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -198,15 +198,15 @@ func TestEndToEnd(t *testing.T) { URI: "file:///nonexistent.txt", } - readHandler := func(_ context.Context, r *Resource, _ *ReadResourceParams) (*ReadResourceResult, error) { - if r.URI == "file:///file1.txt" { + readHandler := func(_ context.Context, _ *ServerConnection, p *ReadResourceParams) (*ReadResourceResult, error) { + if p.URI == "file:///file1.txt" { return &ReadResourceResult{ Contents: &ResourceContents{ Text: "file contents", }, }, nil } - return nil, ResourceNotFoundError(r.URI) + return nil, ResourceNotFoundError(p.URI) } s.AddResources( &ServerResource{resource1, readHandler}, diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 5e3f9e4c748..3dd4f102485 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -119,18 +119,18 @@ func ResourceNotFoundError(uri string) error { // The immediate problem is that jsonprc2 defines -32002 as "server closing". const codeResourceNotFound = -31002 -// A ReadResourceHandler is a function that reads a resource. +// A ResourceHandler is a function that reads a resource. // If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. -type ReadResourceHandler func(context.Context, *Resource, *ReadResourceParams) (*ReadResourceResult, error) +type ResourceHandler func(context.Context, *ServerConnection, *ReadResourceParams) (*ReadResourceResult, error) // A ServerResource associates a Resource with its handler. type ServerResource struct { Resource *Resource - Handler ReadResourceHandler + Handler ResourceHandler } // AddResource adds the given resource to the server and associates it with -// a [ReadResourceHandler], which will be called when the client calls [ClientSession.ReadResource]. +// a [ResourceHandler], which will be called when the client calls [ClientSession.ReadResource]. // If a resource with the same URI already exists, this one replaces it. // AddResource panics if a resource URI is invalid or not absolute (has an empty scheme). func (s *Server) AddResources(resources ...*ServerResource) { @@ -217,7 +217,7 @@ func (s *Server) listResources(_ context.Context, _ *ServerConnection, params *L return res, nil } -func (s *Server) readResource(ctx context.Context, _ *ServerConnection, params *ReadResourceParams) (*ReadResourceResult, error) { +func (s *Server) readResource(ctx context.Context, ss *ServerConnection, params *ReadResourceParams) (*ReadResourceResult, error) { uri := params.URI // Look up the resource URI in the list we have. // This is a security check as well as an information lookup. @@ -229,7 +229,7 @@ func (s *Server) readResource(ctx context.Context, _ *ServerConnection, params * // Treat an unregistered resource the same as a registered one that couldn't be found. return nil, ResourceNotFoundError(uri) } - res, err := resource.Handler(ctx, resource.Resource, params) + res, err := resource.Handler(ctx, ss, params) if err != nil { return nil, err } From c905b91fc973580c9eff98577847c177699be39f Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 13 May 2025 17:30:55 -0400 Subject: [PATCH 065/196] internal/mcp: design.md: clarify backward compatibility guarantee Change-Id: I2cb8be82a91dea7755ea5514ee1fbe1e937135c9 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672376 Reviewed-by: Robert Findley Reviewed-by: Sam Thanawalla TryBot-Bypass: Jonathan Amsterdam Commit-Queue: Jonathan Amsterdam --- internal/mcp/design/design.md | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index bea622955f5..558161dfb64 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -34,6 +34,7 @@ compatible with mcp-go, translating between them should be straightforward in most cases. (Later, we will provide a detailed translation guide.) + # Requirements These may be obvious, but it's worthwhile to define goals for an official MCP @@ -483,12 +484,19 @@ func (*ClientSession) ListTools(context.Context, *ListToolsParams) (*ListToolsRe ``` Our SDK has a method for every RPC in the spec, and except for `CallTool`, -their signatures all share this form. To avoid boilerplate, we don't repeat this +their signatures all share this form. +We do this, rather than providing more convenient shortcut signatures, +to maintain backward compatibility if the spec makes backward-compatible changes +such as adding a new property to the request parameters. +To avoid boilerplate, we don't repeat this signature for RPCs defined in the spec; readers may assume it when we mention a "spec method." `CallTool` is the only exception: for convenience, it takes the tool name and arguments, with an options truct for additional request fields. +Our SDK has a method for every RPC in the spec, and their signatures all share +this form. To avoid boilerplate, we don't repeat this signature for RPCs +defined in the spec; readers may assume it when we mention a "spec method." Why do we use params instead of the full request? JSON-RPC requests consist of a method name and a set of parameters, and the method is already encoded in the Go method name. @@ -496,6 +504,14 @@ Technically, the MCP spec could add a field to a request while preserving backwa compatibility, which would break the Go SDK's compatibility. But in the unlikely event that were to happen, we would add that field to the Params struct. +We believe that any change to the spec that would require callers to pass a new a +parameter is not backward compatible. Therefore, it will always work to pass +`nil` for any `XXXParams` argument that isn't currently necessary. For example, it is okay to call `Ping` like so: + +```go +err := session.Ping(ctx, nil)` +``` + #### Iterator Methods For convenience, iterator methods handle pagination for the `List` spec methods @@ -512,6 +528,7 @@ func (*ClientSession) Resources(context.Context, *ListResourceParams) iter.Seq2[ func (*ClientSession) ResourceTemplates(context.Context, *ListResourceTemplatesParams) iter.Seq2[ResourceTemplate, error] ``` + ### Middleware We provide a mechanism to add MCP-level middleware, which runs after the From 9a093bc3c3808c9c641190e23040065b1a33785d Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 14 May 2025 14:00:02 +0000 Subject: [PATCH 066/196] internal/mcp: rename ServerConnection->ServerSession Bring the Server implementation up to date with the spec by renaming ServerConnection to ServerSession. Also update variable names, which were in various states of disarray from the multiple renamings of this type. Change-Id: I4aec883db6d0500fb91e6daa51b702614877cf8a Reviewed-on: https://go-review.googlesource.com/c/tools/+/672655 Auto-Submit: Robert Findley TryBot-Bypass: Robert Findley Commit-Queue: Robert Findley Reviewed-by: Jonathan Amsterdam --- gopls/internal/mcp/mcp.go | 4 +- internal/mcp/README.md | 2 +- internal/mcp/examples/hello/main.go | 4 +- internal/mcp/examples/sse/main.go | 2 +- internal/mcp/mcp.go | 8 ++-- internal/mcp/mcp_test.go | 32 +++++++------- internal/mcp/prompt.go | 6 +-- internal/mcp/prompt_test.go | 2 +- internal/mcp/server.go | 65 ++++++++++++++--------------- internal/mcp/server_example_test.go | 6 +-- internal/mcp/sse.go | 8 ++-- internal/mcp/sse_example_test.go | 2 +- internal/mcp/sse_test.go | 4 +- internal/mcp/tool.go | 6 +-- internal/mcp/tool_test.go | 2 +- 15 files changed, 76 insertions(+), 77 deletions(-) diff --git a/gopls/internal/mcp/mcp.go b/gopls/internal/mcp/mcp.go index cd8236e53c7..8d1b115ad34 100644 --- a/gopls/internal/mcp/mcp.go +++ b/gopls/internal/mcp/mcp.go @@ -140,8 +140,8 @@ type HelloParams struct { Name string `json:"name" mcp:"the name to say hi to"` } -func helloHandler(session *cache.Session) func(ctx context.Context, cc *mcp.ServerConnection, request *HelloParams) ([]*mcp.Content, error) { - return func(ctx context.Context, cc *mcp.ServerConnection, request *HelloParams) ([]*mcp.Content, error) { +func helloHandler(session *cache.Session) func(ctx context.Context, cc *mcp.ServerSession, request *HelloParams) ([]*mcp.Content, error) { + return func(ctx context.Context, cc *mcp.ServerSession, request *HelloParams) ([]*mcp.Content, error) { return []*mcp.Content{ mcp.NewTextContent("Hi " + request.Name + ", this is lsp session " + session.ID()), }, nil diff --git a/internal/mcp/README.md b/internal/mcp/README.md index 34e53d86230..1f616746bb1 100644 --- a/internal/mcp/README.md +++ b/internal/mcp/README.md @@ -59,7 +59,7 @@ type HiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) ([]mcp.Content, error) { +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *HiParams) ([]mcp.Content, error) { return []mcp.Content{ mcp.TextContent{Text: "Hi " + params.Name}, }, nil diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index b71f86d242a..8fbd74c66ad 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -20,13 +20,13 @@ type HiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) ([]*mcp.Content, error) { +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *HiParams) ([]*mcp.Content, error) { return []*mcp.Content{ mcp.NewTextContent("Hi " + params.Name), }, nil } -func PromptHi(ctx context.Context, cc *mcp.ServerConnection, params *HiParams) (*mcp.GetPromptResult, error) { +func PromptHi(ctx context.Context, cc *mcp.ServerSession, params *HiParams) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Description: "Code review prompt", Messages: []*mcp.PromptMessage{ diff --git a/internal/mcp/examples/sse/main.go b/internal/mcp/examples/sse/main.go index ba793936621..b5a1cec1aac 100644 --- a/internal/mcp/examples/sse/main.go +++ b/internal/mcp/examples/sse/main.go @@ -19,7 +19,7 @@ type SayHiParams struct { Name string `json:"name" mcp:"the name to say hi to"` } -func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *SayHiParams) ([]*mcp.Content, error) { +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *SayHiParams) ([]*mcp.Content, error) { return []*mcp.Content{ mcp.NewTextContent("Hi " + params.Name), }, nil diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go index 201f42092f2..d6ee36915b5 100644 --- a/internal/mcp/mcp.go +++ b/internal/mcp/mcp.go @@ -10,9 +10,9 @@ // To get started, create either a [Client] or [Server], and connect it to a // peer using a [Transport]. The diagram below illustrates how this works: // -// Client Server -// ⇅ (jsonrpc2) ⇅ -// Client Transport ⇄ Server Transport ⇄ ServerConnection +// Client Server +// ⇅ (jsonrpc2) ⇅ +// Client Transport ⇄ Server Transport ⇄ ServerSession // // A [Client] is an MCP client, which can be configured with various client // capabilities. Clients may be connected to a [Server] instance @@ -21,7 +21,7 @@ // Similarly, a [Server] is an MCP server, which can be configured with various // server capabilities. Servers may be connected to one or more [Client] // instances using the [Server.Connect] method, which creates a -// [ServerConnection]. +// [ServerSession]. // // A [Transport] connects a bidirectional [Stream] of jsonrpc2 messages. In // practice, transports in the MCP spec are are either client transports or diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 50b64881087..253254adcce 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -24,7 +24,7 @@ type hiParams struct { Name string } -func sayHi(ctx context.Context, cc *ServerConnection, v hiParams) ([]*Content, error) { +func sayHi(ctx context.Context, cc *ServerSession, v hiParams) ([]*Content, error) { if err := cc.Ping(ctx, nil); err != nil { return nil, fmt.Errorf("ping failed: %v", err) } @@ -43,13 +43,13 @@ func TestEndToEnd(t *testing.T) { // The 'fail' tool returns this error. failure := errors.New("mcp failure") s.AddTools( - NewTool("fail", "just fail", func(context.Context, *ServerConnection, struct{}) ([]*Content, error) { + NewTool("fail", "just fail", func(context.Context, *ServerSession, struct{}) ([]*Content, error) { return nil, failure }), ) s.AddPrompts( - NewPrompt("code_review", "do a code review", func(_ context.Context, _ *ServerConnection, params struct{ Code string }) (*GetPromptResult, error) { + NewPrompt("code_review", "do a code review", func(_ context.Context, _ *ServerSession, params struct{ Code string }) (*GetPromptResult, error) { return &GetPromptResult{ Description: "Code review prompt", Messages: []*PromptMessage{ @@ -57,17 +57,17 @@ func TestEndToEnd(t *testing.T) { }, }, nil }), - NewPrompt("fail", "", func(_ context.Context, _ *ServerConnection, params struct{}) (*GetPromptResult, error) { + NewPrompt("fail", "", func(_ context.Context, _ *ServerSession, params struct{}) (*GetPromptResult, error) { return nil, failure }), ) // Connect the server. - sc, err := s.Connect(ctx, st, nil) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } - if got := slices.Collect(s.Clients()); len(got) != 1 { + if got := slices.Collect(s.Sessions()); len(got) != 1 { t.Errorf("after connection, Clients() has length %d, want 1", len(got)) } @@ -75,7 +75,7 @@ func TestEndToEnd(t *testing.T) { var clientWG sync.WaitGroup clientWG.Add(1) go func() { - if err := sc.Wait(); err != nil { + if err := ss.Wait(); err != nil { t.Errorf("server failed: %v", err) } clientWG.Done() @@ -198,7 +198,7 @@ func TestEndToEnd(t *testing.T) { URI: "file:///nonexistent.txt", } - readHandler := func(_ context.Context, _ *ServerConnection, p *ReadResourceParams) (*ReadResourceResult, error) { + readHandler := func(_ context.Context, _ *ServerSession, p *ReadResourceParams) (*ReadResourceResult, error) { if p.URI == "file:///file1.txt" { return &ReadResourceResult{ Contents: &ResourceContents{ @@ -249,9 +249,9 @@ func TestEndToEnd(t *testing.T) { } }) t.Run("roots", func(t *testing.T) { - // Take the server's first ServerConnection. - var sc *ServerConnection - for sc = range s.Clients() { + // Take the server's first ServerSession. + var sc *ServerSession + for sc = range s.Sessions() { break } @@ -272,7 +272,7 @@ func TestEndToEnd(t *testing.T) { // After disconnecting, neither client nor server should have any // connections. - for range s.Clients() { + for range s.Sessions() { t.Errorf("unexpected client after disconnection") } } @@ -282,7 +282,7 @@ func TestEndToEnd(t *testing.T) { // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerConnection, *Client) { +func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *Client) { t.Helper() ctx := context.Background() @@ -292,7 +292,7 @@ func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerConnection, *Cl // The 'greet' tool says hi. s.AddTools(tools...) - cc, err := s.Connect(ctx, st, nil) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -301,7 +301,7 @@ func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerConnection, *Cl if err := c.Start(ctx); err != nil { t.Fatal(err) } - return cc, c + return ss, c } func TestServerClosing(t *testing.T) { @@ -370,7 +370,7 @@ func TestCancellation(t *testing.T) { cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, cc *ServerConnection, v struct{}) ([]*Content, error) { + slowRequest := func(ctx context.Context, cc *ServerSession, v struct{}) ([]*Content, error) { start <- struct{}{} select { case <-ctx.Done(): diff --git a/internal/mcp/prompt.go b/internal/mcp/prompt.go index c813cee5413..d6a2b117269 100644 --- a/internal/mcp/prompt.go +++ b/internal/mcp/prompt.go @@ -16,7 +16,7 @@ import ( ) // A PromptHandler handles a call to prompts/get. -type PromptHandler func(context.Context, *ServerConnection, map[string]string) (*GetPromptResult, error) +type PromptHandler func(context.Context, *ServerSession, map[string]string) (*GetPromptResult, error) // A Prompt is a prompt definition bound to a prompt handler. type ServerPrompt struct { @@ -32,7 +32,7 @@ type ServerPrompt struct { // of type string or *string. The argument names for the resulting prompt // definition correspond to the JSON names of the request fields, and any // fields that are not marked "omitempty" are considered required. -func NewPrompt[TReq any](name, description string, handler func(context.Context, *ServerConnection, TReq) (*GetPromptResult, error), opts ...PromptOption) *ServerPrompt { +func NewPrompt[TReq any](name, description string, handler func(context.Context, *ServerSession, TReq) (*GetPromptResult, error), opts ...PromptOption) *ServerPrompt { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) @@ -60,7 +60,7 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, Required: required[name], }) } - prompt.Handler = func(ctx context.Context, cc *ServerConnection, args map[string]string) (*GetPromptResult, error) { + prompt.Handler = func(ctx context.Context, cc *ServerSession, args map[string]string) (*GetPromptResult, error) { // For simplicity, just marshal and unmarshal the arguments. // This could be avoided in the future. rawArgs, err := json.Marshal(args) diff --git a/internal/mcp/prompt_test.go b/internal/mcp/prompt_test.go index 88220fa577b..fa2e3bc0a71 100644 --- a/internal/mcp/prompt_test.go +++ b/internal/mcp/prompt_test.go @@ -13,7 +13,7 @@ import ( ) // testPromptHandler is used for type inference in TestNewPrompt. -func testPromptHandler[T any](context.Context, *mcp.ServerConnection, T) (*mcp.GetPromptResult, error) { +func testPromptHandler[T any](context.Context, *mcp.ServerSession, T) (*mcp.GetPromptResult, error) { panic("not implemented") } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 3dd4f102485..0b741cdfd56 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -30,7 +30,7 @@ type Server struct { prompts *featureSet[*ServerPrompt] tools *featureSet[*ServerTool] resources *featureSet[*ServerResource] - conns []*ServerConnection + conns []*ServerSession } // ServerOptions is used to configure behavior of the server. @@ -121,7 +121,7 @@ const codeResourceNotFound = -31002 // A ResourceHandler is a function that reads a resource. // If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. -type ResourceHandler func(context.Context, *ServerConnection, *ReadResourceParams) (*ReadResourceResult, error) +type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) (*ReadResourceResult, error) // A ServerResource associates a Resource with its handler. type ServerResource struct { @@ -157,16 +157,15 @@ func (s *Server) RemoveResources(uris ...string) { s.resources.remove(uris...) } -// Clients returns an iterator that yields the current set of client -// connections. -func (s *Server) Clients() iter.Seq[*ServerConnection] { +// Sessions returns an iterator that yields the current set of server sessions. +func (s *Server) Sessions() iter.Seq[*ServerSession] { s.mu.Lock() clients := slices.Clone(s.conns) s.mu.Unlock() return slices.Values(clients) } -func (s *Server) listPrompts(_ context.Context, _ *ServerConnection, params *ListPromptsParams) (*ListPromptsResult, error) { +func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPromptsParams) (*ListPromptsResult, error) { s.mu.Lock() defer s.mu.Unlock() res := new(ListPromptsResult) @@ -176,7 +175,7 @@ func (s *Server) listPrompts(_ context.Context, _ *ServerConnection, params *Lis return res, nil } -func (s *Server) getPrompt(ctx context.Context, cc *ServerConnection, params *GetPromptParams) (*GetPromptResult, error) { +func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { s.mu.Lock() prompt, ok := s.prompts.get(params.Name) s.mu.Unlock() @@ -187,7 +186,7 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerConnection, params *Ge return prompt.Handler(ctx, cc, params.Arguments) } -func (s *Server) listTools(_ context.Context, _ *ServerConnection, params *ListToolsParams) (*ListToolsResult, error) { +func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() res := new(ListToolsResult) @@ -197,7 +196,7 @@ func (s *Server) listTools(_ context.Context, _ *ServerConnection, params *ListT return res, nil } -func (s *Server) callTool(ctx context.Context, cc *ServerConnection, params *CallToolParams) (*CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallToolParams) (*CallToolResult, error) { s.mu.Lock() tool, ok := s.tools.get(params.Name) s.mu.Unlock() @@ -207,7 +206,7 @@ func (s *Server) callTool(ctx context.Context, cc *ServerConnection, params *Cal return tool.Handler(ctx, cc, params) } -func (s *Server) listResources(_ context.Context, _ *ServerConnection, params *ListResourcesParams) (*ListResourcesResult, error) { +func (s *Server) listResources(_ context.Context, _ *ServerSession, params *ListResourcesParams) (*ListResourcesResult, error) { s.mu.Lock() defer s.mu.Unlock() res := new(ListResourcesResult) @@ -217,7 +216,7 @@ func (s *Server) listResources(_ context.Context, _ *ServerConnection, params *L return res, nil } -func (s *Server) readResource(ctx context.Context, ss *ServerConnection, params *ReadResourceParams) (*ReadResourceResult, error) { +func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *ReadResourceParams) (*ReadResourceResult, error) { uri := params.URI // Look up the resource URI in the list we have. // This is a security check as well as an information lookup. @@ -250,29 +249,29 @@ func (s *Server) readResource(ctx context.Context, ss *ServerConnection, params // // Run blocks until the client terminates the connection. func (s *Server) Run(ctx context.Context, t Transport, opts *ConnectionOptions) error { - cc, err := s.Connect(ctx, t, opts) + ss, err := s.Connect(ctx, t, opts) if err != nil { return err } - return cc.Wait() + return ss.Wait() } -// bind implements the binder[*ServerConnection] interface, so that Servers can +// bind implements the binder[*ServerSession] interface, so that Servers can // be connected using [connect]. -func (s *Server) bind(conn *jsonrpc2.Connection) *ServerConnection { - cc := &ServerConnection{conn: conn, server: s} +func (s *Server) bind(conn *jsonrpc2.Connection) *ServerSession { + cc := &ServerSession{conn: conn, server: s} s.mu.Lock() s.conns = append(s.conns, cc) s.mu.Unlock() return cc } -// disconnect implements the binder[*ServerConnection] interface, so that +// disconnect implements the binder[*ServerSession] interface, so that // Servers can be connected using [connect]. -func (s *Server) disconnect(cc *ServerConnection) { +func (s *Server) disconnect(cc *ServerSession) { s.mu.Lock() defer s.mu.Unlock() - s.conns = slices.DeleteFunc(s.conns, func(cc2 *ServerConnection) bool { + s.conns = slices.DeleteFunc(s.conns, func(cc2 *ServerSession) bool { return cc2 == cc }) } @@ -283,17 +282,17 @@ func (s *Server) disconnect(cc *ServerConnection) { // It returns a connection object that may be used to terminate the connection // (with [Connection.Close]), or await client termination (with // [Connection.Wait]). -func (s *Server) Connect(ctx context.Context, t Transport, opts *ConnectionOptions) (*ServerConnection, error) { +func (s *Server) Connect(ctx context.Context, t Transport, opts *ConnectionOptions) (*ServerSession, error) { return connect(ctx, t, opts, s) } -// A ServerConnection is a connection from a single MCP client. Its methods can -// be used to send requests or notifications to the client. Create a connection -// by calling [Server.Connect]. +// A ServerSession is a logical connection from a single MCP client. Its +// methods can be used to send requests or notifications to the client. Create +// a session by calling [Server.Connect]. // -// Call [ServerConnection.Close] to close the connection, or await client -// termination with [ServerConnection.Wait]. -type ServerConnection struct { +// Call [ServerSession.Close] to close the connection, or await client +// termination with [ServerSession.Wait]. +type ServerSession struct { server *Server conn *jsonrpc2.Connection @@ -303,15 +302,15 @@ type ServerConnection struct { } // Ping makes an MCP "ping" request to the client. -func (cc *ServerConnection) Ping(ctx context.Context, _ *PingParams) error { +func (cc *ServerSession) Ping(ctx context.Context, _ *PingParams) error { return call(ctx, cc.conn, "ping", nil, nil) } -func (cc *ServerConnection) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { +func (cc *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { return standardCall[ListRootsResult](ctx, cc.conn, "roots/list", params) } -func (cc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { +func (cc *ServerSession) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { cc.mu.Lock() initialized := cc.initialized cc.mu.Unlock() @@ -362,7 +361,7 @@ func (cc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) ( return nil, jsonrpc2.ErrNotHandled } -func (cc *ServerConnection) initialize(ctx context.Context, _ *ServerConnection, params *initializeParams) (*initializeResult, error) { +func (cc *ServerSession) initialize(ctx context.Context, _ *ServerSession, params *initializeParams) (*initializeResult, error) { cc.mu.Lock() cc.initializeParams = params cc.mu.Unlock() @@ -400,12 +399,12 @@ func (cc *ServerConnection) initialize(ctx context.Context, _ *ServerConnection, // Close performs a graceful shutdown of the connection, preventing new // requests from being handled, and waiting for ongoing requests to return. // Close then terminates the connection. -func (cc *ServerConnection) Close() error { +func (cc *ServerSession) Close() error { return cc.conn.Close() } // Wait waits for the connection to be closed by the client. -func (cc *ServerConnection) Wait() error { +func (cc *ServerSession) Wait() error { return cc.conn.Wait() } @@ -413,7 +412,7 @@ func (cc *ServerConnection) Wait() error { // // Importantly, it returns nil if the handler returned an error, which is a // requirement of the jsonrpc2 package. -func dispatch[TParams, TResult any](ctx context.Context, conn *ServerConnection, req *jsonrpc2.Request, f func(context.Context, *ServerConnection, TParams) (TResult, error)) (any, error) { +func dispatch[TParams, TResult any](ctx context.Context, conn *ServerSession, req *jsonrpc2.Request, f func(context.Context, *ServerSession, TParams) (TResult, error)) (any, error) { var params TParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { return nil, err diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index ab39089f9af..5dede1f78f0 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -16,7 +16,7 @@ type SayHiParams struct { Name string `json:"name" mcp:"the name to say hi to"` } -func SayHi(ctx context.Context, cc *mcp.ServerConnection, params *SayHiParams) ([]*mcp.Content, error) { +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *SayHiParams) ([]*mcp.Content, error) { return []*mcp.Content{ mcp.NewTextContent("Hi " + params.Name), }, nil @@ -29,7 +29,7 @@ func ExampleServer() { server := mcp.NewServer("greeter", "v0.0.1", nil) server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) - clientConnection, err := server.Connect(ctx, serverTransport, nil) + serverSession, err := server.Connect(ctx, serverTransport, nil) if err != nil { log.Fatal(err) } @@ -46,7 +46,7 @@ func ExampleServer() { fmt.Println(res.Content[0].Text) client.Close() - clientConnection.Wait() + serverSession.Wait() // Output: Hi user } diff --git a/internal/mcp/sse.go b/internal/mcp/sse.go index c2acdd5c201..f1f657f94bb 100644 --- a/internal/mcp/sse.go +++ b/internal/mcp/sse.go @@ -64,7 +64,7 @@ func writeEvent(w io.Writer, evt event) (int, error) { // https://modelcontextprotocol.io/specification/2024-11-05/basic/transports type SSEHandler struct { getServer func(request *http.Request) *Server - onConnection func(*ServerConnection) // for testing; must not block + onConnection func(*ServerSession) // for testing; must not block mu sync.Mutex sessions map[string]*SSEServerTransport @@ -229,15 +229,15 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { // TODO(hxjiang): getServer returns nil will panic. server := h.getServer(req) - cc, err := server.Connect(req.Context(), transport, nil) + ss, err := server.Connect(req.Context(), transport, nil) if err != nil { http.Error(w, "connection failed", http.StatusInternalServerError) return } if h.onConnection != nil { - h.onConnection(cc) + h.onConnection(ss) } - defer cc.Close() // close the transport when the GET exits + defer ss.Close() // close the transport when the GET exits select { case <-req.Context().Done(): diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go index b220b0ccfb0..391746f58fd 100644 --- a/internal/mcp/sse_example_test.go +++ b/internal/mcp/sse_example_test.go @@ -18,7 +18,7 @@ type AddParams struct { X, Y int } -func Add(ctx context.Context, cc *mcp.ServerConnection, params *AddParams) ([]*mcp.Content, error) { +func Add(ctx context.Context, cc *mcp.ServerSession, params *AddParams) ([]*mcp.Content, error) { return []*mcp.Content{ mcp.NewTextContent(fmt.Sprintf("%d", params.X+params.Y)), }, nil diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index edf1aadf991..661cb5436f8 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -23,8 +23,8 @@ func TestSSEServer(t *testing.T) { sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) - conns := make(chan *ServerConnection, 1) - sseHandler.onConnection = func(cc *ServerConnection) { + conns := make(chan *ServerSession, 1) + sseHandler.onConnection = func(cc *ServerSession) { select { case conns <- cc: default: diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index 9238f80c996..f7b04660f08 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -13,7 +13,7 @@ import ( ) // A ToolHandler handles a call to tools/call. -type ToolHandler func(context.Context, *ServerConnection, *CallToolParams) (*CallToolResult, error) +type ToolHandler func(context.Context, *ServerSession, *CallToolParams) (*CallToolResult, error) // A Tool is a tool definition that is bound to a tool handler. type ServerTool struct { @@ -34,12 +34,12 @@ type ServerTool struct { // // TODO: just have the handler return a CallToolResult: returning []Content is // going to be inconsistent with other server features. -func NewTool[TReq any](name, description string, handler func(context.Context, *ServerConnection, TReq) ([]*Content, error), opts ...ToolOption) *ServerTool { +func NewTool[TReq any](name, description string, handler func(context.Context, *ServerSession, TReq) ([]*Content, error), opts ...ToolOption) *ServerTool { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) } - wrapped := func(ctx context.Context, cc *ServerConnection, params *CallToolParams) (*CallToolResult, error) { + wrapped := func(ctx context.Context, cc *ServerSession, params *CallToolParams) (*CallToolResult, error) { var v TReq if err := unmarshalSchema(params.Arguments, schema, &v); err != nil { return nil, err diff --git a/internal/mcp/tool_test.go b/internal/mcp/tool_test.go index a4202a221fe..88694af3e75 100644 --- a/internal/mcp/tool_test.go +++ b/internal/mcp/tool_test.go @@ -15,7 +15,7 @@ import ( ) // testToolHandler is used for type inference in TestNewTool. -func testToolHandler[T any](context.Context, *mcp.ServerConnection, T) ([]*mcp.Content, error) { +func testToolHandler[T any](context.Context, *mcp.ServerSession, T) ([]*mcp.Content, error) { panic("not implemented") } From cb65fbb91b670e110d685c28dadf06a3c4c8f0a9 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 14 May 2025 14:59:59 +0000 Subject: [PATCH 067/196] internal/mcp: rename LocalTransport->InMemoryTransport, and document Update the design to mention our in-memory transport, since it is generally useful, and update the implementation to use distinct types for InMemoryTransport and StdIOTransport (the fact that they are both implemented as an io.ReadWriteCloser is an implementation detail). Change-Id: I8228ccf2092a56a730c113c4db105fb492cce112 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672715 Reviewed-by: Sam Thanawalla Commit-Queue: Robert Findley Reviewed-by: Jonathan Amsterdam TryBot-Bypass: Robert Findley --- internal/mcp/design/design.md | 8 ++++++++ internal/mcp/mcp_test.go | 6 +++--- internal/mcp/server_example_test.go | 2 +- internal/mcp/transport.go | 30 ++++++++++++++++++++--------- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 558161dfb64..12cab8b9b62 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -283,6 +283,14 @@ func NewStreamableClientTransport(url string) *StreamableClientTransport { func (*StreamableClientTransport) Connect(context.Context) (Stream, error) ``` +Finally, we also provide an in-memory transport, for scenarios such as testing, +where the MCP client and server are in the same process. + +```go +type InMemoryTransport struct { /* ... */ } +func NewInMemoryTransport() (*InMemoryTransport, *InMemoryTransport) +``` + **Differences from mcp-go**: The Go team has a battle-tested JSON-RPC implementation that we use for gopls, our Go LSP server. We are using the new version of this library as part of our MCP SDK. It handles all JSON-RPC 2.0 diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 253254adcce..9f7d0fe12bb 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -33,7 +33,7 @@ func sayHi(ctx context.Context, cc *ServerSession, v hiParams) ([]*Content, erro func TestEndToEnd(t *testing.T) { ctx := context.Background() - ct, st := NewLocalTransport() + ct, st := NewInMemoryTransport() s := NewServer("testServer", "v1.0.0", nil) @@ -286,7 +286,7 @@ func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *Clien t.Helper() ctx := context.Background() - ct, st := NewLocalTransport() + ct, st := NewInMemoryTransport() s := NewServer("testServer", "v1.0.0", nil) @@ -329,7 +329,7 @@ func TestServerClosing(t *testing.T) { func TestBatching(t *testing.T) { ctx := context.Background() - ct, st := NewLocalTransport() + ct, st := NewInMemoryTransport() s := NewServer("testServer", "v1.0.0", nil) _, err := s.Connect(ctx, st, nil) diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index 5dede1f78f0..ed7438184e2 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -24,7 +24,7 @@ func SayHi(ctx context.Context, cc *mcp.ServerSession, params *SayHiParams) ([]* func ExampleServer() { ctx := context.Background() - clientTransport, serverTransport := mcp.NewLocalTransport() + clientTransport, serverTransport := mcp.NewInMemoryTransport() server := mcp.NewServer("greeter", "v0.0.1", nil) server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index 7c902848d60..a39d74d5b16 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -48,27 +48,39 @@ type ConnectionOptions struct { Logger io.Writer // if set, write RPC logs } -// An IOTransport is a [Transport] that communicates using newline-delimited +// A StdIOTransport is a [Transport] that communicates over stdin/stdout using +// newline-delimited JSON. +type StdIOTransport struct { + ioTransport +} + +// An ioTransport is a [Transport] that communicates using newline-delimited // JSON over an io.ReadWriteCloser. -type IOTransport struct { +type ioTransport struct { rwc io.ReadWriteCloser } -func (t *IOTransport) Connect(context.Context) (Stream, error) { +func (t *ioTransport) Connect(context.Context) (Stream, error) { return newIOStream(t.rwc), nil } // NewStdIOTransport constructs a transport that communicates over // stdin/stdout. -func NewStdIOTransport() *IOTransport { - return &IOTransport{rwc{os.Stdin, os.Stdout}} +func NewStdIOTransport() *StdIOTransport { + return &StdIOTransport{ioTransport{rwc{os.Stdin, os.Stdout}}} +} + +// An InMemoryTransport is a [Transport] that communicates over an in-memory +// network connection, using newline-delimited JSON. +type InMemoryTransport struct { + ioTransport } -// NewLocalTransport returns two in-memory transports that connect to -// each other, for testing purposes. -func NewLocalTransport() (*IOTransport, *IOTransport) { +// NewInMemoryTransport returns two InMemoryTransports that connect to each +// other. +func NewInMemoryTransport() (*InMemoryTransport, *InMemoryTransport) { c1, c2 := net.Pipe() - return &IOTransport{c1}, &IOTransport{c2} + return &InMemoryTransport{ioTransport{c1}}, &InMemoryTransport{ioTransport{c2}} } // handler is an unexported version of jsonrpc2.Handler. From ad2312c50758eacfa3b6ae16c23d76dfc79fadab Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 14 May 2025 13:10:57 -0400 Subject: [PATCH 068/196] internal/mcp: design.md: link to spec change Add a link to a recent commit that illustrates a backward-compatible change. Change-Id: I5e277f9417cf9bbbc5146512e29a3cb8205a4e14 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672716 Reviewed-by: Sam Thanawalla TryBot-Bypass: Jonathan Amsterdam Commit-Queue: Jonathan Amsterdam --- internal/mcp/design/design.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 12cab8b9b62..51819d01704 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -495,7 +495,9 @@ Our SDK has a method for every RPC in the spec, and except for `CallTool`, their signatures all share this form. We do this, rather than providing more convenient shortcut signatures, to maintain backward compatibility if the spec makes backward-compatible changes -such as adding a new property to the request parameters. +such as adding a new property to the request parameters +(as in [this commit](https://github.com/modelcontextprotocol/modelcontextprotocol/commit/2fce8a077688bf8011e80af06348b8fe1dae08ac), +for example). To avoid boilerplate, we don't repeat this signature for RPCs defined in the spec; readers may assume it when we mention a "spec method." From 5eb0d1f4706821805f47d59de20d2d5cf2d8010e Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Wed, 14 May 2025 17:48:58 +0000 Subject: [PATCH 069/196] internal/mcp: fix typos Change-Id: I279447fd30f71bad6ea5e2b83d776ca34910c168 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672775 TryBot-Bypass: Sam Thanawalla Reviewed-by: Robert Findley Auto-Submit: Sam Thanawalla --- internal/mcp/design/design.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 51819d01704..a76177be4bf 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -219,7 +219,7 @@ create `SSEServerTransport` instances themselves, for incoming GET requests. // A SSEServerTransport is a logical SSE session created through a hanging GET // request. // -// When connected, it it returns the following [Stream] implementation: +// When connected, it returns the following [Stream] implementation: // - Writes are SSE 'message' events to the GET response. // - Reads are received from POSTs to the session endpoint, via // [SSEServerTransport.ServeHTTP]. @@ -503,7 +503,7 @@ signature for RPCs defined in the spec; readers may assume it when we mention a "spec method." `CallTool` is the only exception: for convenience, it takes the tool name and -arguments, with an options truct for additional request fields. +arguments, with an options struct for additional request fields. Our SDK has a method for every RPC in the spec, and their signatures all share this form. To avoid boilerplate, we don't repeat this signature for RPCs defined in the spec; readers may assume it when we mention a "spec method." @@ -840,7 +840,7 @@ Clients can call the spec method `ListTools` or an iterator method `Tools` to list the available tools. **Differences from mcp-go**: using variadic options to configure tools was -signficantly inspired by mcp-go. However, the distinction between `ToolOption` +significantly inspired by mcp-go. However, the distinction between `ToolOption` and `SchemaOption` allows for recursive application of schema options. For example, that limitation is visible in [this code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), @@ -1044,13 +1044,13 @@ type ServerOptions { // The value for the "logger" field of the notification. LoggerName string // Log notifications to a single ClientSession will not be - // send more frequently than this duration. + // sent more frequently than this duration. LogInterval time.Duration } ``` ServerSessions have access to a `slog.Logger` that writes to the client. A call to -a log method like `Info`is translated to a `LoggingMessageNotification` as +a log method like `Info` is translated to a `LoggingMessageNotification` as follows: - The attributes and the message populate the "data" property with the @@ -1099,5 +1099,5 @@ pagination. Server responses for List methods include a `NextCursor` field if more pages exist. In addition to the `List` methods, the SDK provides an iterator method for each -list operation. This simplifies pagination for cients by automatically handling +list operation. This simplifies pagination for clients by automatically handling the underlying pagination logic. See [Iterator Methods](#iterator-methods) above. From fdae66bf098166d8c05765cb183057cfdff19af7 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 14 May 2025 13:18:23 -0400 Subject: [PATCH 070/196] internal/mcp: design.md: polishing Various typo and style changes. Change-Id: Ifd4322dcf0bf2b2653bb38a023a51056424cca39 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672815 Reviewed-by: Robert Findley Reviewed-by: Sam Thanawalla Commit-Queue: Jonathan Amsterdam TryBot-Bypass: Jonathan Amsterdam --- internal/mcp/design/design.md | 83 +++++++++++++++-------------------- 1 file changed, 36 insertions(+), 47 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index a76177be4bf..f3905f394ae 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -1,6 +1,6 @@ # Go MCP SDK design -This file discusses the design of a Go SDK for the [model context +This document discusses the design of a Go SDK for the [model context protocol](https://modelcontextprotocol.io/specification/2025-03-26). It is intended to seed a GitHub discussion about the official Go MCP SDK. @@ -18,7 +18,7 @@ writing, it is imported by over 400 packages that span over 200 modules. We admire mcp-go, and seriously considered simply adopting it as a starting point for this SDK. However, as we looked at doing so, we realized that a significant amount of its API would probably need to change. In some cases, -mcp-go has older APIs that predated newer variations--an obvious opportunity +mcp-go has older APIs that predated newer variations—an obvious opportunity for cleanup. In others, it took a batteries-included approach that is probably not viable for an official SDK. In yet others, we simply think there is room for API refinement, and we should take this opportunity to consider our options. @@ -332,7 +332,7 @@ marshalling/unmarshalling can be delegated to the business logic of the client or server. For union types, which can't be represented in Go (specifically `Content` and -`Resource`), we prefer distinguished unions: struct types with fields +`ResourceContents`), we prefer distinguished unions: struct types with fields corresponding to the union of all properties for union elements. For brevity, only a few examples are shown here: @@ -353,11 +353,11 @@ type CallToolResult struct { // The Type field distinguishes the type of the content. // At most one of Text, MIMEType, Data, and Resource is non-zero. type Content struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - MIMEType string `json:"mimeType,omitempty"` - Data []byte `json:"data,omitempty"` - Resource *Resource `json:"resource,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Data []byte `json:"data,omitempty"` + Resource *ResourceContents `json:"resource,omitempty"` } ``` @@ -386,7 +386,7 @@ change. Following the terminology of the [spec](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#session-management), -we call the logical connection between a client and server a "session". There +we call the logical connection between a client and server a "session." There must necessarily be a `ClientSession` and a `ServerSession`, corresponding to the APIs available from the client and server perspective, respectively. @@ -436,7 +436,7 @@ transport := mcp.NewCommandTransport(exec.Command("myserver")) session, err := client.Connect(ctx, transport) if err != nil { ... } // Call a tool on the server. -content, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}, nil}) +content, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}, nil) ... return session.Close() ``` @@ -446,7 +446,7 @@ A server that can handle that client call would look like this: ```go // Create a server with a single tool. server := mcp.NewServer("greeter", "v1.0.0", nil) -server.AddTool(mcp.NewTool("greet", "say hi", SayHi)) +server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) // Run the server over stdin/stdout, until the client disconnects. transport := mcp.NewStdIOTransport() session, err := server.Connect(ctx, transport) @@ -541,8 +541,8 @@ func (*ClientSession) ResourceTemplates(context.Context, *ListResourceTemplatesP ### Middleware -We provide a mechanism to add MCP-level middleware, which runs after the -request has been parsed, but before any normal handling. +We provide a mechanism to add MCP-level middleware on the server side, which runs after the +request has been parsed but before any normal handling. ```go // A Dispatcher dispatches an MCP message to the appropriate handler. @@ -551,14 +551,14 @@ request has been parsed, but before any normal handling. type Dispatcher func(ctx context.Context, s *ServerSession, method string, params any) (result any, err error) // AddDispatchers calls each function from right to left on the previous result, beginning -// with the server's current dispatcher, and installs the result as the new handler. -func (*Server) AddDispatchers(middleware ...func(Handler) Handler)) +// with the server's current dispatcher, and installs the result as the new dispatcher. +func (*Server) AddDispatchers(middleware ...func(Dispatcher) Dispatcher)) ``` As an example, this code adds server-side logging: ```go -func withLogging(h mcp.Handler) mcp.Handler { +func withLogging(h mcp.Dispatcher) mcp.Dispatcher { return func(ctx context.Context, s *mcp.ServerSession, method string, params any) (res any, err error) { log.Printf("request: %s %v", method, params) defer func() { log.Printf("response: %v, %v", res, err) }() @@ -621,11 +621,9 @@ The server observes a client cancellation as a cancelled context. A caller can request progress notifications by setting the `ProgressToken` field on any request. ```go -type ProgressToken any // string or int - type XXXParams struct { // where XXX is each type of call ... - ProgressToken ProgressToken + ProgressToken any // string or int } ``` @@ -681,7 +679,7 @@ Roots can be added and removed from a `Client` with `AddRoots` and `RemoveRoots` // AddRoots adds the given roots to the client, // replacing any with the same URIs, // and notifies any connected servers. -func (*Client) AddRoots(roots ...Root) +func (*Client) AddRoots(roots ...*Root) // RemoveRoots removes the roots with the given URIs. // and notifies any connected servers if the list has changed. @@ -689,7 +687,7 @@ func (*Client) AddRoots(roots ...Root) func (*Client) RemoveRoots(uris ...string) ``` -Servers can call the spec method `ListRoots` to get the roots. If a server installs a +Server sessions can call the spec method `ListRoots` to get the roots. If a server installs a `RootsChangedHandler`, it will be called when the client sends a roots-changed notification, which happens whenever the list of roots changes after a connection has been established. @@ -705,7 +703,7 @@ type ServerOptions { ### Sampling Clients that support sampling are created with a `CreateMessageHandler` option -for handling server calls. To perform sampling, a server calls the spec method `CreateMessage`. +for handling server calls. To perform sampling, a server session calls the spec method `CreateMessage`. ```go type ClientOptions struct { @@ -742,7 +740,7 @@ Add tools to a server with `AddTools`: ```go server.AddTools( mcp.NewTool("add", "add numbers", addHandler), - mcp.NewTools("subtract, subtract numbers", subHandler)) + mcp.NewTool("subtract, subtract numbers", subHandler)) ``` Remove them by name with `RemoveTools`: @@ -836,7 +834,7 @@ Schemas are validated on the server before the tool handler is called. Since all the fields of the Tool struct are exported, a Tool can also be created directly with assignment or a struct literal. -Clients can call the spec method `ListTools` or an iterator method `Tools` +Client sessions can call the spec method `ListTools` or an iterator method `Tools` to list the available tools. **Differences from mcp-go**: using variadic options to configure tools was @@ -862,7 +860,7 @@ each occur only once (and in an SDK that wraps mcp-go). For registering tools, we provide only `AddTools`; mcp-go's `SetTools`, `AddTool`, `AddSessionTool`, and `AddSessionTools` are deemed unnecessary. -(similarly for Delete/Remove). +(Similarly for Delete/Remove). ### Prompts @@ -893,8 +891,8 @@ server.AddPrompts( server.RemovePrompts("code_review") ``` -Clients can call the spec method `ListPrompts` or an iterator method `Prompts` -to list the available prompts and the spec method `GetPrompt` to get one. +Client sessions can call the spec method `ListPrompts` or the iterator method `Prompts` +to list the available prompts, and the spec method `GetPrompt` to get one. **Differences from mcp-go**: We provide a `NewPrompt` helper to bind a prompt handler to a Go function using reflection to derive its arguments. We provide @@ -926,7 +924,8 @@ type ServerResourceTemplate struct { ``` To add a resource or resource template to a server, users call the `AddResources` and -`AddResourceTemplates` methods with one or more `ServerResource`s or `ServerResourceTemplate`s: +`AddResourceTemplates` methods with one or more `ServerResource`s or `ServerResourceTemplate`s. +We also provide methods to remove them. ```go func (*Server) AddResources(...*ServerResource) @@ -938,14 +937,12 @@ func (s *Server) RemoveResourceTemplates(uriTemplates ...string) The `ReadResource` method finds a resource or resource template matching the argument URI and calls its assocated handler. -If the argument URI matches a template, the `Resource` argument to the handler is constructed -from the fields in the `ResourceTemplate`. To read files from the local filesystem, we recommend using `FileResourceHandler` to construct a handler: ```go // FileResourceHandler returns a ResourceHandler that reads paths using dir as a root directory. // It protects against path traversal attacks. -// It will not read any file that is not in the root set of the client requesting the resource. +// It will not read any file that is not in the root set of the client session requesting the resource. func (*Server) FileResourceHandler(dir string) ResourceHandler ``` Here is an example: @@ -957,17 +954,8 @@ s.AddResources(&mcp.ServerResource{ Handler: s.FileReadResourceHandler("/public")}) ``` -Servers support all of the resource-related spec methods: - -- `ListResources` and `ListResourceTemplates` for listings. -- `ReadResource` to get the contents of a resource. -- `Subscribe` and `Unsubscribe` to manage subscriptions on resources. - -We also provide iterator methods `Resources` and `ResourceTemplates`. - -`ReadResource` checks the incoming URI against the server's list of -resources and resource templates to make sure it matches one of them, -then returns the result of calling the associated reader function. +Server sessions also support the spec methods `ListResources` and `ListResourceTemplates`, +and the corresponding iterator methods `Resources` and `ResourceTemplates`. #### Subscriptions @@ -1017,6 +1005,7 @@ type ClientOptions struct { ... ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) + // For both resources and resource templates. ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) } ``` @@ -1049,8 +1038,8 @@ type ServerOptions { } ``` -ServerSessions have access to a `slog.Logger` that writes to the client. A call to -a log method like `Info` is translated to a `LoggingMessageNotification` as +Server sessions have a field `Logger` holding a `slog.Logger` that writes to the client session. +A call to a log method like `Info` is translated to a `LoggingMessageNotification` as follows: - The attributes and the message populate the "data" property with the @@ -1060,11 +1049,11 @@ follows: - If the `LoggerName` server option is set, it populates the "logger" property. - The standard slog levels `Info`, `Debug`, `Warn` and `Error` map to the - corresponding levels in the MCP spec. The other spec levels will be mapped + corresponding levels in the MCP spec. The other spec levels map to integers between the slog levels. For example, "notice" is level 2 because it is between "warning" (slog value 4) and "info" (slog value 0). The `mcp` package defines consts for these levels. To log at the "notice" - level, a handler would call `session.Log(ctx, mcp.LevelNotice, "message")`. + level, a handler would call `session.Logger.Log(ctx, mcp.LevelNotice, "message")`. A client that wishes to receive log messages must provide a handler: @@ -1080,7 +1069,7 @@ type ClientOptions struct { Servers initiate pagination for `ListTools`, `ListPrompts`, `ListResources`, and `ListResourceTemplates`, dictating the page size and providing a `NextCursor` field in the Result if more pages exist. The SDK implements keyset -pagination, using the `unique ID` as the key for a stable sort order and encoding +pagination, using the unique ID of the feature as the key for a stable sort order and encoding the cursor as an opaque string. For server implementations, the page size for the list operation may be From baeb0dae388387e606b04319fcb10c770143d7b5 Mon Sep 17 00:00:00 2001 From: xieyuschen Date: Wed, 14 May 2025 21:49:32 -0600 Subject: [PATCH 071/196] cmd/deadcode: respect unused symbols inside all loaded modules Current deadcode uses the first package module path as a regular expression to filter when check deadcode code for packages. However, when users load multiple modules, the filter will respect one of them only. This CL constructs the filter regex based on all loaded modules rather than use the first privileging the first package. Fixes golang/go#73652 Change-Id: Id7b9eb8274141cd2d6362da01366cbc45c87eebc Reviewed-on: https://go-review.googlesource.com/c/tools/+/671916 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Auto-Submit: Alan Donovan --- cmd/deadcode/deadcode.go | 13 +++++++-- cmd/deadcode/doc.go | 6 ++-- cmd/deadcode/testdata/issue73652.txtar | 39 ++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 cmd/deadcode/testdata/issue73652.txtar diff --git a/cmd/deadcode/deadcode.go b/cmd/deadcode/deadcode.go index e164dc22ba8..0c0b7ec394e 100644 --- a/cmd/deadcode/deadcode.go +++ b/cmd/deadcode/deadcode.go @@ -132,8 +132,17 @@ func main() { // If -filter is unset, use first module (if available). if *filterFlag == "" { - if mod := initial[0].Module; mod != nil && mod.Path != "" { - *filterFlag = "^" + regexp.QuoteMeta(mod.Path) + "\\b" + seen := make(map[string]bool) + var patterns []string + for _, pkg := range initial { + if pkg.Module != nil && pkg.Module.Path != "" && !seen[pkg.Module.Path] { + seen[pkg.Module.Path] = true + patterns = append(patterns, regexp.QuoteMeta(pkg.Module.Path)) + } + } + + if patterns != nil { + *filterFlag = "^(" + strings.Join(patterns, "|") + ")\\b" } else { *filterFlag = "" // match any } diff --git a/cmd/deadcode/doc.go b/cmd/deadcode/doc.go index 66a150dd19d..bd474248e55 100644 --- a/cmd/deadcode/doc.go +++ b/cmd/deadcode/doc.go @@ -5,7 +5,7 @@ /* The deadcode command reports unreachable functions in Go programs. -Usage: deadcode [flags] package... + Usage: deadcode [flags] package... The deadcode command loads a Go program from source then uses Rapid Type Analysis (RTA) to build a call graph of all the functions @@ -25,8 +25,8 @@ function without an "Output:" comment is merely documentation: it is dead code, and does not contribute coverage. The -filter flag restricts results to packages that match the provided -regular expression; its default value is the module name of the first -package. Use -filter= to display all results. +regular expression; its default value matches the listed packages and any other +packages belonging to the same modules. Use -filter= to display all results. Example: show all dead code within the gopls module: diff --git a/cmd/deadcode/testdata/issue73652.txtar b/cmd/deadcode/testdata/issue73652.txtar new file mode 100644 index 00000000000..e3cf00f5719 --- /dev/null +++ b/cmd/deadcode/testdata/issue73652.txtar @@ -0,0 +1,39 @@ +# Test deadcode usage under go.work. + + deadcode ./svc/... ./lib/... + want "unreachable func: A" + +# different order of path under the same go.work should behave the same. + + deadcode ./svc/... ./lib/... + want "unreachable func: A" + + +-- go.work -- +go 1.18 + +use ( + ./lib + ./svc +) + +-- lib/go.mod -- +module lib.com + +go 1.18 + +-- lib/a/a.go -- +package a + +func A() {} + +-- svc/go.mod -- +module svc.com + +go 1.18 + +-- svc/s/main.go -- +package main + +func main() { println("main") } + From 43dd7128941bbc8afcc00ae4e4d4c2ed897e98bc Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 13 May 2025 18:01:17 -0400 Subject: [PATCH 072/196] gopls/internal/protocol: make FoldingRange fields optional For these fields, the protocol defines "missing" to mean something different from zero, so they need to be indirect. (An alternative fix would be to remove omitempty, which would work for the server since it always sets the fields, but would not work for a client using the protocol package, since it would have no way to distinguish zero from unset.) Fixes golang/go#71489 Change-Id: I36066f7ef9e41683c94cc19bb310e30a38763fde Reviewed-on: https://go-review.googlesource.com/c/tools/+/672416 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Auto-Submit: Alan Donovan --- gopls/internal/cmd/folding_range.go | 9 ++++---- gopls/internal/golang/folding_range.go | 18 +++++++++------- gopls/internal/protocol/generate/generate.go | 2 -- gopls/internal/protocol/generate/tables.go | 22 +++++++++++++------- gopls/internal/protocol/tsprotocol.go | 8 +++---- gopls/internal/test/marker/marker_test.go | 5 +++-- 6 files changed, 37 insertions(+), 27 deletions(-) diff --git a/gopls/internal/cmd/folding_range.go b/gopls/internal/cmd/folding_range.go index f48feee5b2c..af45d0b0364 100644 --- a/gopls/internal/cmd/folding_range.go +++ b/gopls/internal/cmd/folding_range.go @@ -59,11 +59,12 @@ func (r *foldingRanges) Run(ctx context.Context, args ...string) error { } for _, r := range ranges { + // We assume our server always supplies these fields. fmt.Printf("%v:%v-%v:%v\n", - r.StartLine+1, - r.StartCharacter+1, - r.EndLine+1, - r.EndCharacter+1, + *r.StartLine+1, + *r.StartCharacter+1, + *r.EndLine+1, + *r.EndCharacter+1, ) } diff --git a/gopls/internal/golang/folding_range.go b/gopls/internal/golang/folding_range.go index eed31e92944..2cf9f9a6b94 100644 --- a/gopls/internal/golang/folding_range.go +++ b/gopls/internal/golang/folding_range.go @@ -127,10 +127,10 @@ func FoldingRange(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, // Sort by start position. slices.SortFunc(ranges, func(x, y protocol.FoldingRange) int { - if d := cmp.Compare(x.StartLine, y.StartLine); d != 0 { + if d := cmp.Compare(*x.StartLine, *y.StartLine); d != 0 { return d } - return cmp.Compare(x.StartCharacter, y.StartCharacter) + return cmp.Compare(*x.StartCharacter, *y.StartCharacter) }) return ranges, nil @@ -232,11 +232,15 @@ func commentsFoldingRange(pgf *parsego.File) (comments []protocol.FoldingRange) func foldingRange(kind protocol.FoldingRangeKind, rng protocol.Range) protocol.FoldingRange { return protocol.FoldingRange{ - // I have no idea why LSP doesn't use a protocol.Range here. - StartLine: rng.Start.Line, - StartCharacter: rng.Start.Character, - EndLine: rng.End.Line, - EndCharacter: rng.End.Character, + // (I guess LSP doesn't use a protocol.Range here + // because missing means something different from zero.) + StartLine: varOf(rng.Start.Line), + StartCharacter: varOf(rng.Start.Character), + EndLine: varOf(rng.End.Line), + EndCharacter: varOf(rng.End.Character), Kind: string(kind), } } + +// varOf returns a new variable whose value is x. +func varOf[T any](x T) *T { return &x } diff --git a/gopls/internal/protocol/generate/generate.go b/gopls/internal/protocol/generate/generate.go index fef8ef417eb..72a2a0c5ad2 100644 --- a/gopls/internal/protocol/generate/generate.go +++ b/gopls/internal/protocol/generate/generate.go @@ -78,8 +78,6 @@ func propStar(name string, t NameType, gotype string) (omitempty, indirect bool) switch newStar { case nothing: indirect, omitempty = false, false - case wantStar: - indirect, omitempty = false, false case wantOpt: indirect, omitempty = false, true case wantOptStar: diff --git a/gopls/internal/protocol/generate/tables.go b/gopls/internal/protocol/generate/tables.go index c0841a2334b..eccaf9cd1c3 100644 --- a/gopls/internal/protocol/generate/tables.go +++ b/gopls/internal/protocol/generate/tables.go @@ -6,14 +6,14 @@ package main import "log" -// prop combines the name of a property with the name of the structure it is in. +// prop combines the name of a property (class.field) with the name of +// the structure it is in, using LSP field capitalization. type prop [2]string const ( - nothing = iota - wantStar - wantOpt - wantOptStar + nothing = iota + wantOpt // omitempty + wantOptStar // omitempty, indirect ) // goplsStar records the optionality of each field in the protocol. @@ -37,13 +37,19 @@ var goplsStar = map[prop]int{ {"Diagnostic", "severity"}: wantOpt, // nil checks or more careful thought {"DidSaveTextDocumentParams", "text"}: wantOptStar, // capabilities_test.go:112 logic {"DocumentHighlight", "kind"}: wantOpt, // need temporary variables - {"Hover", "range"}: wantOpt, // complex expressions - {"InlayHint", "kind"}: wantOpt, // temporary variables + + {"FoldingRange", "startLine"}: wantOptStar, // unset != zero (#71489) + {"FoldingRange", "startCharacter"}: wantOptStar, // unset != zero (#71489) + {"FoldingRange", "endLine"}: wantOptStar, // unset != zero (#71489) + {"FoldingRange", "endCharacter"}: wantOptStar, // unset != zero (#71489) + + {"Hover", "range"}: wantOpt, // complex expressions + {"InlayHint", "kind"}: wantOpt, // temporary variables {"TextDocumentClientCapabilities", "codeAction"}: wantOpt, // A.B.C.D {"TextDocumentClientCapabilities", "completion"}: wantOpt, // A.B.C.D {"TextDocumentClientCapabilities", "documentSymbol"}: wantOpt, // A.B.C.D - {"TextDocumentClientCapabilities", "publishDiagnostics"}: wantOpt, //A.B.C.D + {"TextDocumentClientCapabilities", "publishDiagnostics"}: wantOpt, // A.B.C.D {"TextDocumentClientCapabilities", "semanticTokens"}: wantOpt, // A.B.C.D {"TextDocumentContentChangePartial", "range"}: wantOptStar, // == nil test {"TextDocumentSyncOptions", "change"}: wantOpt, // &constant diff --git a/gopls/internal/protocol/tsprotocol.go b/gopls/internal/protocol/tsprotocol.go index 7306f62a7ad..a759eb2ed89 100644 --- a/gopls/internal/protocol/tsprotocol.go +++ b/gopls/internal/protocol/tsprotocol.go @@ -2375,14 +2375,14 @@ type FileSystemWatcher struct { type FoldingRange struct { // The zero-based start line of the range to fold. The folded area starts after the line's last character. // To be valid, the end must be zero or larger and smaller than the number of lines in the document. - StartLine uint32 `json:"startLine"` + StartLine *uint32 `json:"startLine,omitempty"` // The zero-based character offset from where the folded range starts. If not defined, defaults to the length of the start line. - StartCharacter uint32 `json:"startCharacter"` + StartCharacter *uint32 `json:"startCharacter,omitempty"` // The zero-based end line of the range to fold. The folded area ends with the line's last character. // To be valid, the end must be zero or larger and smaller than the number of lines in the document. - EndLine uint32 `json:"endLine"` + EndLine *uint32 `json:"endLine,omitempty"` // The zero-based character offset before the folded range ends. If not defined, defaults to the length of the end line. - EndCharacter uint32 `json:"endCharacter"` + EndCharacter *uint32 `json:"endCharacter,omitempty"` // Describes the kind of the folding range such as 'comment' or 'region'. The kind // is used to categorize folding ranges and used by commands like 'Fold all comments'. // See {@link FoldingRangeKind} for an enumeration of standardized kinds. diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index d2d4f899d48..8cc7c56320d 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -1691,8 +1691,9 @@ func foldingRangeMarker(mark marker, g *Golden) { }) } for i, rng := range ranges { - insert(rng.StartLine, rng.StartCharacter, fmt.Sprintf("<%d kind=%q>", i, rng.Kind)) - insert(rng.EndLine, rng.EndCharacter, fmt.Sprintf("", i)) + // We assume the server populates these optional fields. + insert(*rng.StartLine, *rng.StartCharacter, fmt.Sprintf("<%d kind=%q>", i, rng.Kind)) + insert(*rng.EndLine, *rng.EndCharacter, fmt.Sprintf("", i)) } filename := mark.path() mapper, err := env.Editor.Mapper(filename) From 84fa02a987bf078de641c30936d8b6067f869a6e Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 12 May 2025 13:43:46 -0400 Subject: [PATCH 073/196] x/tools: gofix -fix -test ./... Updates golang/go#70859 Change-Id: I9a4345b8e973aaf45ed052ccb8b9807d0868eca8 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672016 Auto-Submit: Alan Donovan Commit-Queue: Alan Donovan Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- go/analysis/passes/gofix/gofix.go | 3 +-- go/types/internal/play/play.go | 3 +-- .../analysis/fillreturns/fillreturns.go | 5 ++-- gopls/internal/analysis/maprange/maprange.go | 8 +++--- gopls/internal/analysis/modernize/bloop.go | 9 +++---- .../internal/analysis/modernize/fmtappendf.go | 2 +- gopls/internal/analysis/modernize/maps.go | 3 +-- gopls/internal/analysis/modernize/minmax.go | 3 +-- .../internal/analysis/modernize/modernize.go | 9 +++---- gopls/internal/analysis/modernize/rangeint.go | 5 ++-- .../analysis/modernize/slicescontains.go | 3 +-- .../internal/analysis/modernize/stringsseq.go | 2 +- .../analysis/modernize/testingcontext.go | 2 +- .../internal/analysis/nonewvars/nonewvars.go | 3 +-- .../analysis/noresultvalues/noresultvalues.go | 3 +-- .../analysis/unusedparams/unusedparams.go | 5 ++-- gopls/internal/cache/parsego/file.go | 4 +-- gopls/internal/cache/parsego/parse.go | 3 +-- gopls/internal/cmd/cmd.go | 2 +- gopls/internal/golang/codeaction.go | 6 ++--- gopls/internal/golang/completion/newfile.go | 2 +- gopls/internal/golang/diagnostics.go | 2 +- gopls/internal/golang/extract.go | 10 +++---- gopls/internal/golang/format.go | 2 +- gopls/internal/golang/implementation.go | 6 ++--- gopls/internal/golang/inlay_hint.go | 18 ++++++------- gopls/internal/golang/inline_all.go | 2 +- gopls/internal/golang/invertifcondition.go | 4 +-- gopls/internal/golang/known_packages.go | 2 +- gopls/internal/golang/lines.go | 8 +++--- gopls/internal/golang/rename.go | 8 +++--- gopls/internal/golang/rename_check.go | 10 +++---- gopls/internal/golang/symbols.go | 2 +- gopls/internal/protocol/protocol.go | 3 +-- gopls/internal/server/command.go | 6 ++--- gopls/internal/server/diagnostics.go | 2 +- gopls/internal/server/link.go | 2 +- internal/analysisinternal/analysis.go | 3 +-- internal/analysisinternal/analysis_test.go | 5 ++-- internal/astutil/edge/edge.go | 2 +- internal/gofix/findgofix/findgofix.go | 4 +-- internal/gofix/gofix.go | 15 +++++------ internal/typesinternal/typeindex/typeindex.go | 27 +++++++++---------- .../typesinternal/typeindex/typeindex_test.go | 3 +-- 44 files changed, 106 insertions(+), 125 deletions(-) diff --git a/go/analysis/passes/gofix/gofix.go b/go/analysis/passes/gofix/gofix.go index 706e0759c3a..f6b66156276 100644 --- a/go/analysis/passes/gofix/gofix.go +++ b/go/analysis/passes/gofix/gofix.go @@ -12,7 +12,6 @@ import ( "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/internal/analysisinternal" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/gofix/findgofix" ) @@ -28,7 +27,7 @@ var Analyzer = &analysis.Analyzer{ } func run(pass *analysis.Pass) (any, error) { - root := cursor.Root(pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)) + root := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector).Root() findgofix.Find(pass, root, nil) return nil, nil } diff --git a/go/types/internal/play/play.go b/go/types/internal/play/play.go index 77a90502135..f1a3b95e743 100644 --- a/go/types/internal/play/play.go +++ b/go/types/internal/play/play.go @@ -34,7 +34,6 @@ import ( "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/typeutil" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/typeparams" ) @@ -168,7 +167,7 @@ func handleSelectJSON(w http.ResponseWriter, req *http.Request) { // It's usually the same, but may differ in edge // cases (e.g. around FuncType.Func). inspect := inspector.New([]*ast.File{file}) - if cur, ok := cursor.Root(inspect).FindByPos(startPos, endPos); ok { + if cur, ok := inspect.Root().FindByPos(startPos, endPos); ok { fmt.Fprintf(out, "Cursor.FindPos().Enclosing() = %v\n", slices.Collect(cur.Enclosing())) } else { diff --git a/gopls/internal/analysis/fillreturns/fillreturns.go b/gopls/internal/analysis/fillreturns/fillreturns.go index b2cc1caf872..d6502db5773 100644 --- a/gopls/internal/analysis/fillreturns/fillreturns.go +++ b/gopls/internal/analysis/fillreturns/fillreturns.go @@ -21,7 +21,6 @@ import ( "golang.org/x/tools/gopls/internal/fuzzy" "golang.org/x/tools/gopls/internal/util/moreiters" "golang.org/x/tools/internal/analysisinternal" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/typesinternal" ) @@ -50,7 +49,7 @@ outer: if !ok { continue // no position information } - curErr, ok := cursor.Root(inspect).FindByPos(start, end) + curErr, ok := inspect.Root().FindByPos(start, end) if !ok { continue // can't find node } @@ -227,6 +226,6 @@ func fixesError(err types.Error) bool { // enclosingFunc returns the cursor for the innermost Func{Decl,Lit} // that encloses c, if any. -func enclosingFunc(c cursor.Cursor) (cursor.Cursor, bool) { +func enclosingFunc(c inspector.Cursor) (inspector.Cursor, bool) { return moreiters.First(c.Enclosing((*ast.FuncDecl)(nil), (*ast.FuncLit)(nil))) } diff --git a/gopls/internal/analysis/maprange/maprange.go b/gopls/internal/analysis/maprange/maprange.go index eed04b14e72..c74e684b827 100644 --- a/gopls/internal/analysis/maprange/maprange.go +++ b/gopls/internal/analysis/maprange/maprange.go @@ -11,11 +11,11 @@ import ( "go/types" "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/ast/edge" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/util/moreiters" "golang.org/x/tools/internal/analysisinternal" typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex" - "golang.org/x/tools/internal/astutil/cursor" - "golang.org/x/tools/internal/astutil/edge" "golang.org/x/tools/internal/typesinternal/typeindex" "golang.org/x/tools/internal/versions" ) @@ -66,7 +66,7 @@ func run(pass *analysis.Pass) (any, error) { // For certain patterns involving x/exp/maps.Keys before Go 1.22, it reports // a diagnostic about potential incorrect usage without a suggested fix. // No diagnostic is reported if the range statement doesn't require changes. -func analyzeRangeStmt(pass *analysis.Pass, callee types.Object, curCall cursor.Cursor) { +func analyzeRangeStmt(pass *analysis.Pass, callee types.Object, curCall inspector.Cursor) { var ( call = curCall.Node().(*ast.CallExpr) rangeStmt = curCall.Parent().Node().(*ast.RangeStmt) @@ -152,7 +152,7 @@ func isSet(expr ast.Expr) bool { // fileUses reports whether the file containing the specified cursor // uses at least the specified version of Go (e.g. "go1.24"). -func fileUses(info *types.Info, c cursor.Cursor, version string) bool { +func fileUses(info *types.Info, c inspector.Cursor, version string) bool { c, _ = moreiters.First(c.Enclosing((*ast.File)(nil))) file := c.Node().(*ast.File) return !versions.Before(info.FileVersions[file], version) diff --git a/gopls/internal/analysis/modernize/bloop.go b/gopls/internal/analysis/modernize/bloop.go index ea2359c7fb6..ed6c1b3f665 100644 --- a/gopls/internal/analysis/modernize/bloop.go +++ b/gopls/internal/analysis/modernize/bloop.go @@ -17,7 +17,6 @@ import ( "golang.org/x/tools/gopls/internal/util/moreiters" "golang.org/x/tools/internal/analysisinternal" typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/typesinternal/typeindex" ) @@ -43,12 +42,12 @@ func bloop(pass *analysis.Pass) { // edits computes the text edits for a matched for/range loop // at the specified cursor. b is the *testing.B value, and // (start, end) is the portion using b.N to delete. - edits := func(curLoop cursor.Cursor, b ast.Expr, start, end token.Pos) (edits []analysis.TextEdit) { + edits := func(curLoop inspector.Cursor, b ast.Expr, start, end token.Pos) (edits []analysis.TextEdit) { curFn, _ := enclosingFunc(curLoop) // Within the same function, delete all calls to // b.{Start,Stop,Timer} that precede the loop. filter := []ast.Node{(*ast.ExprStmt)(nil), (*ast.FuncLit)(nil)} - curFn.Inspect(filter, func(cur cursor.Cursor) (descend bool) { + curFn.Inspect(filter, func(cur inspector.Cursor) (descend bool) { node := cur.Node() if is[*ast.FuncLit](node) { return false // don't descend into FuncLits (e.g. sub-benchmarks) @@ -156,7 +155,7 @@ func bloop(pass *analysis.Pass) { } // uses reports whether the subtree cur contains a use of obj. -func uses(index *typeindex.Index, cur cursor.Cursor, obj types.Object) bool { +func uses(index *typeindex.Index, cur inspector.Cursor, obj types.Object) bool { for use := range index.Uses(obj) { if cur.Contains(use) { return true @@ -167,6 +166,6 @@ func uses(index *typeindex.Index, cur cursor.Cursor, obj types.Object) bool { // enclosingFunc returns the cursor for the innermost Func{Decl,Lit} // that encloses c, if any. -func enclosingFunc(c cursor.Cursor) (cursor.Cursor, bool) { +func enclosingFunc(c inspector.Cursor) (inspector.Cursor, bool) { return moreiters.First(c.Enclosing((*ast.FuncDecl)(nil), (*ast.FuncLit)(nil))) } diff --git a/gopls/internal/analysis/modernize/fmtappendf.go b/gopls/internal/analysis/modernize/fmtappendf.go index 6b01d38050e..cd9dfa5e311 100644 --- a/gopls/internal/analysis/modernize/fmtappendf.go +++ b/gopls/internal/analysis/modernize/fmtappendf.go @@ -11,8 +11,8 @@ import ( "strings" "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/ast/edge" typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex" - "golang.org/x/tools/internal/astutil/edge" "golang.org/x/tools/internal/typesinternal/typeindex" ) diff --git a/gopls/internal/analysis/modernize/maps.go b/gopls/internal/analysis/modernize/maps.go index 1a5e2c3eeee..1e32233b5b6 100644 --- a/gopls/internal/analysis/modernize/maps.go +++ b/gopls/internal/analysis/modernize/maps.go @@ -16,7 +16,6 @@ import ( "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/internal/analysisinternal" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/typeparams" ) @@ -51,7 +50,7 @@ func mapsloop(pass *analysis.Pass) { // check is called for each statement of this form: // for k, v := range x { m[k] = v } - check := func(file *ast.File, curRange cursor.Cursor, assign *ast.AssignStmt, m, x ast.Expr) { + check := func(file *ast.File, curRange inspector.Cursor, assign *ast.AssignStmt, m, x ast.Expr) { // Is x a map or iter.Seq2? tx := types.Unalias(info.TypeOf(x)) diff --git a/gopls/internal/analysis/modernize/minmax.go b/gopls/internal/analysis/modernize/minmax.go index 0e43ee11c3d..6c896289e1e 100644 --- a/gopls/internal/analysis/modernize/minmax.go +++ b/gopls/internal/analysis/modernize/minmax.go @@ -15,7 +15,6 @@ import ( "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/internal/analysisinternal" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/typeparams" ) @@ -38,7 +37,7 @@ func minmax(pass *analysis.Pass) { // check is called for all statements of this form: // if a < b { lhs = rhs } - check := func(file *ast.File, curIfStmt cursor.Cursor, compare *ast.BinaryExpr) { + check := func(file *ast.File, curIfStmt inspector.Cursor, compare *ast.BinaryExpr) { var ( ifStmt = curIfStmt.Node().(*ast.IfStmt) tassign = ifStmt.Body.List[0].(*ast.AssignStmt) diff --git a/gopls/internal/analysis/modernize/modernize.go b/gopls/internal/analysis/modernize/modernize.go index 44992c3aa14..d092c10c313 100644 --- a/gopls/internal/analysis/modernize/modernize.go +++ b/gopls/internal/analysis/modernize/modernize.go @@ -23,7 +23,6 @@ import ( "golang.org/x/tools/gopls/internal/util/moreiters" "golang.org/x/tools/internal/analysisinternal" typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/stdlib" "golang.org/x/tools/internal/versions" ) @@ -136,9 +135,9 @@ func isIntLiteral(info *types.Info, e ast.Expr, n int64) bool { // TODO(adonovan): opt: eliminate this function, instead following the // approach of [fmtappendf], which uses typeindex and [fileUses]. // See "Tip" at [fileUses] for motivation. -func filesUsing(inspect *inspector.Inspector, info *types.Info, version string) iter.Seq[cursor.Cursor] { - return func(yield func(cursor.Cursor) bool) { - for curFile := range cursor.Root(inspect).Children() { +func filesUsing(inspect *inspector.Inspector, info *types.Info, version string) iter.Seq[inspector.Cursor] { + return func(yield func(inspector.Cursor) bool) { + for curFile := range inspect.Root().Children() { file := curFile.Node().(*ast.File) if !versions.Before(info.FileVersions[file], version) && !yield(curFile) { break @@ -161,7 +160,7 @@ func fileUses(info *types.Info, file *ast.File, version string) bool { } // enclosingFile returns the syntax tree for the file enclosing c. -func enclosingFile(c cursor.Cursor) *ast.File { +func enclosingFile(c inspector.Cursor) *ast.File { c, _ = moreiters.First(c.Enclosing((*ast.File)(nil))) return c.Node().(*ast.File) } diff --git a/gopls/internal/analysis/modernize/rangeint.go b/gopls/internal/analysis/modernize/rangeint.go index 1d3f4b5db0c..7858f365d4d 100644 --- a/gopls/internal/analysis/modernize/rangeint.go +++ b/gopls/internal/analysis/modernize/rangeint.go @@ -12,12 +12,11 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/edge" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/internal/analysisinternal" typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex" - "golang.org/x/tools/internal/astutil/cursor" - "golang.org/x/tools/internal/astutil/edge" "golang.org/x/tools/internal/typesinternal" "golang.org/x/tools/internal/typesinternal/typeindex" ) @@ -246,7 +245,7 @@ func rangeint(pass *analysis.Pass) { // // This function is valid only for scalars (x = ...), // not for aggregates (x.a[i] = ...) -func isScalarLvalue(info *types.Info, curId cursor.Cursor) bool { +func isScalarLvalue(info *types.Info, curId inspector.Cursor) bool { // Unfortunately we can't simply use info.Types[e].Assignable() // as it is always true for a variable even when that variable is // used only as an r-value. So we must inspect enclosing syntax. diff --git a/gopls/internal/analysis/modernize/slicescontains.go b/gopls/internal/analysis/modernize/slicescontains.go index b5cd56022e1..3f74fef2b5b 100644 --- a/gopls/internal/analysis/modernize/slicescontains.go +++ b/gopls/internal/analysis/modernize/slicescontains.go @@ -16,7 +16,6 @@ import ( "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/internal/analysisinternal" typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/typeparams" "golang.org/x/tools/internal/typesinternal/typeindex" ) @@ -66,7 +65,7 @@ func slicescontains(pass *analysis.Pass) { // check is called for each RangeStmt of this form: // for i, elem := range s { if cond { ... } } - check := func(file *ast.File, curRange cursor.Cursor) { + check := func(file *ast.File, curRange inspector.Cursor) { rng := curRange.Node().(*ast.RangeStmt) ifStmt := rng.Body.List[0].(*ast.IfStmt) diff --git a/gopls/internal/analysis/modernize/stringsseq.go b/gopls/internal/analysis/modernize/stringsseq.go index d32f8be754f..9250a92711d 100644 --- a/gopls/internal/analysis/modernize/stringsseq.go +++ b/gopls/internal/analysis/modernize/stringsseq.go @@ -12,10 +12,10 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/edge" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/types/typeutil" typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex" - "golang.org/x/tools/internal/astutil/edge" "golang.org/x/tools/internal/typesinternal/typeindex" ) diff --git a/gopls/internal/analysis/modernize/testingcontext.go b/gopls/internal/analysis/modernize/testingcontext.go index de52f756ab8..b356a1eb081 100644 --- a/gopls/internal/analysis/modernize/testingcontext.go +++ b/gopls/internal/analysis/modernize/testingcontext.go @@ -14,10 +14,10 @@ import ( "unicode/utf8" "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/ast/edge" "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/internal/analysisinternal" typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex" - "golang.org/x/tools/internal/astutil/edge" "golang.org/x/tools/internal/typesinternal/typeindex" ) diff --git a/gopls/internal/analysis/nonewvars/nonewvars.go b/gopls/internal/analysis/nonewvars/nonewvars.go index 62383dc2309..c562f9754d4 100644 --- a/gopls/internal/analysis/nonewvars/nonewvars.go +++ b/gopls/internal/analysis/nonewvars/nonewvars.go @@ -16,7 +16,6 @@ import ( "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/util/moreiters" "golang.org/x/tools/internal/analysisinternal" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/typesinternal" ) @@ -43,7 +42,7 @@ func run(pass *analysis.Pass) (any, error) { if !ok { continue // can't get position info } - curErr, ok := cursor.Root(inspect).FindByPos(start, end) + curErr, ok := inspect.Root().FindByPos(start, end) if !ok { continue // can't find errant node } diff --git a/gopls/internal/analysis/noresultvalues/noresultvalues.go b/gopls/internal/analysis/noresultvalues/noresultvalues.go index 4f095c941c4..12b2720db63 100644 --- a/gopls/internal/analysis/noresultvalues/noresultvalues.go +++ b/gopls/internal/analysis/noresultvalues/noresultvalues.go @@ -16,7 +16,6 @@ import ( "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/util/moreiters" "golang.org/x/tools/internal/analysisinternal" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/typesinternal" ) @@ -43,7 +42,7 @@ func run(pass *analysis.Pass) (any, error) { if !ok { continue // can't get position info } - curErr, ok := cursor.Root(inspect).FindByPos(start, end) + curErr, ok := inspect.Root().FindByPos(start, end) if !ok { continue // can't find errant node } diff --git a/gopls/internal/analysis/unusedparams/unusedparams.go b/gopls/internal/analysis/unusedparams/unusedparams.go index 824711242da..422e029cd01 100644 --- a/gopls/internal/analysis/unusedparams/unusedparams.go +++ b/gopls/internal/analysis/unusedparams/unusedparams.go @@ -12,11 +12,10 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/edge" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/util/moreslices" "golang.org/x/tools/internal/analysisinternal" - "golang.org/x/tools/internal/astutil/cursor" - "golang.org/x/tools/internal/astutil/edge" "golang.org/x/tools/internal/typesinternal" ) @@ -126,7 +125,7 @@ func run(pass *analysis.Pass) (any, error) { // Check each non-address-taken function's parameters are all used. funcloop: - for c := range cursor.Root(inspect).Preorder((*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)) { + for c := range inspect.Root().Preorder((*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)) { var ( fn types.Object // function symbol (*Func, possibly *Var for a FuncLit) ftype *ast.FuncType diff --git a/gopls/internal/cache/parsego/file.go b/gopls/internal/cache/parsego/file.go index 2be4ed4b2ca..ef8a3379b03 100644 --- a/gopls/internal/cache/parsego/file.go +++ b/gopls/internal/cache/parsego/file.go @@ -11,10 +11,10 @@ import ( "go/token" "sync" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" - "golang.org/x/tools/internal/astutil/cursor" ) // A File contains the results of parsing a Go file. @@ -33,7 +33,7 @@ type File struct { // actual content of the file if we have fixed the AST. Src []byte - Cursor cursor.Cursor // cursor of *ast.File, sans sibling files + Cursor inspector.Cursor // cursor of *ast.File, sans sibling files // fixedSrc and fixedAST report on "fixing" that occurred during parsing of // this file. diff --git a/gopls/internal/cache/parsego/parse.go b/gopls/internal/cache/parsego/parse.go index 9a6bdf03da3..df4d9c8e44d 100644 --- a/gopls/internal/cache/parsego/parse.go +++ b/gopls/internal/cache/parsego/parse.go @@ -28,7 +28,6 @@ import ( "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/util/astutil" "golang.org/x/tools/gopls/internal/util/safetoken" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/event" ) @@ -125,7 +124,7 @@ func Parse(ctx context.Context, fset *token.FileSet, uri protocol.DocumentURI, s // Provide a cursor for fast and convenient navigation. inspect := inspector.New([]*ast.File{file}) - curFile, _ := cursor.Root(inspect).FirstChild() + curFile, _ := inspect.Root().FirstChild() _ = curFile.Node().(*ast.File) return &File{ diff --git a/gopls/internal/cmd/cmd.go b/gopls/internal/cmd/cmd.go index fed96388fb4..02c5103de37 100644 --- a/gopls/internal/cmd/cmd.go +++ b/gopls/internal/cmd/cmd.go @@ -215,7 +215,7 @@ func isZeroValue(f *flag.Flag, value string) bool { // This works unless the Value type is itself an interface type. typ := reflect.TypeOf(f.Value) var z reflect.Value - if typ.Kind() == reflect.Ptr { + if typ.Kind() == reflect.Pointer { z = reflect.New(typ.Elem()) } else { z = reflect.Zero(typ) diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index a7917fbbda4..07b577de745 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -17,6 +17,7 @@ import ( "strings" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/analysis/fillstruct" "golang.org/x/tools/gopls/internal/analysis/fillswitch" "golang.org/x/tools/gopls/internal/cache" @@ -27,7 +28,6 @@ import ( "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/protocol/command" "golang.org/x/tools/gopls/internal/settings" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/imports" "golang.org/x/tools/internal/typesinternal" @@ -831,7 +831,7 @@ func selectionContainsStructField(node *ast.StructType, start, end token.Pos, ne // fields within start and end positions. If removeTags is true, it means the // current command is for remove tags rather than add tags, so we only return // true if the struct field found contains a struct tag to remove. -func selectionContainsStruct(cursor cursor.Cursor, start, end token.Pos, removeTags bool) bool { +func selectionContainsStruct(cursor inspector.Cursor, start, end token.Pos, removeTags bool) bool { cur, ok := cursor.FindByPos(start, end) if !ok { return false @@ -1084,7 +1084,7 @@ func goAssembly(ctx context.Context, req *codeActionsRequest) error { func toggleCompilerOptDetails(ctx context.Context, req *codeActionsRequest) error { // TODO(adonovan): errors from code action providers should probably be // logged, even if they aren't visible to the client; see https://go.dev/issue/71275. - if meta, err := NarrowestMetadataForFile(ctx, req.snapshot, req.fh.URI()); err == nil { + if meta, err := req.snapshot.NarrowestMetadataForFile(ctx, req.fh.URI()); err == nil { if len(meta.CompiledGoFiles) == 0 { return fmt.Errorf("package %q does not compile file %q", meta.ID, req.fh.URI()) } diff --git a/gopls/internal/golang/completion/newfile.go b/gopls/internal/golang/completion/newfile.go index d9869a2f050..38dcadc238f 100644 --- a/gopls/internal/golang/completion/newfile.go +++ b/gopls/internal/golang/completion/newfile.go @@ -21,7 +21,7 @@ func NewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) (*pr if bs, err := fh.Content(); err != nil || len(bs) != 0 { return nil, err } - meta, err := golang.NarrowestMetadataForFile(ctx, snapshot, fh.URI()) + meta, err := snapshot.NarrowestMetadataForFile(ctx, fh.URI()) if err != nil { return nil, err } diff --git a/gopls/internal/golang/diagnostics.go b/gopls/internal/golang/diagnostics.go index f65ca4f7047..6708d32fcbb 100644 --- a/gopls/internal/golang/diagnostics.go +++ b/gopls/internal/golang/diagnostics.go @@ -16,7 +16,7 @@ import ( // DiagnoseFile returns pull-based diagnostics for the given file. func DiagnoseFile(ctx context.Context, snapshot *cache.Snapshot, uri protocol.DocumentURI) ([]*cache.Diagnostic, error) { - mp, err := NarrowestMetadataForFile(ctx, snapshot, uri) + mp, err := snapshot.NarrowestMetadataForFile(ctx, uri) if err != nil { return nil, err } diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 5e82e430225..f19285b8a3c 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -21,13 +21,13 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/parsego" goplsastutil "golang.org/x/tools/gopls/internal/util/astutil" "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" "golang.org/x/tools/internal/analysisinternal" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/typesinternal" ) @@ -391,7 +391,7 @@ func stmtToInsertVarBefore(path []ast.Node, variables []*variable) (ast.Stmt, er // canExtractVariable reports whether the code in the given range can be // extracted to a variable (or constant). It returns the selected expression or, if 'all', // all structurally equivalent expressions within the same function body, in lexical order. -func canExtractVariable(info *types.Info, curFile cursor.Cursor, start, end token.Pos, all bool) ([]ast.Expr, error) { +func canExtractVariable(info *types.Info, curFile inspector.Cursor, start, end token.Pos, all bool) ([]ast.Expr, error) { if start == end { return nil, fmt.Errorf("empty selection") } @@ -1238,7 +1238,7 @@ func moveParamToFrontIfFound(params []ast.Expr, paramTypes []*ast.Field, x, sel // their cursors for whitespace. To support this use case, we must manually adjust the // ranges to match the correct AST node. In this particular example, we would adjust // rng.Start forward to the start of 'if' and rng.End backward to after '}'. -func adjustRangeForCommentsAndWhiteSpace(tok *token.File, start, end token.Pos, content []byte, curFile cursor.Cursor) (token.Pos, token.Pos, error) { +func adjustRangeForCommentsAndWhiteSpace(tok *token.File, start, end token.Pos, content []byte, curFile inspector.Cursor) (token.Pos, token.Pos, error) { file := curFile.Node().(*ast.File) // TODO(adonovan): simplify, using Cursor. @@ -1568,7 +1568,7 @@ type fnExtractParams struct { // canExtractFunction reports whether the code in the given range can be // extracted to a function. -func canExtractFunction(tok *token.File, start, end token.Pos, src []byte, curFile cursor.Cursor) (*fnExtractParams, bool, bool, error) { +func canExtractFunction(tok *token.File, start, end token.Pos, src []byte, curFile inspector.Cursor) (*fnExtractParams, bool, bool, error) { if start == end { return nil, false, false, fmt.Errorf("start and end are equal") } @@ -2022,7 +2022,7 @@ func replaceBranchStmtWithReturnStmt(block ast.Node, br *ast.BranchStmt, ret *as // freeBranches returns all branch statements beneath cur whose continuation // lies outside the (start, end) range. -func freeBranches(info *types.Info, cur cursor.Cursor, start, end token.Pos) (free []*ast.BranchStmt) { +func freeBranches(info *types.Info, cur inspector.Cursor, start, end token.Pos) (free []*ast.BranchStmt) { nextBranch: for curBr := range cur.Preorder((*ast.BranchStmt)(nil)) { br := curBr.Node().(*ast.BranchStmt) diff --git a/gopls/internal/golang/format.go b/gopls/internal/golang/format.go index ded00deef38..ef98580abff 100644 --- a/gopls/internal/golang/format.go +++ b/gopls/internal/golang/format.go @@ -79,7 +79,7 @@ func Format(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) ([]pr // Can this, for example, result in inconsistent formatting across saves, // due to pending calls to packages.Load? var opts gofumptFormat.Options - meta, err := NarrowestMetadataForFile(ctx, snapshot, fh.URI()) + meta, err := snapshot.NarrowestMetadataForFile(ctx, fh.URI()) if err == nil { if mi := meta.Module; mi != nil { if v := mi.GoVersion; v != "" { diff --git a/gopls/internal/golang/implementation.go b/gopls/internal/golang/implementation.go index 675b232d0eb..678861440da 100644 --- a/gopls/internal/golang/implementation.go +++ b/gopls/internal/golang/implementation.go @@ -18,6 +18,8 @@ import ( "sync" "golang.org/x/sync/errgroup" + "golang.org/x/tools/go/ast/edge" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/metadata" @@ -28,8 +30,6 @@ import ( "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/moreiters" "golang.org/x/tools/gopls/internal/util/safetoken" - "golang.org/x/tools/internal/astutil/cursor" - "golang.org/x/tools/internal/astutil/edge" "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/typesinternal" ) @@ -1103,7 +1103,7 @@ func funcDefs(pkg *cache.Package, t types.Type) ([]protocol.Location, error) { // beneathFuncDef reports whether the specified FuncType cursor is a // child of Func{Decl,Lit}. -func beneathFuncDef(cur cursor.Cursor) bool { +func beneathFuncDef(cur inspector.Cursor) bool { switch ek, _ := cur.ParentEdge(); ek { case edge.FuncDecl_Type, edge.FuncLit_Type: return true diff --git a/gopls/internal/golang/inlay_hint.go b/gopls/internal/golang/inlay_hint.go index 617231a4f8c..589a809f933 100644 --- a/gopls/internal/golang/inlay_hint.go +++ b/gopls/internal/golang/inlay_hint.go @@ -13,12 +13,12 @@ import ( "go/types" "strings" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/parsego" "golang.org/x/tools/gopls/internal/file" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/settings" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/typeparams" "golang.org/x/tools/internal/typesinternal" @@ -74,7 +74,7 @@ func InlayHint(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, pR return hints, nil } -type inlayHintFunc func(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur cursor.Cursor, add func(protocol.InlayHint)) +type inlayHintFunc func(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur inspector.Cursor, add func(protocol.InlayHint)) var allInlayHints = map[settings.InlayHint]inlayHintFunc{ settings.AssignVariableTypes: assignVariableTypes, @@ -86,7 +86,7 @@ var allInlayHints = map[settings.InlayHint]inlayHintFunc{ settings.FunctionTypeParameters: funcTypeParams, } -func parameterNames(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur cursor.Cursor, add func(protocol.InlayHint)) { +func parameterNames(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur inspector.Cursor, add func(protocol.InlayHint)) { for curCall := range cur.Preorder((*ast.CallExpr)(nil)) { callExpr := curCall.Node().(*ast.CallExpr) t := info.TypeOf(callExpr.Fun) @@ -134,7 +134,7 @@ func parameterNames(info *types.Info, pgf *parsego.File, qual types.Qualifier, c } } -func funcTypeParams(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur cursor.Cursor, add func(protocol.InlayHint)) { +func funcTypeParams(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur inspector.Cursor, add func(protocol.InlayHint)) { for curCall := range cur.Preorder((*ast.CallExpr)(nil)) { call := curCall.Node().(*ast.CallExpr) id, ok := call.Fun.(*ast.Ident) @@ -164,7 +164,7 @@ func funcTypeParams(info *types.Info, pgf *parsego.File, qual types.Qualifier, c } } -func assignVariableTypes(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur cursor.Cursor, add func(protocol.InlayHint)) { +func assignVariableTypes(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur inspector.Cursor, add func(protocol.InlayHint)) { for curAssign := range cur.Preorder((*ast.AssignStmt)(nil)) { stmt := curAssign.Node().(*ast.AssignStmt) if stmt.Tok != token.DEFINE { @@ -176,7 +176,7 @@ func assignVariableTypes(info *types.Info, pgf *parsego.File, qual types.Qualifi } } -func rangeVariableTypes(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur cursor.Cursor, add func(protocol.InlayHint)) { +func rangeVariableTypes(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur inspector.Cursor, add func(protocol.InlayHint)) { for curRange := range cur.Preorder((*ast.RangeStmt)(nil)) { rStmt := curRange.Node().(*ast.RangeStmt) variableType(info, pgf, qual, rStmt.Key, add) @@ -201,7 +201,7 @@ func variableType(info *types.Info, pgf *parsego.File, qual types.Qualifier, e a }) } -func constantValues(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur cursor.Cursor, add func(protocol.InlayHint)) { +func constantValues(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur inspector.Cursor, add func(protocol.InlayHint)) { for curDecl := range cur.Preorder((*ast.GenDecl)(nil)) { genDecl := curDecl.Node().(*ast.GenDecl) if genDecl.Tok != token.CONST { @@ -252,7 +252,7 @@ func constantValues(info *types.Info, pgf *parsego.File, qual types.Qualifier, c } } -func compositeLiteralFields(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur cursor.Cursor, add func(protocol.InlayHint)) { +func compositeLiteralFields(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur inspector.Cursor, add func(protocol.InlayHint)) { for curCompLit := range cur.Preorder((*ast.CompositeLit)(nil)) { compLit, ok := curCompLit.Node().(*ast.CompositeLit) if !ok { @@ -300,7 +300,7 @@ func compositeLiteralFields(info *types.Info, pgf *parsego.File, qual types.Qual } } -func compositeLiteralTypes(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur cursor.Cursor, add func(protocol.InlayHint)) { +func compositeLiteralTypes(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur inspector.Cursor, add func(protocol.InlayHint)) { for curCompLit := range cur.Preorder((*ast.CompositeLit)(nil)) { compLit := curCompLit.Node().(*ast.CompositeLit) typ := info.TypeOf(compLit) diff --git a/gopls/internal/golang/inline_all.go b/gopls/internal/golang/inline_all.go index ec9a458d61a..07a858e00a4 100644 --- a/gopls/internal/golang/inline_all.go +++ b/gopls/internal/golang/inline_all.go @@ -72,7 +72,7 @@ func inlineAllCalls(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Pa { needPkgs := make(map[PackageID]struct{}) for _, ref := range refs { - md, err := NarrowestMetadataForFile(ctx, snapshot, ref.URI) + md, err := snapshot.NarrowestMetadataForFile(ctx, ref.URI) if err != nil { return nil, fmt.Errorf("finding ref metadata: %v", err) } diff --git a/gopls/internal/golang/invertifcondition.go b/gopls/internal/golang/invertifcondition.go index c8cd7deef5e..dcab7da898f 100644 --- a/gopls/internal/golang/invertifcondition.go +++ b/gopls/internal/golang/invertifcondition.go @@ -12,10 +12,10 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/parsego" "golang.org/x/tools/gopls/internal/util/safetoken" - "golang.org/x/tools/internal/astutil/cursor" ) // invertIfCondition is a singleFileFixFunc that inverts an if/else statement @@ -245,7 +245,7 @@ func invertAndOr(fset *token.FileSet, expr *ast.BinaryExpr, src []byte) ([]byte, // canInvertIfCondition reports whether we can do invert-if-condition on the // code in the given range. -func canInvertIfCondition(curFile cursor.Cursor, start, end token.Pos) (*ast.IfStmt, bool, error) { +func canInvertIfCondition(curFile inspector.Cursor, start, end token.Pos) (*ast.IfStmt, bool, error) { file := curFile.Node().(*ast.File) // TODO(adonovan): simplify, using Cursor. path, _ := astutil.PathEnclosingInterval(file, start, end) diff --git a/gopls/internal/golang/known_packages.go b/gopls/internal/golang/known_packages.go index 3b320d4f782..92f766471d4 100644 --- a/gopls/internal/golang/known_packages.go +++ b/gopls/internal/golang/known_packages.go @@ -30,7 +30,7 @@ func KnownPackagePaths(ctx context.Context, snapshot *cache.Snapshot, fh file.Ha // This algorithm is expressed in terms of Metadata, not Packages, // so it doesn't cause or wait for type checking. - current, err := NarrowestMetadataForFile(ctx, snapshot, fh.URI()) + current, err := snapshot.NarrowestMetadataForFile(ctx, fh.URI()) if err != nil { return nil, err // e.g. context cancelled } diff --git a/gopls/internal/golang/lines.go b/gopls/internal/golang/lines.go index cb161671726..d6eca0feec6 100644 --- a/gopls/internal/golang/lines.go +++ b/gopls/internal/golang/lines.go @@ -18,15 +18,15 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/parsego" "golang.org/x/tools/gopls/internal/util/safetoken" - "golang.org/x/tools/internal/astutil/cursor" ) // canSplitLines checks whether we can split lists of elements inside // an enclosing curly bracket/parens into separate lines. -func canSplitLines(curFile cursor.Cursor, fset *token.FileSet, start, end token.Pos) (string, bool, error) { +func canSplitLines(curFile inspector.Cursor, fset *token.FileSet, start, end token.Pos) (string, bool, error) { itemType, items, comments, _, _, _ := findSplitJoinTarget(fset, curFile, nil, start, end) if itemType == "" { return "", false, nil @@ -49,7 +49,7 @@ func canSplitLines(curFile cursor.Cursor, fset *token.FileSet, start, end token. // canJoinLines checks whether we can join lists of elements inside an // enclosing curly bracket/parens into a single line. -func canJoinLines(curFile cursor.Cursor, fset *token.FileSet, start, end token.Pos) (string, bool, error) { +func canJoinLines(curFile inspector.Cursor, fset *token.FileSet, start, end token.Pos) (string, bool, error) { itemType, items, comments, _, _, _ := findSplitJoinTarget(fset, curFile, nil, start, end) if itemType == "" { return "", false, nil @@ -170,7 +170,7 @@ func processLines(fset *token.FileSet, items []ast.Node, comments []*ast.Comment } // findSplitJoinTarget returns the first curly bracket/parens that encloses the current cursor. -func findSplitJoinTarget(fset *token.FileSet, curFile cursor.Cursor, src []byte, start, end token.Pos) (itemType string, items []ast.Node, comments []*ast.CommentGroup, indent string, open, close token.Pos) { +func findSplitJoinTarget(fset *token.FileSet, curFile inspector.Cursor, src []byte, start, end token.Pos) (itemType string, items []ast.Node, comments []*ast.CommentGroup, indent string, open, close token.Pos) { isCursorInside := func(nodePos, nodeEnd token.Pos) bool { return nodePos < start && end < nodeEnd } diff --git a/gopls/internal/golang/rename.go b/gopls/internal/golang/rename.go index f11a3356d74..85c3c517245 100644 --- a/gopls/internal/golang/rename.go +++ b/gopls/internal/golang/rename.go @@ -62,6 +62,7 @@ import ( "golang.org/x/mod/modfile" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/types/objectpath" "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/gopls/internal/cache" @@ -74,7 +75,6 @@ import ( "golang.org/x/tools/gopls/internal/util/moreiters" "golang.org/x/tools/gopls/internal/util/safetoken" internalastutil "golang.org/x/tools/internal/astutil" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/typesinternal" @@ -176,7 +176,7 @@ func prepareRenamePackageName(ctx context.Context, snapshot *cache.Snapshot, pgf } // Check validity of the metadata for the file's containing package. - meta, err := NarrowestMetadataForFile(ctx, snapshot, pgf.URI) + meta, err := snapshot.NarrowestMetadataForFile(ctx, pgf.URI) if err != nil { return nil, err } @@ -473,7 +473,7 @@ func renameOrdinary(ctx context.Context, snapshot *cache.Snapshot, f file.Handle // computes the union across all variants.) var targets map[types.Object]ast.Node var pkg *cache.Package - var cur cursor.Cursor // of selected Ident or ImportSpec + var cur inspector.Cursor // of selected Ident or ImportSpec { mps, err := snapshot.MetadataForFile(ctx, f.URI()) if err != nil { @@ -946,7 +946,7 @@ func renamePackage(ctx context.Context, s *cache.Snapshot, f file.Handle, newNam // We need metadata for the relevant package and module paths. // These should be the same for all packages containing the file. - meta, err := NarrowestMetadataForFile(ctx, s, f.URI()) + meta, err := s.NarrowestMetadataForFile(ctx, f.URI()) if err != nil { return nil, err } diff --git a/gopls/internal/golang/rename_check.go b/gopls/internal/golang/rename_check.go index 6b89cabbe81..060a2f5e6c6 100644 --- a/gopls/internal/golang/rename_check.go +++ b/gopls/internal/golang/rename_check.go @@ -43,10 +43,10 @@ import ( "unicode" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ast/edge" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/util/safetoken" - "golang.org/x/tools/internal/astutil/cursor" - "golang.org/x/tools/internal/astutil/edge" "golang.org/x/tools/internal/typeparams" "golang.org/x/tools/internal/typesinternal" "golang.org/x/tools/refactor/satisfy" @@ -346,8 +346,8 @@ func forEachLexicalRef(pkg *cache.Package, obj types.Object, fn func(id *ast.Ide (*ast.CompositeLit)(nil), } ok := true - var visit func(cur cursor.Cursor) (descend bool) - visit = func(cur cursor.Cursor) (descend bool) { + var visit func(cur inspector.Cursor) (descend bool) + visit = func(cur inspector.Cursor) (descend bool) { if !ok { return false // bail out } @@ -401,7 +401,7 @@ func forEachLexicalRef(pkg *cache.Package, obj types.Object, fn func(id *ast.Ide // enclosingBlock returns the innermost block logically enclosing the // AST node (an ast.Ident), specified as a Cursor. -func enclosingBlock(info *types.Info, curId cursor.Cursor) *types.Scope { +func enclosingBlock(info *types.Info, curId inspector.Cursor) *types.Scope { for cur := range curId.Enclosing() { n := cur.Node() // For some reason, go/types always associates a diff --git a/gopls/internal/golang/symbols.go b/gopls/internal/golang/symbols.go index 53fbb663800..c49a498ab18 100644 --- a/gopls/internal/golang/symbols.go +++ b/gopls/internal/golang/symbols.go @@ -94,7 +94,7 @@ func PackageSymbols(ctx context.Context, snapshot *cache.Snapshot, uri protocol. // golang/vscode-go#3681: do our best if the file is not in a package. // TODO(rfindley): revisit this in the future once there is more graceful // handling in VS Code. - if mp, err := NarrowestMetadataForFile(ctx, snapshot, uri); err == nil { + if mp, err := snapshot.NarrowestMetadataForFile(ctx, uri); err == nil { pkgFiles = mp.CompiledGoFiles } diff --git a/gopls/internal/protocol/protocol.go b/gopls/internal/protocol/protocol.go index f98d6371273..2d6d8173523 100644 --- a/gopls/internal/protocol/protocol.go +++ b/gopls/internal/protocol/protocol.go @@ -11,7 +11,6 @@ import ( "fmt" "io" - "golang.org/x/telemetry/crashmonitor" "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/jsonrpc2" @@ -302,7 +301,7 @@ func recoverHandlerPanic(method string) { // Report panics in the handler goroutine, // unless we have enabled the monitor, // which reports all crashes. - if !crashmonitor.Supported() { + if !true { defer func() { if x := recover(); x != nil { bug.Reportf("panic in %s request", method) diff --git a/gopls/internal/server/command.go b/gopls/internal/server/command.go index a3345d33a1d..b16009ec0ce 100644 --- a/gopls/internal/server/command.go +++ b/gopls/internal/server/command.go @@ -736,7 +736,7 @@ func (c *commandHandler) RunTests(ctx context.Context, args command.RunTestsArgs func (c *commandHandler) runTests(ctx context.Context, snapshot *cache.Snapshot, work *progress.WorkDone, uri protocol.DocumentURI, tests, benchmarks []string) error { // TODO: fix the error reporting when this runs async. - meta, err := golang.NarrowestMetadataForFile(ctx, snapshot, uri) + meta, err := snapshot.NarrowestMetadataForFile(ctx, uri) if err != nil { return err } @@ -1021,7 +1021,7 @@ func (c *commandHandler) GCDetails(ctx context.Context, uri protocol.DocumentURI }, func(ctx context.Context, deps commandDeps) error { return c.modifyState(ctx, FromToggleCompilerOptDetails, func() (*cache.Snapshot, func(), error) { // Don't blindly use "dir := deps.fh.URI().Dir()"; validate. - meta, err := golang.NarrowestMetadataForFile(ctx, deps.snapshot, deps.fh.URI()) + meta, err := deps.snapshot.NarrowestMetadataForFile(ctx, deps.fh.URI()) if err != nil { return nil, nil, err } @@ -1082,7 +1082,7 @@ func (c *commandHandler) ListImports(ctx context.Context, args command.URIArg) ( }) } } - meta, err := golang.NarrowestMetadataForFile(ctx, deps.snapshot, args.URI) + meta, err := deps.snapshot.NarrowestMetadataForFile(ctx, args.URI) if err != nil { return err // e.g. cancelled } diff --git a/gopls/internal/server/diagnostics.go b/gopls/internal/server/diagnostics.go index dbffc58fd99..95046d98117 100644 --- a/gopls/internal/server/diagnostics.go +++ b/gopls/internal/server/diagnostics.go @@ -272,7 +272,7 @@ func (s *server) diagnoseChangedFiles(ctx context.Context, snapshot *cache.Snaps } // Find all packages that include this file and diagnose them in parallel. - meta, err := golang.NarrowestMetadataForFile(ctx, snapshot, uri) + meta, err := snapshot.NarrowestMetadataForFile(ctx, uri) if err != nil { if ctx.Err() != nil { return nil, ctx.Err() diff --git a/gopls/internal/server/link.go b/gopls/internal/server/link.go index 75c717dbe8e..52e8ca379c5 100644 --- a/gopls/internal/server/link.go +++ b/gopls/internal/server/link.go @@ -162,7 +162,7 @@ func goLinks(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) ([]p // This requires the import map from the package metadata. Ignore errors. var depsByImpPath map[golang.ImportPath]golang.PackageID if strings.ToLower(snapshot.Options().LinkTarget) == "pkg.go.dev" { - if meta, err := golang.NarrowestMetadataForFile(ctx, snapshot, fh.URI()); err == nil { + if meta, err := snapshot.NarrowestMetadataForFile(ctx, fh.URI()); err == nil { depsByImpPath = meta.DepsByImpPath } } diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go index c4d10de3e91..f54c3f4208d 100644 --- a/internal/analysisinternal/analysis.go +++ b/internal/analysisinternal/analysis.go @@ -22,7 +22,6 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/inspector" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/typesinternal" ) @@ -526,7 +525,7 @@ func CanImport(from, to string) bool { func DeleteStmt(fset *token.FileSet, astFile *ast.File, stmt ast.Stmt, report func(string, ...any)) []analysis.TextEdit { // TODO: pass in the cursor to a ast.Stmt. callers should provide the Cursor insp := inspector.New([]*ast.File{astFile}) - root := cursor.Root(insp) + root := insp.Root() cstmt, ok := root.FindNode(stmt) if !ok { report("%s not found in file", stmt.Pos()) diff --git a/internal/analysisinternal/analysis_test.go b/internal/analysisinternal/analysis_test.go index e3c760aff5a..6aaf0f6df06 100644 --- a/internal/analysisinternal/analysis_test.go +++ b/internal/analysisinternal/analysis_test.go @@ -12,7 +12,6 @@ import ( "testing" "golang.org/x/tools/go/ast/inspector" - "golang.org/x/tools/internal/astutil/cursor" ) func TestCanImport(t *testing.T) { @@ -215,8 +214,8 @@ func TestDeleteStmt(t *testing.T) { t.Fatalf("%s: %v", tt.name, err) } insp := inspector.New([]*ast.File{f}) - root := cursor.Root(insp) - var stmt cursor.Cursor + root := insp.Root() + var stmt inspector.Cursor cnt := 0 for cn := range root.Preorder() { // Preorder(ast.Stmt(nil)) doesn't work if _, ok := cn.Node().(ast.Stmt); !ok { diff --git a/internal/astutil/edge/edge.go b/internal/astutil/edge/edge.go index c3f7661cbad..5ec9f4a356c 100644 --- a/internal/astutil/edge/edge.go +++ b/internal/astutil/edge/edge.go @@ -13,7 +13,7 @@ type Kind = edge.Kind //go:fix inline const ( - Invalid Kind = edge.Invalid + Invalid edge.Kind = edge.Invalid // Kinds are sorted alphabetically. // Numbering is not stable. diff --git a/internal/gofix/findgofix/findgofix.go b/internal/gofix/findgofix/findgofix.go index 38ce079b923..ceb42f8ee55 100644 --- a/internal/gofix/findgofix/findgofix.go +++ b/internal/gofix/findgofix/findgofix.go @@ -19,8 +19,8 @@ import ( "go/types" "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/ast/inspector" internalastutil "golang.org/x/tools/internal/astutil" - "golang.org/x/tools/internal/astutil/cursor" ) // A Handler handles language entities with go:fix directives. @@ -33,7 +33,7 @@ type Handler interface { // Find finds functions and constants annotated with an appropriate "//go:fix" // comment (the syntax proposed by #32816), and calls handler methods for each one. // h may be nil. -func Find(pass *analysis.Pass, root cursor.Cursor, h Handler) { +func Find(pass *analysis.Pass, root inspector.Cursor, h Handler) { for cur := range root.Preorder((*ast.FuncDecl)(nil), (*ast.GenDecl)(nil)) { switch decl := cur.Node().(type) { case *ast.FuncDecl: diff --git a/internal/gofix/gofix.go b/internal/gofix/gofix.go index 904f17cf3d5..51b23c65849 100644 --- a/internal/gofix/gofix.go +++ b/internal/gofix/gofix.go @@ -17,11 +17,10 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/edge" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/internal/analysisinternal" - "golang.org/x/tools/internal/astutil/cursor" - "golang.org/x/tools/internal/astutil/edge" "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/gofix/findgofix" "golang.org/x/tools/internal/refactor/inline" @@ -47,7 +46,7 @@ var Analyzer = &analysis.Analyzer{ // analyzer holds the state for this analysis. type analyzer struct { pass *analysis.Pass - root cursor.Cursor + root inspector.Cursor // memoization of repeated calls for same file. fileContent map[string][]byte // memoization of fact imports (nil => no fact) @@ -59,7 +58,7 @@ type analyzer struct { func run(pass *analysis.Pass) (any, error) { a := &analyzer{ pass: pass, - root: cursor.Root(pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)), + root: pass.ResultOf[inspect.Analyzer].(*inspector.Inspector).Root(), fileContent: make(map[string][]byte), inlinableFuncs: make(map[*types.Func]*inline.Callee), inlinableConsts: make(map[*types.Const]*goFixInlineConstFact), @@ -144,7 +143,7 @@ func (a *analyzer) inline() { } // If call is a call to an inlinable func, suggest inlining its use at cur. -func (a *analyzer) inlineCall(call *ast.CallExpr, cur cursor.Cursor) { +func (a *analyzer) inlineCall(call *ast.CallExpr, cur inspector.Cursor) { if fn := typeutil.StaticCallee(a.pass.TypesInfo, call); fn != nil { // Inlinable? callee, ok := a.inlinableFuncs[fn] @@ -214,7 +213,7 @@ func (a *analyzer) inlineCall(call *ast.CallExpr, cur cursor.Cursor) { } // If tn is the TypeName of an inlinable alias, suggest inlining its use at cur. -func (a *analyzer) inlineAlias(tn *types.TypeName, curId cursor.Cursor) { +func (a *analyzer) inlineAlias(tn *types.TypeName, curId inspector.Cursor) { inalias, ok := a.inlinableAliases[tn] if !ok { var fact goFixInlineAliasFact @@ -396,7 +395,7 @@ func typenames(t types.Type) []*types.TypeName { } // If con is an inlinable constant, suggest inlining its use at cur. -func (a *analyzer) inlineConst(con *types.Const, cur cursor.Cursor) { +func (a *analyzer) inlineConst(con *types.Const, cur inspector.Cursor) { incon, ok := a.inlinableConsts[con] if !ok { var fact goFixInlineConstFact @@ -489,7 +488,7 @@ func (a *analyzer) readFile(node ast.Node) ([]byte, error) { } // currentFile returns the unique ast.File for a cursor. -func currentFile(c cursor.Cursor) *ast.File { +func currentFile(c inspector.Cursor) *ast.File { for cf := range c.Enclosing((*ast.File)(nil)) { return cf.Node().(*ast.File) } diff --git a/internal/typesinternal/typeindex/typeindex.go b/internal/typesinternal/typeindex/typeindex.go index 34087a98fbf..e03deef4409 100644 --- a/internal/typesinternal/typeindex/typeindex.go +++ b/internal/typesinternal/typeindex/typeindex.go @@ -4,7 +4,7 @@ // Package typeindex provides an [Index] of type information for a // package, allowing efficient lookup of, say, whether a given symbol -// is referenced and, if so, where from; or of the [cursor.Cursor] for +// is referenced and, if so, where from; or of the [inspector.Cursor] for // the declaration of a particular [types.Object] symbol. package typeindex @@ -14,10 +14,9 @@ import ( "go/types" "iter" + "golang.org/x/tools/go/ast/edge" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/types/typeutil" - "golang.org/x/tools/internal/astutil/cursor" - "golang.org/x/tools/internal/astutil/edge" "golang.org/x/tools/internal/typesinternal" ) @@ -30,7 +29,7 @@ func New(inspect *inspector.Inspector, pkg *types.Package, info *types.Info) *In inspect: inspect, info: info, packages: make(map[string]*types.Package), - def: make(map[types.Object]cursor.Cursor), + def: make(map[types.Object]inspector.Cursor), uses: make(map[types.Object]*uses), } @@ -40,7 +39,7 @@ func New(inspect *inspector.Inspector, pkg *types.Package, info *types.Info) *In } } - for cur := range cursor.Root(inspect).Preorder((*ast.ImportSpec)(nil), (*ast.Ident)(nil)) { + for cur := range inspect.Root().Preorder((*ast.ImportSpec)(nil), (*ast.Ident)(nil)) { switch n := cur.Node().(type) { case *ast.ImportSpec: // Index direct imports, including blank ones. @@ -83,9 +82,9 @@ func New(inspect *inspector.Inspector, pkg *types.Package, info *types.Info) *In type Index struct { inspect *inspector.Inspector info *types.Info - packages map[string]*types.Package // packages of all symbols referenced from this package - def map[types.Object]cursor.Cursor // Cursor of *ast.Ident that defines the Object - uses map[types.Object]*uses // Cursors of *ast.Idents that use the Object + packages map[string]*types.Package // packages of all symbols referenced from this package + def map[types.Object]inspector.Cursor // Cursor of *ast.Ident that defines the Object + uses map[types.Object]*uses // Cursors of *ast.Idents that use the Object } // A uses holds the list of Cursors of Idents that use a given symbol. @@ -107,14 +106,14 @@ type uses struct { // Uses returns the sequence of Cursors of [*ast.Ident]s in this package // that refer to obj. If obj is nil, the sequence is empty. -func (ix *Index) Uses(obj types.Object) iter.Seq[cursor.Cursor] { - return func(yield func(cursor.Cursor) bool) { +func (ix *Index) Uses(obj types.Object) iter.Seq[inspector.Cursor] { + return func(yield func(inspector.Cursor) bool) { if uses := ix.uses[obj]; uses != nil { var last int32 for code := uses.code; len(code) > 0; { delta, n := binary.Uvarint(code) last += int32(delta) - if !yield(cursor.At(ix.inspect, last)) { + if !yield(ix.inspect.At(last)) { return } code = code[n:] @@ -140,7 +139,7 @@ func (ix *Index) Used(objs ...types.Object) bool { // Def returns the Cursor of the [*ast.Ident] in this package // that declares the specified object, if any. -func (ix *Index) Def(obj types.Object) (cursor.Cursor, bool) { +func (ix *Index) Def(obj types.Object) (inspector.Cursor, bool) { cur, ok := ix.def[obj] return cur, ok } @@ -176,8 +175,8 @@ func (ix *Index) Selection(path, typename, name string) types.Object { // Calls returns the sequence of cursors for *ast.CallExpr nodes that // call the specified callee, as defined by [typeutil.Callee]. // If callee is nil, the sequence is empty. -func (ix *Index) Calls(callee types.Object) iter.Seq[cursor.Cursor] { - return func(yield func(cursor.Cursor) bool) { +func (ix *Index) Calls(callee types.Object) iter.Seq[inspector.Cursor] { + return func(yield func(inspector.Cursor) bool) { for cur := range ix.Uses(callee) { ek, _ := cur.ParentEdge() diff --git a/internal/typesinternal/typeindex/typeindex_test.go b/internal/typesinternal/typeindex/typeindex_test.go index c8b08dc9d00..9bba7a48ffa 100644 --- a/internal/typesinternal/typeindex/typeindex_test.go +++ b/internal/typesinternal/typeindex/typeindex_test.go @@ -14,7 +14,6 @@ import ( "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/typeutil" - "golang.org/x/tools/internal/astutil/cursor" "golang.org/x/tools/internal/testenv" "golang.org/x/tools/internal/typesinternal/typeindex" ) @@ -134,7 +133,7 @@ func BenchmarkIndex(b *testing.B) { b.Run("cursor", func(b *testing.B) { for b.Loop() { countB = 0 - for curCall := range cursor.Root(inspect).Preorder((*ast.CallExpr)(nil)) { + for curCall := range inspect.Root().Preorder((*ast.CallExpr)(nil)) { call := curCall.Node().(*ast.CallExpr) if typeutil.Callee(pkg.TypesInfo, call) == target { countB++ From d2ad3e0486b84f781eaaad3ec8c45c14a6a70b86 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 12 May 2025 13:51:32 -0400 Subject: [PATCH 074/196] internal/astutil/cursor: delete shims for old Cursor Updates golang/go#70859 Change-Id: Iaaa4bbddc7f3a4651b6c341f9e4d6ce8739b328a Reviewed-on: https://go-review.googlesource.com/c/tools/+/672017 Auto-Submit: Alan Donovan Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- go/ast/inspector/cursor.go | 3 - go/ast/inspector/cursor_test.go | 2 +- go/ast/inspector/inspector.go | 3 +- go/ast/inspector/walk.go | 2 +- gopls/internal/golang/snapshot.go | 5 -- internal/astutil/cursor/cursor.go | 17 ---- internal/astutil/edge/edge.go | 127 ------------------------------ 7 files changed, 3 insertions(+), 156 deletions(-) delete mode 100644 internal/astutil/cursor/cursor.go delete mode 100644 internal/astutil/edge/edge.go diff --git a/go/ast/inspector/cursor.go b/go/ast/inspector/cursor.go index bec9e4decac..31c8d2f2409 100644 --- a/go/ast/inspector/cursor.go +++ b/go/ast/inspector/cursor.go @@ -4,9 +4,6 @@ package inspector -// TODO(adonovan): -// - apply-all //go:fix inline - import ( "fmt" "go/ast" diff --git a/go/ast/inspector/cursor_test.go b/go/ast/inspector/cursor_test.go index 90067c67060..8cda063ca21 100644 --- a/go/ast/inspector/cursor_test.go +++ b/go/ast/inspector/cursor_test.go @@ -16,8 +16,8 @@ import ( "strings" "testing" + "golang.org/x/tools/go/ast/edge" "golang.org/x/tools/go/ast/inspector" - "golang.org/x/tools/internal/astutil/edge" ) func TestCursor_Preorder(t *testing.T) { diff --git a/go/ast/inspector/inspector.go b/go/ast/inspector/inspector.go index 656302e2494..bc44b2c8e7e 100644 --- a/go/ast/inspector/inspector.go +++ b/go/ast/inspector/inspector.go @@ -46,9 +46,8 @@ package inspector import ( "go/ast" - _ "unsafe" - "golang.org/x/tools/internal/astutil/edge" + "golang.org/x/tools/go/ast/edge" ) // An Inspector provides methods for inspecting diff --git a/go/ast/inspector/walk.go b/go/ast/inspector/walk.go index 5a42174a0a0..5f1c93c8a73 100644 --- a/go/ast/inspector/walk.go +++ b/go/ast/inspector/walk.go @@ -13,7 +13,7 @@ import ( "fmt" "go/ast" - "golang.org/x/tools/internal/astutil/edge" + "golang.org/x/tools/go/ast/edge" ) func walkList[N ast.Node](v *visitor, ek edge.Kind, list []N) { diff --git a/gopls/internal/golang/snapshot.go b/gopls/internal/golang/snapshot.go index 30199d45463..53b2b872e6c 100644 --- a/gopls/internal/golang/snapshot.go +++ b/gopls/internal/golang/snapshot.go @@ -14,11 +14,6 @@ import ( "golang.org/x/tools/gopls/internal/protocol" ) -//go:fix inline -func NarrowestMetadataForFile(ctx context.Context, snapshot *cache.Snapshot, uri protocol.DocumentURI) (*metadata.Package, error) { - return snapshot.NarrowestMetadataForFile(ctx, uri) -} - // NarrowestPackageForFile is a convenience function that selects the narrowest // non-ITV package to which this file belongs, type-checks it in the requested // mode (full or workspace), and returns it, along with the parse tree of that diff --git a/internal/astutil/cursor/cursor.go b/internal/astutil/cursor/cursor.go deleted file mode 100644 index e328c484a08..00000000000 --- a/internal/astutil/cursor/cursor.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package cursor is deprecated; use [inspector.Cursor]. -package cursor - -import "golang.org/x/tools/go/ast/inspector" - -//go:fix inline -type Cursor = inspector.Cursor - -//go:fix inline -func Root(in *inspector.Inspector) inspector.Cursor { return in.Root() } - -//go:fix inline -func At(in *inspector.Inspector, index int32) inspector.Cursor { return in.At(index) } diff --git a/internal/astutil/edge/edge.go b/internal/astutil/edge/edge.go deleted file mode 100644 index 5ec9f4a356c..00000000000 --- a/internal/astutil/edge/edge.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2025 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package edge defines identifiers for each field of an ast.Node -// struct type that refers to another Node. -package edge - -import "golang.org/x/tools/go/ast/edge" - -//go:fix inline -type Kind = edge.Kind - -//go:fix inline -const ( - Invalid edge.Kind = edge.Invalid - - // Kinds are sorted alphabetically. - // Numbering is not stable. - // Each is named Type_Field, where Type is the - // ast.Node struct type and Field is the name of the field - - ArrayType_Elt = edge.ArrayType_Elt - ArrayType_Len = edge.ArrayType_Len - AssignStmt_Lhs = edge.AssignStmt_Lhs - AssignStmt_Rhs = edge.AssignStmt_Rhs - BinaryExpr_X = edge.BinaryExpr_X - BinaryExpr_Y = edge.BinaryExpr_Y - BlockStmt_List = edge.BlockStmt_List - BranchStmt_Label = edge.BranchStmt_Label - CallExpr_Args = edge.CallExpr_Args - CallExpr_Fun = edge.CallExpr_Fun - CaseClause_Body = edge.CaseClause_Body - CaseClause_List = edge.CaseClause_List - ChanType_Value = edge.ChanType_Value - CommClause_Body = edge.CommClause_Body - CommClause_Comm = edge.CommClause_Comm - CommentGroup_List = edge.CommentGroup_List - CompositeLit_Elts = edge.CompositeLit_Elts - CompositeLit_Type = edge.CompositeLit_Type - DeclStmt_Decl = edge.DeclStmt_Decl - DeferStmt_Call = edge.DeferStmt_Call - Ellipsis_Elt = edge.Ellipsis_Elt - ExprStmt_X = edge.ExprStmt_X - FieldList_List = edge.FieldList_List - Field_Comment = edge.Field_Comment - Field_Doc = edge.Field_Doc - Field_Names = edge.Field_Names - Field_Tag = edge.Field_Tag - Field_Type = edge.Field_Type - File_Decls = edge.File_Decls - File_Doc = edge.File_Doc - File_Name = edge.File_Name - ForStmt_Body = edge.ForStmt_Body - ForStmt_Cond = edge.ForStmt_Cond - ForStmt_Init = edge.ForStmt_Init - ForStmt_Post = edge.ForStmt_Post - FuncDecl_Body = edge.FuncDecl_Body - FuncDecl_Doc = edge.FuncDecl_Doc - FuncDecl_Name = edge.FuncDecl_Name - FuncDecl_Recv = edge.FuncDecl_Recv - FuncDecl_Type = edge.FuncDecl_Type - FuncLit_Body = edge.FuncLit_Body - FuncLit_Type = edge.FuncLit_Type - FuncType_Params = edge.FuncType_Params - FuncType_Results = edge.FuncType_Results - FuncType_TypeParams = edge.FuncType_TypeParams - GenDecl_Doc = edge.GenDecl_Doc - GenDecl_Specs = edge.GenDecl_Specs - GoStmt_Call = edge.GoStmt_Call - IfStmt_Body = edge.IfStmt_Body - IfStmt_Cond = edge.IfStmt_Cond - IfStmt_Else = edge.IfStmt_Else - IfStmt_Init = edge.IfStmt_Init - ImportSpec_Comment = edge.ImportSpec_Comment - ImportSpec_Doc = edge.ImportSpec_Doc - ImportSpec_Name = edge.ImportSpec_Name - ImportSpec_Path = edge.ImportSpec_Path - IncDecStmt_X = edge.IncDecStmt_X - IndexExpr_Index = edge.IndexExpr_Index - IndexExpr_X = edge.IndexExpr_X - IndexListExpr_Indices = edge.IndexListExpr_Indices - IndexListExpr_X = edge.IndexListExpr_X - InterfaceType_Methods = edge.InterfaceType_Methods - KeyValueExpr_Key = edge.KeyValueExpr_Key - KeyValueExpr_Value = edge.KeyValueExpr_Value - LabeledStmt_Label = edge.LabeledStmt_Label - LabeledStmt_Stmt = edge.LabeledStmt_Stmt - MapType_Key = edge.MapType_Key - MapType_Value = edge.MapType_Value - ParenExpr_X = edge.ParenExpr_X - RangeStmt_Body = edge.RangeStmt_Body - RangeStmt_Key = edge.RangeStmt_Key - RangeStmt_Value = edge.RangeStmt_Value - RangeStmt_X = edge.RangeStmt_X - ReturnStmt_Results = edge.ReturnStmt_Results - SelectStmt_Body = edge.SelectStmt_Body - SelectorExpr_Sel = edge.SelectorExpr_Sel - SelectorExpr_X = edge.SelectorExpr_X - SendStmt_Chan = edge.SendStmt_Chan - SendStmt_Value = edge.SendStmt_Value - SliceExpr_High = edge.SliceExpr_High - SliceExpr_Low = edge.SliceExpr_Low - SliceExpr_Max = edge.SliceExpr_Max - SliceExpr_X = edge.SliceExpr_X - StarExpr_X = edge.StarExpr_X - StructType_Fields = edge.StructType_Fields - SwitchStmt_Body = edge.SwitchStmt_Body - SwitchStmt_Init = edge.SwitchStmt_Init - SwitchStmt_Tag = edge.SwitchStmt_Tag - TypeAssertExpr_Type = edge.TypeAssertExpr_Type - TypeAssertExpr_X = edge.TypeAssertExpr_X - TypeSpec_Comment = edge.TypeSpec_Comment - TypeSpec_Doc = edge.TypeSpec_Doc - TypeSpec_Name = edge.TypeSpec_Name - TypeSpec_Type = edge.TypeSpec_Type - TypeSpec_TypeParams = edge.TypeSpec_TypeParams - TypeSwitchStmt_Assign = edge.TypeSwitchStmt_Assign - TypeSwitchStmt_Body = edge.TypeSwitchStmt_Body - TypeSwitchStmt_Init = edge.TypeSwitchStmt_Init - UnaryExpr_X = edge.UnaryExpr_X - ValueSpec_Comment = edge.ValueSpec_Comment - ValueSpec_Doc = edge.ValueSpec_Doc - ValueSpec_Names = edge.ValueSpec_Names - ValueSpec_Type = edge.ValueSpec_Type - ValueSpec_Values = edge.ValueSpec_Values -) From 3e377036196f644e59e757af8a38ea6afa07677c Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 14 May 2025 11:15:21 -0400 Subject: [PATCH 075/196] internal/refactor/inline: report when a "binding decl" was inserted Some clients prefer not to apply fixes that introduce a binding declarations (var params = args) for parameters that could not be substituted, on stylistic grounds. Expose that information. It is not yet used. Change-Id: I0f0dbc306b6ffaa5282182d98b854b63e5cb469d Reviewed-on: https://go-review.googlesource.com/c/tools/+/672695 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan --- internal/refactor/inline/inline.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/internal/refactor/inline/inline.go b/internal/refactor/inline/inline.go index 652ce8b28f2..a17078cb96c 100644 --- a/internal/refactor/inline/inline.go +++ b/internal/refactor/inline/inline.go @@ -57,6 +57,7 @@ type Options struct { type Result struct { Content []byte // formatted, transformed content of caller file Literalized bool // chosen strategy replaced callee() with func(){...}() + BindingDecl bool // transformation added "var params = args" declaration // TODO(adonovan): provide an API for clients that want structured // output: a list of import additions and deletions plus one or more @@ -379,6 +380,7 @@ func (st *state) inline() (*Result, error) { return &Result{ Content: newSrc, Literalized: literalized, + BindingDecl: res.bindingDecl, }, nil } @@ -587,6 +589,7 @@ type inlineCallResult struct { // unfortunately in order to preserve comments, it is important that inlining // replace as little syntax as possible. elideBraces bool + bindingDecl bool // transformation inserted "var params = args" declaration old, new ast.Node // e.g. replace call expr by callee function body expression } @@ -1008,6 +1011,7 @@ func (st *state) inlineCall() (*inlineCallResult, error) { res.new = results[0] } else { // Reduces to: { var (bindings); expr } + res.bindingDecl = true res.old = stmt res.new = &ast.BlockStmt{ List: []ast.Stmt{ @@ -1033,6 +1037,7 @@ func (st *state) inlineCall() (*inlineCallResult, error) { res.new = discard } else { // Reduces to: { var (bindings); _, _ = exprs } + res.bindingDecl = true res.new = &ast.BlockStmt{ List: []ast.Stmt{ bindingDecl.stmt, @@ -1062,6 +1067,7 @@ func (st *state) inlineCall() (*inlineCallResult, error) { List: newStmts, } if needBindingDecl { + res.bindingDecl = true block.List = prepend(bindingDecl.stmt, block.List...) } @@ -1178,6 +1184,7 @@ func (st *state) inlineCall() (*inlineCallResult, error) { body := calleeDecl.Body clearPositions(body) if needBindingDecl { + res.bindingDecl = true body.List = prepend(bindingDecl.stmt, body.List...) } res.old = ret @@ -1271,6 +1278,7 @@ func (st *state) inlineCall() (*inlineCallResult, error) { if bindingDecl != nil && allResultsUnreferenced { funcLit.Type.Params.List = nil remainingArgs = nil + res.bindingDecl = true funcLit.Body.List = prepend(bindingDecl.stmt, funcLit.Body.List...) } From 3c52d1f5b267c4d31ebd92bb55cd1183019247e2 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 14 May 2025 17:53:48 +0000 Subject: [PATCH 076/196] internal/mcp: reinstate ClientSession Add a logical client session type back to the SDK, to bring it closer to the design. Also, remove 'ConnectionOptions', which is better implemented using middleware, and reimplement logging as a LoggingTransport middleware, which is also added to the design. Change-Id: I908b4f837d513ffc60ae5b6171a31519cffafebc Reviewed-on: https://go-review.googlesource.com/c/tools/+/672795 LUCI-TryBot-Result: Go LUCI Reviewed-by: Sam Thanawalla Reviewed-by: Jonathan Amsterdam --- internal/mcp/client.go | 109 ++++++++++++++++------------ internal/mcp/cmd_test.go | 11 +-- internal/mcp/design/design.md | 25 +++++-- internal/mcp/examples/hello/main.go | 4 +- internal/mcp/mcp_test.go | 55 +++++++------- internal/mcp/server.go | 74 +++++++++---------- internal/mcp/server_example_test.go | 13 ++-- internal/mcp/sse.go | 2 +- internal/mcp/sse_example_test.go | 9 ++- internal/mcp/sse_test.go | 19 ++--- internal/mcp/transport.go | 102 +++++++++++++------------- 11 files changed, 227 insertions(+), 196 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 9467faba6a0..97caf60be15 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -16,29 +16,26 @@ import ( ) // A Client is an MCP client, which may be connected to an MCP server -// using the [Client.Start] method. +// using the [Client.Connect] method. type Client struct { - name string - version string - transport Transport - opts ClientOptions - mu sync.Mutex - conn *jsonrpc2.Connection - roots *featureSet[*Root] - initializeResult *initializeResult + name string + version string + opts ClientOptions + mu sync.Mutex + roots *featureSet[*Root] + sessions []*ClientSession } // NewClient creates a new Client. // -// Use [Client.Start] to connect it to an MCP server. +// Use [Client.Connect] to connect it to an MCP server. // // If non-nil, the provided options configure the Client. -func NewClient(name, version string, t Transport, opts *ClientOptions) *Client { +func NewClient(name, version string, opts *ClientOptions) *Client { c := &Client{ - name: name, - version: version, - transport: t, - roots: newFeatureSet(func(r *Root) string { return r.URI }), + name: name, + version: version, + roots: newFeatureSet(func(r *Root) string { return r.URI }), } if opts != nil { c.opts = *opts @@ -48,63 +45,79 @@ func NewClient(name, version string, t Transport, opts *ClientOptions) *Client { // ClientOptions configures the behavior of the client. type ClientOptions struct { - ConnectionOptions } -// bind implements the binder[*Client] interface, so that Clients can +// bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. -func (c *Client) bind(conn *jsonrpc2.Connection) *Client { +func (c *Client) bind(conn *jsonrpc2.Connection) *ClientSession { + cs := &ClientSession{ + conn: conn, + client: c, + } c.mu.Lock() defer c.mu.Unlock() - c.conn = conn - return c + c.sessions = append(c.sessions, cs) + return cs } // disconnect implements the binder[*Client] interface, so that // Clients can be connected using [connect]. -func (c *Client) disconnect(*Client) { - // Do nothing. In particular, do not set conn to nil: it needs to exist so it can - // return an error. +func (c *Client) disconnect(cs *ClientSession) { + c.mu.Lock() + defer c.mu.Unlock() + c.sessions = slices.DeleteFunc(c.sessions, func(cs2 *ClientSession) bool { + return cs2 == cs + }) } -// Start begins an MCP session by connecting the MCP client over its transport. +// Connect begins an MCP session by connecting to a server over the given +// transport, and initializing the session. // // Typically, it is the responsibility of the client to close the connection // when it is no longer needed. However, if the connection is closed by the // server, calls or notifications will return an error wrapping // [ErrConnectionClosed]. -func (c *Client) Start(ctx context.Context) (err error) { - defer func() { - if err != nil { - _ = c.Close() - } - }() - _, err = connect(ctx, c.transport, &c.opts.ConnectionOptions, c) +func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, err error) { + cs, err = connect(ctx, t, c) if err != nil { - return err + return nil, err } params := &initializeParams{ ClientInfo: &implementation{Name: c.name, Version: c.version}, } - if err := call(ctx, c.conn, "initialize", params, &c.initializeResult); err != nil { - return err + if err := call(ctx, cs.conn, "initialize", params, &cs.initializeResult); err != nil { + _ = cs.Close() + return nil, err } - if err := c.conn.Notify(ctx, "notifications/initialized", &initializedParams{}); err != nil { - return err + if err := cs.conn.Notify(ctx, "notifications/initialized", &initializedParams{}); err != nil { + _ = cs.Close() + return nil, err } - return nil + return cs, nil +} + +// A ClientSession is a logical connection with an MCP server. Its +// methods can be used to send requests or notifications to the server. Create +// a session by calling [Client.Connect]. +// +// Call [ClientSession.Close] to close the connection, or await client +// termination with [ServerSession.Wait]. +type ClientSession struct { + conn *jsonrpc2.Connection + client *Client + initializeResult *initializeResult } // Close performs a graceful close of the connection, preventing new requests // from being handled, and waiting for ongoing requests to return. Close then // terminates the connection. -func (c *Client) Close() error { +func (c *ClientSession) Close() error { return c.conn.Close() } // Wait waits for the connection to be closed by the server. // Generally, clients should be responsible for closing the connection. -func (c *Client) Wait() error { +func (c *ClientSession) Wait() error { return c.conn.Wait() } @@ -136,7 +149,7 @@ func (c *Client) listRoots(_ context.Context, _ *ListRootsParams) (*ListRootsRes }, nil } -func (c *Client) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { +func (c *ClientSession) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { // TODO: when we switch to ClientSessions, use a copy of the server's dispatch function, or // maybe just add another type parameter. // @@ -148,34 +161,34 @@ func (c *Client) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) return struct{}{}, nil case "roots/list": // ListRootsParams happens to be unused. - return c.listRoots(ctx, nil) + return c.client.listRoots(ctx, nil) } return nil, jsonrpc2.ErrNotHandled } // Ping makes an MCP "ping" request to the server. -func (c *Client) Ping(ctx context.Context, params *PingParams) error { +func (c *ClientSession) Ping(ctx context.Context, params *PingParams) error { return call(ctx, c.conn, "ping", params, nil) } // ListPrompts lists prompts that are currently available on the server. -func (c *Client) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { +func (c *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { return standardCall[ListPromptsResult](ctx, c.conn, "prompts/list", params) } // GetPrompt gets a prompt from the server. -func (c *Client) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { +func (c *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { return standardCall[GetPromptResult](ctx, c.conn, "prompts/get", params) } // ListTools lists tools that are currently available on the server. -func (c *Client) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { +func (c *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { return standardCall[ListToolsResult](ctx, c.conn, "tools/list", params) } // CallTool calls the tool with the given name and arguments. // Pass a [CallToolOptions] to provide additional request fields. -func (c *Client) CallTool(ctx context.Context, name string, args map[string]any, opts *CallToolOptions) (_ *CallToolResult, err error) { +func (c *ClientSession) CallTool(ctx context.Context, name string, args map[string]any, opts *CallToolOptions) (_ *CallToolResult, err error) { defer func() { if err != nil { err = fmt.Errorf("calling tool %q: %w", name, err) @@ -201,12 +214,12 @@ type CallToolOptions struct { } // ListResources lists the resources that are currently available on the server. -func (c *Client) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { +func (c *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { return standardCall[ListResourcesResult](ctx, c.conn, "resources/list", params) } // ReadResource ask the server to read a resource and return its contents. -func (c *Client) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { +func (c *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { return standardCall[ReadResourceResult](ctx, c.conn, "resources/read", params) } diff --git a/internal/mcp/cmd_test.go b/internal/mcp/cmd_test.go index cf899178ba1..202f8495136 100644 --- a/internal/mcp/cmd_test.go +++ b/internal/mcp/cmd_test.go @@ -32,7 +32,7 @@ func runServer() { server := mcp.NewServer("greeter", "v0.0.1", nil) server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) - if err := server.Run(ctx, mcp.NewStdIOTransport(), nil); err != nil { + if err := server.Run(ctx, mcp.NewStdIOTransport()); err != nil { log.Fatal(err) } } @@ -48,11 +48,12 @@ func TestCmdTransport(t *testing.T) { cmd := exec.Command(exe) cmd.Env = append(os.Environ(), runAsServer+"=true") - client := mcp.NewClient("client", "v0.0.1", mcp.NewCommandTransport(cmd), nil) - if err := client.Start(ctx); err != nil { + client := mcp.NewClient("client", "v0.0.1", nil) + session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) + if err != nil { log.Fatal(err) } - got, err := client.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) + got, err := session.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) if err != nil { log.Fatal(err) } @@ -62,7 +63,7 @@ func TestCmdTransport(t *testing.T) { if diff := cmp.Diff(want, got); diff != "" { t.Errorf("greet returned unexpected content (-want +got):\n%s", diff) } - if err := client.Close(); err != nil { + if err := session.Close(); err != nil { t.Fatalf("closing server: %v", err) } } diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index f3905f394ae..8ceb7466399 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -34,7 +34,6 @@ compatible with mcp-go, translating between them should be straightforward in most cases. (Later, we will provide a detailed translation guide.) - # Requirements These may be obvious, but it's worthwhile to define goals for an official MCP @@ -283,12 +282,24 @@ func NewStreamableClientTransport(url string) *StreamableClientTransport { func (*StreamableClientTransport) Connect(context.Context) (Stream, error) ``` -Finally, we also provide an in-memory transport, for scenarios such as testing, -where the MCP client and server are in the same process. +Finally, we also provide a couple of transport implementations for special scenarios. +An InMemoryTransport can be used when the client and server reside in the same +process. A LoggingTransport is a middleware layer that logs RPC logs to a desired +location, specified as an io.Writer. ```go +// An InMemoryTransport is a [Transport] that communicates over an in-memory +// network connection, using newline-delimited JSON. type InMemoryTransport struct { /* ... */ } -func NewInMemoryTransport() (*InMemoryTransport, *InMemoryTransport) + +// NewInMemoryTransports returns two InMemoryTransports that connect to each +// other. +func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) + +// A LoggingTransport is a [Transport] that delegates to another transport, +// writing RPC logs to an io.Writer. +type LoggingTransport struct { /* ... */ } +func NewLoggingTransport(delegate Transport, w io.Writer) *LoggingTransport ``` **Differences from mcp-go**: The Go team has a battle-tested JSON-RPC @@ -538,7 +549,6 @@ func (*ClientSession) Resources(context.Context, *ListResourceParams) iter.Seq2[ func (*ClientSession) ResourceTemplates(context.Context, *ListResourceTemplatesParams) iter.Seq2[ResourceTemplate, error] ``` - ### Middleware We provide a mechanism to add MCP-level middleware on the server side, which runs after the @@ -902,15 +912,18 @@ handler to a Go function using reflection to derive its arguments. We provide In our design, each resource and resource template is associated with a function that reads it, with this signature: + ```go type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) (*ReadResourceResult, error) ``` + The arguments include the `ServerSession` so the handler can observe the client's roots. The handler should return the resource contents in a `ReadResourceResult`, calling either `NewTextResourceContents` or `NewBlobResourceContents`. If the handler omits the URI or MIME type, the server will populate them from the resource. The `ServerResource` and `ServerResourceTemplate` types hold the association between the resource and its handler: + ```go type ServerResource struct { Resource Resource @@ -939,12 +952,14 @@ The `ReadResource` method finds a resource or resource template matching the arg its assocated handler. To read files from the local filesystem, we recommend using `FileResourceHandler` to construct a handler: + ```go // FileResourceHandler returns a ResourceHandler that reads paths using dir as a root directory. // It protects against path traversal attacks. // It will not read any file that is not in the root set of the client session requesting the resource. func (*Server) FileResourceHandler(dir string) ResourceHandler ``` + Here is an example: ```go diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index 8fbd74c66ad..b39b460f8ea 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -50,8 +50,8 @@ func main() { }) http.ListenAndServe(*httpAddr, handler) } else { - opts := &mcp.ConnectionOptions{Logger: os.Stderr} - if err := server.Run(context.Background(), mcp.NewStdIOTransport(), opts); err != nil { + t := mcp.NewLoggingTransport(mcp.NewStdIOTransport(), os.Stderr) + if err := server.Run(context.Background(), t); err != nil { fmt.Fprintf(os.Stderr, "Server failed: %v", err) } } diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 9f7d0fe12bb..4a0b0b4acfa 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -33,7 +33,7 @@ func sayHi(ctx context.Context, cc *ServerSession, v hiParams) ([]*Content, erro func TestEndToEnd(t *testing.T) { ctx := context.Background() - ct, st := NewInMemoryTransport() + ct, st := NewInMemoryTransports() s := NewServer("testServer", "v1.0.0", nil) @@ -63,7 +63,7 @@ func TestEndToEnd(t *testing.T) { ) // Connect the server. - ss, err := s.Connect(ctx, st, nil) + ss, err := s.Connect(ctx, st) if err != nil { t.Fatal(err) } @@ -81,19 +81,20 @@ func TestEndToEnd(t *testing.T) { clientWG.Done() }() - c := NewClient("testClient", "v1.0.0", ct, nil) + c := NewClient("testClient", "v1.0.0", nil) c.AddRoots(&Root{URI: "file:///root"}) // Connect the client. - if err := c.Start(ctx); err != nil { + cs, err := c.Connect(ctx, ct) + if err != nil { t.Fatal(err) } - if err := c.Ping(ctx, nil); err != nil { + if err := cs.Ping(ctx, nil); err != nil { t.Fatalf("ping failed: %v", err) } t.Run("prompts", func(t *testing.T) { - res, err := c.ListPrompts(ctx, nil) + res, err := cs.ListPrompts(ctx, nil) if err != nil { t.Errorf("prompts/list failed: %v", err) } @@ -109,7 +110,7 @@ func TestEndToEnd(t *testing.T) { t.Fatalf("prompts/list mismatch (-want +got):\n%s", diff) } - gotReview, err := c.GetPrompt(ctx, &GetPromptParams{Name: "code_review", Arguments: map[string]string{"Code": "1+1"}}) + gotReview, err := cs.GetPrompt(ctx, &GetPromptParams{Name: "code_review", Arguments: map[string]string{"Code": "1+1"}}) if err != nil { t.Fatal(err) } @@ -124,13 +125,13 @@ func TestEndToEnd(t *testing.T) { t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) } - if _, err := c.GetPrompt(ctx, &GetPromptParams{Name: "fail"}); err == nil || !strings.Contains(err.Error(), failure.Error()) { + if _, err := cs.GetPrompt(ctx, &GetPromptParams{Name: "fail"}); err == nil || !strings.Contains(err.Error(), failure.Error()) { t.Errorf("fail returned unexpected error: got %v, want containing %v", err, failure) } }) t.Run("tools", func(t *testing.T) { - res, err := c.ListTools(ctx, nil) + res, err := cs.ListTools(ctx, nil) if err != nil { t.Errorf("tools/list failed: %v", err) } @@ -160,7 +161,7 @@ func TestEndToEnd(t *testing.T) { t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) } - gotHi, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) + gotHi, err := cs.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) if err != nil { t.Fatal(err) } @@ -171,7 +172,7 @@ func TestEndToEnd(t *testing.T) { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } - gotFail, err := c.CallTool(ctx, "fail", map[string]any{}, nil) + gotFail, err := cs.CallTool(ctx, "fail", map[string]any{}, nil) // Counter-intuitively, when a tool fails, we don't expect an RPC error for // call tool: instead, the failure is embedded in the result. if err != nil { @@ -212,7 +213,7 @@ func TestEndToEnd(t *testing.T) { &ServerResource{resource1, readHandler}, &ServerResource{resource2, readHandler}) - lrres, err := c.ListResources(ctx, nil) + lrres, err := cs.ListResources(ctx, nil) if err != nil { t.Fatal(err) } @@ -228,7 +229,7 @@ func TestEndToEnd(t *testing.T) { {"file:///nonexistent.txt", ""}, // TODO(jba): add resource template cases when we implement them } { - rres, err := c.ReadResource(ctx, &ReadResourceParams{URI: tt.uri}) + rres, err := cs.ReadResource(ctx, &ReadResourceParams{URI: tt.uri}) if err != nil { var werr *jsonrpc2.WireError if errors.As(err, &werr) && werr.Code == codeResourceNotFound { @@ -267,7 +268,7 @@ func TestEndToEnd(t *testing.T) { }) // Disconnect. - c.Close() + cs.Close() clientWG.Wait() // After disconnecting, neither client nor server should have any @@ -282,26 +283,27 @@ func TestEndToEnd(t *testing.T) { // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *Client) { +func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *ClientSession) { t.Helper() ctx := context.Background() - ct, st := NewInMemoryTransport() + ct, st := NewInMemoryTransports() s := NewServer("testServer", "v1.0.0", nil) // The 'greet' tool says hi. s.AddTools(tools...) - ss, err := s.Connect(ctx, st, nil) + ss, err := s.Connect(ctx, st) if err != nil { t.Fatal(err) } - c := NewClient("testClient", "v1.0.0", ct, nil) - if err := c.Start(ctx); err != nil { + c := NewClient("testClient", "v1.0.0", nil) + cs, err := c.Connect(ctx, ct) + if err != nil { t.Fatal(err) } - return ss, c + return ss, cs } func TestServerClosing(t *testing.T) { @@ -329,28 +331,29 @@ func TestServerClosing(t *testing.T) { func TestBatching(t *testing.T) { ctx := context.Background() - ct, st := NewInMemoryTransport() + ct, st := NewInMemoryTransports() s := NewServer("testServer", "v1.0.0", nil) - _, err := s.Connect(ctx, st, nil) + _, err := s.Connect(ctx, st) if err != nil { t.Fatal(err) } - c := NewClient("testClient", "v1.0.0", ct, nil) + c := NewClient("testClient", "v1.0.0", nil) // TODO: this test is broken, because increasing the batch size here causes // 'initialize' to block. Therefore, we can only test with a size of 1. const batchSize = 1 BatchSize(ct, batchSize) - if err := c.Start(ctx); err != nil { + cs, err := c.Connect(ctx, ct) + if err != nil { t.Fatal(err) } - defer c.Close() + defer cs.Close() errs := make(chan error, batchSize) for i := range batchSize { go func() { - _, err := c.ListTools(ctx, nil) + _, err := cs.ListTools(ctx, nil) errs <- err }() time.Sleep(2 * time.Millisecond) diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 0b741cdfd56..1e397c6d8f3 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -30,7 +30,7 @@ type Server struct { prompts *featureSet[*ServerPrompt] tools *featureSet[*ServerTool] resources *featureSet[*ServerResource] - conns []*ServerSession + sessions []*ServerSession } // ServerOptions is used to configure behavior of the server. @@ -160,7 +160,7 @@ func (s *Server) RemoveResources(uris ...string) { // Sessions returns an iterator that yields the current set of server sessions. func (s *Server) Sessions() iter.Seq[*ServerSession] { s.mu.Lock() - clients := slices.Clone(s.conns) + clients := slices.Clone(s.sessions) s.mu.Unlock() return slices.Values(clients) } @@ -248,8 +248,8 @@ func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *Re // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection. -func (s *Server) Run(ctx context.Context, t Transport, opts *ConnectionOptions) error { - ss, err := s.Connect(ctx, t, opts) +func (s *Server) Run(ctx context.Context, t Transport) error { + ss, err := s.Connect(ctx, t) if err != nil { return err } @@ -261,7 +261,7 @@ func (s *Server) Run(ctx context.Context, t Transport, opts *ConnectionOptions) func (s *Server) bind(conn *jsonrpc2.Connection) *ServerSession { cc := &ServerSession{conn: conn, server: s} s.mu.Lock() - s.conns = append(s.conns, cc) + s.sessions = append(s.sessions, cc) s.mu.Unlock() return cc } @@ -271,7 +271,7 @@ func (s *Server) bind(conn *jsonrpc2.Connection) *ServerSession { func (s *Server) disconnect(cc *ServerSession) { s.mu.Lock() defer s.mu.Unlock() - s.conns = slices.DeleteFunc(s.conns, func(cc2 *ServerSession) bool { + s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool { return cc2 == cc }) } @@ -282,8 +282,8 @@ func (s *Server) disconnect(cc *ServerSession) { // It returns a connection object that may be used to terminate the connection // (with [Connection.Close]), or await client termination (with // [Connection.Wait]). -func (s *Server) Connect(ctx context.Context, t Transport, opts *ConnectionOptions) (*ServerSession, error) { - return connect(ctx, t, opts, s) +func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, error) { + return connect(ctx, t, s) } // A ServerSession is a logical connection from a single MCP client. Its @@ -302,18 +302,18 @@ type ServerSession struct { } // Ping makes an MCP "ping" request to the client. -func (cc *ServerSession) Ping(ctx context.Context, _ *PingParams) error { - return call(ctx, cc.conn, "ping", nil, nil) +func (ss *ServerSession) Ping(ctx context.Context, _ *PingParams) error { + return call(ctx, ss.conn, "ping", nil, nil) } -func (cc *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { - return standardCall[ListRootsResult](ctx, cc.conn, "roots/list", params) +func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { + return standardCall[ListRootsResult](ctx, ss.conn, "roots/list", params) } -func (cc *ServerSession) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { - cc.mu.Lock() - initialized := cc.initialized - cc.mu.Unlock() +func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { + ss.mu.Lock() + initialized := ss.initialized + ss.mu.Unlock() // From the spec: // "The client SHOULD NOT send requests other than pings before the server @@ -332,39 +332,39 @@ func (cc *ServerSession) handle(ctx context.Context, req *jsonrpc2.Request) (any switch req.Method { case "initialize": - return dispatch(ctx, cc, req, cc.initialize) + return dispatch(ctx, ss, req, ss.initialize) case "ping": // The spec says that 'ping' expects an empty object result. return struct{}{}, nil case "prompts/list": - return dispatch(ctx, cc, req, cc.server.listPrompts) + return dispatch(ctx, ss, req, ss.server.listPrompts) case "prompts/get": - return dispatch(ctx, cc, req, cc.server.getPrompt) + return dispatch(ctx, ss, req, ss.server.getPrompt) case "tools/list": - return dispatch(ctx, cc, req, cc.server.listTools) + return dispatch(ctx, ss, req, ss.server.listTools) case "tools/call": - return dispatch(ctx, cc, req, cc.server.callTool) + return dispatch(ctx, ss, req, ss.server.callTool) case "resources/list": - return dispatch(ctx, cc, req, cc.server.listResources) + return dispatch(ctx, ss, req, ss.server.listResources) case "resources/read": - return dispatch(ctx, cc, req, cc.server.readResource) + return dispatch(ctx, ss, req, ss.server.readResource) case "notifications/initialized": } return nil, jsonrpc2.ErrNotHandled } -func (cc *ServerSession) initialize(ctx context.Context, _ *ServerSession, params *initializeParams) (*initializeResult, error) { - cc.mu.Lock() - cc.initializeParams = params - cc.mu.Unlock() +func (ss *ServerSession) initialize(ctx context.Context, _ *ServerSession, params *initializeParams) (*initializeResult, error) { + ss.mu.Lock() + ss.initializeParams = params + ss.mu.Unlock() // Mark the connection as initialized when this method exits. TODO: // Technically, the server should not be considered initialized until it has @@ -372,9 +372,9 @@ func (cc *ServerSession) initialize(ctx context.Context, _ *ServerSession, param // connection to implement that easily. In any case, once we've initialized // here, we can handle requests. defer func() { - cc.mu.Lock() - cc.initialized = true - cc.mu.Unlock() + ss.mu.Lock() + ss.initialized = true + ss.mu.Unlock() }() return &initializeResult{ @@ -388,10 +388,10 @@ func (cc *ServerSession) initialize(ctx context.Context, _ *ServerSession, param ListChanged: false, // not yet supported }, }, - Instructions: cc.server.opts.Instructions, + Instructions: ss.server.opts.Instructions, ServerInfo: &implementation{ - Name: cc.server.name, - Version: cc.server.version, + Name: ss.server.name, + Version: ss.server.version, }, }, nil } @@ -399,13 +399,13 @@ func (cc *ServerSession) initialize(ctx context.Context, _ *ServerSession, param // Close performs a graceful shutdown of the connection, preventing new // requests from being handled, and waiting for ongoing requests to return. // Close then terminates the connection. -func (cc *ServerSession) Close() error { - return cc.conn.Close() +func (ss *ServerSession) Close() error { + return ss.conn.Close() } // Wait waits for the connection to be closed by the client. -func (cc *ServerSession) Wait() error { - return cc.conn.Wait() +func (ss *ServerSession) Wait() error { + return ss.conn.Wait() } // dispatch turns a strongly type request handler into a jsonrpc2 handler. diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index ed7438184e2..4a9a9c7044c 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -24,28 +24,29 @@ func SayHi(ctx context.Context, cc *mcp.ServerSession, params *SayHiParams) ([]* func ExampleServer() { ctx := context.Background() - clientTransport, serverTransport := mcp.NewInMemoryTransport() + clientTransport, serverTransport := mcp.NewInMemoryTransports() server := mcp.NewServer("greeter", "v0.0.1", nil) server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) - serverSession, err := server.Connect(ctx, serverTransport, nil) + serverSession, err := server.Connect(ctx, serverTransport) if err != nil { log.Fatal(err) } - client := mcp.NewClient("client", "v0.0.1", clientTransport, nil) - if err := client.Start(ctx); err != nil { + client := mcp.NewClient("client", "v0.0.1", nil) + clientSession, err := client.Connect(ctx, clientTransport) + if err != nil { log.Fatal(err) } - res, err := client.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) + res, err := clientSession.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) if err != nil { log.Fatal(err) } fmt.Println(res.Content[0].Text) - client.Close() + clientSession.Close() serverSession.Wait() // Output: Hi user diff --git a/internal/mcp/sse.go b/internal/mcp/sse.go index f1f657f94bb..bd82538769a 100644 --- a/internal/mcp/sse.go +++ b/internal/mcp/sse.go @@ -229,7 +229,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { // TODO(hxjiang): getServer returns nil will panic. server := h.getServer(req) - ss, err := server.Connect(req.Context(), transport, nil) + ss, err := server.Connect(req.Context(), transport) if err != nil { http.Error(w, "connection failed", http.StatusInternalServerError) return diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go index 391746f58fd..ef4269d46ff 100644 --- a/internal/mcp/sse_example_test.go +++ b/internal/mcp/sse_example_test.go @@ -34,13 +34,14 @@ func ExampleSSEHandler() { ctx := context.Background() transport := mcp.NewSSEClientTransport(httpServer.URL) - client := mcp.NewClient("test", "v1.0.0", transport, nil) - if err := client.Start(ctx); err != nil { + client := mcp.NewClient("test", "v1.0.0", nil) + cs, err := client.Connect(ctx, transport) + if err != nil { log.Fatal(err) } - defer client.Close() + defer cs.Close() - res, err := client.CallTool(ctx, "add", map[string]any{"x": 1, "y": 2}, nil) + res, err := cs.CallTool(ctx, "add", map[string]any{"x": 1, "y": 2}, nil) if err != nil { log.Fatal(err) } diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index 661cb5436f8..cba0ada9235 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -35,15 +35,16 @@ func TestSSEServer(t *testing.T) { clientTransport := NewSSEClientTransport(httpServer.URL) - c := NewClient("testClient", "v1.0.0", clientTransport, nil) - if err := c.Start(ctx); err != nil { + c := NewClient("testClient", "v1.0.0", nil) + cs, err := c.Connect(ctx, clientTransport) + if err != nil { t.Fatal(err) } - if err := c.Ping(ctx, nil); err != nil { + if err := cs.Ping(ctx, nil); err != nil { t.Fatal(err) } - cc := <-conns - gotHi, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) + ss := <-conns + gotHi, err := cs.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) if err != nil { t.Fatal(err) } @@ -57,11 +58,11 @@ func TestSSEServer(t *testing.T) { // Test that closing either end of the connection terminates the other // end. if closeServerFirst { - c.Close() - cc.Wait() + cs.Close() + ss.Wait() } else { - cc.Close() - c.Wait() + ss.Close() + cs.Wait() } }) } diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index a39d74d5b16..0fbe7082a80 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -41,13 +41,6 @@ type Stream interface { io.Closer } -// ConnectionOptions configures the behavior of an individual client<->server -// connection. -type ConnectionOptions struct { - SessionID string // if set, the session ID - Logger io.Writer // if set, write RPC logs -} - // A StdIOTransport is a [Transport] that communicates over stdin/stdout using // newline-delimited JSON. type StdIOTransport struct { @@ -76,9 +69,9 @@ type InMemoryTransport struct { ioTransport } -// NewInMemoryTransport returns two InMemoryTransports that connect to each +// NewInMemoryTransports returns two InMemoryTransports that connect to each // other. -func NewInMemoryTransport() (*InMemoryTransport, *InMemoryTransport) { +func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { c1, c2 := net.Pipe() return &InMemoryTransport{ioTransport{c1}}, &InMemoryTransport{ioTransport{c2}} } @@ -93,11 +86,7 @@ type binder[T handler] interface { disconnect(T) } -func connect[H handler](ctx context.Context, t Transport, opts *ConnectionOptions, b binder[H]) (H, error) { - if opts == nil { - opts = new(ConnectionOptions) - } - +func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error) { var zero H stream, err := t.Connect(ctx) if err != nil { @@ -105,11 +94,6 @@ func connect[H handler](ctx context.Context, t Transport, opts *ConnectionOption } // If logging is configured, write message logs. reader, writer := jsonrpc2.Reader(stream), jsonrpc2.Writer(stream) - if opts.Logger != nil { - reader = loggingReader(opts.Logger, reader) - writer = loggingWriter(opts.Logger, writer) - } - var ( h H preempter canceller @@ -178,54 +162,66 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params, return nil } -// The helpers below are used to bind transports to jsonrpc2. - -// A readerFunc implements jsonrpc2.Reader.Read. -type readerFunc func(context.Context) (jsonrpc2.Message, int64, error) +// A LoggingTransport is a [Transport] that delegates to another transport, +// writing RPC logs to an io.Writer. +type LoggingTransport struct { + delegate Transport + w io.Writer +} -func (f readerFunc) Read(ctx context.Context) (jsonrpc2.Message, int64, error) { - return f(ctx) +// NewLoggingTransport creates a new LoggingTransport that delegates to the +// provided transport, writing RPC logs to the provided io.Writer. +func NewLoggingTransport(delegate Transport, w io.Writer) *LoggingTransport { + return &LoggingTransport{delegate, w} } -// A writerFunc implements jsonrpc2.Writer.Write. -type writerFunc func(context.Context, jsonrpc2.Message) (int64, error) +// Connect connects the underlying transport, returning a [Stream] that writes +// logs to the configured destination. +func (t *LoggingTransport) Connect(ctx context.Context) (Stream, error) { + delegate, err := t.delegate.Connect(ctx) + if err != nil { + return nil, err + } + return &loggingStream{delegate, t.w}, nil +} -func (f writerFunc) Write(ctx context.Context, msg jsonrpc2.Message) (int64, error) { - return f(ctx, msg) +type loggingStream struct { + delegate Stream + w io.Writer } // loggingReader is a stream middleware that logs incoming messages. -func loggingReader(w io.Writer, delegate jsonrpc2.Reader) jsonrpc2.Reader { - return readerFunc(func(ctx context.Context) (jsonrpc2.Message, int64, error) { - msg, n, err := delegate.Read(ctx) +func (s *loggingStream) Read(ctx context.Context) (jsonrpc2.Message, int64, error) { + msg, n, err := s.delegate.Read(ctx) + if err != nil { + fmt.Fprintf(s.w, "read error: %v", err) + } else { + data, err := jsonrpc2.EncodeMessage(msg) if err != nil { - fmt.Fprintf(w, "read error: %v", err) - } else { - data, err := jsonrpc2.EncodeMessage(msg) - if err != nil { - fmt.Fprintf(w, "LoggingFramer: failed to marshal: %v", err) - } - fmt.Fprintf(w, "read: %s", string(data)) + fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) } - return msg, n, err - }) + fmt.Fprintf(s.w, "read: %s", string(data)) + } + return msg, n, err } // loggingWriter is a stream middleware that logs outgoing messages. -func loggingWriter(w io.Writer, delegate jsonrpc2.Writer) jsonrpc2.Writer { - return writerFunc(func(ctx context.Context, msg jsonrpc2.Message) (int64, error) { - n, err := delegate.Write(ctx, msg) +func (s *loggingStream) Write(ctx context.Context, msg jsonrpc2.Message) (int64, error) { + n, err := s.delegate.Write(ctx, msg) + if err != nil { + fmt.Fprintf(s.w, "write error: %v", err) + } else { + data, err := jsonrpc2.EncodeMessage(msg) if err != nil { - fmt.Fprintf(w, "write error: %v", err) - } else { - data, err := jsonrpc2.EncodeMessage(msg) - if err != nil { - fmt.Fprintf(w, "LoggingFramer: failed to marshal: %v", err) - } - fmt.Fprintf(w, "write: %s", string(data)) + fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) } - return n, err - }) + fmt.Fprintf(s.w, "write: %s", string(data)) + } + return n, err +} + +func (s *loggingStream) Close() error { + return s.delegate.Close() } // A rwc binds an io.ReadCloser and io.WriteCloser together to create an From 04dca596aa4130810edc8c7033ea8ab86484040f Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 14 May 2025 17:58:58 +0000 Subject: [PATCH 077/196] internal/mcp: hide CancelledParams Since cancellation is handled transparently, this type need not be exported. Change-Id: Id4e195b0c3bd49eca52587032a027ffa449eb9d1 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672796 LUCI-TryBot-Result: Go LUCI Reviewed-by: Sam Thanawalla --- internal/mcp/generate.go | 2 +- internal/mcp/protocol.go | 22 +++++++++++----------- internal/mcp/transport.go | 4 ++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index bfeca7705b0..dc136714988 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -66,7 +66,7 @@ var declarations = config{ "CallToolResult": {}, "CancelledNotification": { Name: "-", - Fields: config{"Params": {Name: "CancelledParams"}}, + Fields: config{"Params": {Name: "cancelledParams"}}, }, "ClientCapabilities": {}, "GetPromptRequest": { diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index dcccceff509..27e725e4b05 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -54,17 +54,6 @@ type CallToolResult struct { IsError bool `json:"isError,omitempty"` } -type CancelledParams struct { - // An optional string describing the reason for the cancellation. This MAY be - // logged or presented to the user. - Reason string `json:"reason,omitempty"` - // The ID of the request to cancel. - // - // This MUST correspond to the ID of a request previously issued in the same - // direction. - RequestID any `json:"requestId"` -} - // Capabilities a client may support. Known capabilities are defined here, in // this schema, but this is not a closed set: any client can define its own, // additional capabilities. @@ -315,6 +304,17 @@ type ToolAnnotations struct { Title string `json:"title,omitempty"` } +type cancelledParams struct { + // An optional string describing the reason for the cancellation. This MAY be + // logged or presented to the user. + Reason string `json:"reason,omitempty"` + // The ID of the request to cancel. + // + // This MUST correspond to the ID of a request previously issued in the same + // direction. + RequestID any `json:"requestId"` +} + // Describes the name and version of an MCP implementation. type implementation struct { Name string `json:"name"` diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index 0fbe7082a80..6c4a7319655 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -126,7 +126,7 @@ type canceller struct { // Preempt implements jsonrpc2.Preempter. func (c *canceller) Preempt(ctx context.Context, req *jsonrpc2.Request) (result any, err error) { if req.Method == "notifications/cancelled" { - var params CancelledParams + var params cancelledParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { return nil, err } @@ -151,7 +151,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params, return fmt.Errorf("calling %q: %w", method, ErrConnectionClosed) case ctx.Err() != nil: // Notify the peer of cancellation. - err := conn.Notify(xcontext.Detach(ctx), "notifications/cancelled", &CancelledParams{ + err := conn.Notify(xcontext.Detach(ctx), "notifications/cancelled", &cancelledParams{ Reason: ctx.Err().Error(), RequestID: call.ID().Raw(), }) From 0c0d3300bcff0c90129d404fee43751899e9db38 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 14 May 2025 15:05:29 -0400 Subject: [PATCH 078/196] internal/diff/lcs: log pseudorandom seeds so failures can be repro'd This change logs the seeds used to initialize the PRNG used by various tests so that failures on the builders can be reproduced. Running with -count=1000 readily discovers several failures; the seeds have been logged. I have not yet investigated why the test is failing. Updates golang/go#73714 Change-Id: I0be1a2a1470a890892fb377a36b442385c20e323 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672617 LUCI-TryBot-Result: Go LUCI Reviewed-by: Peter Weinberger --- internal/diff/lcs/common_test.go | 6 +++--- internal/diff/lcs/old_test.go | 34 +++++++++++++++++++++++--------- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/internal/diff/lcs/common_test.go b/internal/diff/lcs/common_test.go index 68f4485fdb8..1a621f3f76a 100644 --- a/internal/diff/lcs/common_test.go +++ b/internal/diff/lcs/common_test.go @@ -6,7 +6,7 @@ package lcs import ( "log" - "math/rand" + "math/rand/v2" "slices" "strings" "testing" @@ -105,11 +105,11 @@ func lcslen(l lcs) int { } // return a random string of length n made of characters from s -func randstr(s string, n int) string { +func randstr(rng *rand.Rand, s string, n int) string { src := []rune(s) x := make([]rune, n) for i := range n { - x[i] = src[rand.Intn(len(src))] + x[i] = src[rng.Int64N(int64(len(src)))] } return string(x) } diff --git a/internal/diff/lcs/old_test.go b/internal/diff/lcs/old_test.go index 2eac1af6d2f..32639e0c69f 100644 --- a/internal/diff/lcs/old_test.go +++ b/internal/diff/lcs/old_test.go @@ -7,7 +7,7 @@ package lcs import ( "fmt" "log" - "math/rand" + "math/rand/v2" "os" "strings" "testing" @@ -106,12 +106,21 @@ func TestRegressionOld003(t *testing.T) { } func TestRandOld(t *testing.T) { - rand.Seed(1) + // This test has been observed to fail on 0.5% of runs, + // for example, using these seed pairs: + // - 14495503613572398601, 14235960032715252551 + // - 4604737379815557952, 3544687276571693387 + // - 5419883078200329767, 11421218423438832472 + // - 16595049989808072974, 2139246309634979125 + // - 7079260183459082455, 16563974573191788291 + // TODO(adonovan): fix. + rng := rng(t) + for i := range 1000 { // TODO(adonovan): use ASCII and bytesSeqs here? The use of // non-ASCII isn't relevant to the property exercised by the test. - a := []rune(randstr("abω", 16)) - b := []rune(randstr("abωc", 16)) + a := []rune(randstr(rng, "abω", 16)) + b := []rune(randstr(rng, "abωc", 16)) seq := runesSeqs{a, b} const lim = 24 // large enough to get true lcs @@ -158,7 +167,7 @@ func TestDiffAPI(t *testing.T) { } func BenchmarkTwoOld(b *testing.B) { - tests := genBench("abc", 96) + tests := genBench(rng(b), "abc", 96) for i := 0; i < b.N; i++ { for _, tt := range tests { _, two := compute(stringSeqs{tt.before, tt.after}, twosided, 100) @@ -170,7 +179,7 @@ func BenchmarkTwoOld(b *testing.B) { } func BenchmarkForwOld(b *testing.B) { - tests := genBench("abc", 96) + tests := genBench(rng(b), "abc", 96) for i := 0; i < b.N; i++ { for _, tt := range tests { _, two := compute(stringSeqs{tt.before, tt.after}, forward, 100) @@ -181,14 +190,21 @@ func BenchmarkForwOld(b *testing.B) { } } -func genBench(set string, n int) []struct{ before, after string } { +// rng returns a randomly initialized PRNG whose seeds are logged so +// that occasional test failures can be deterministically replayed. +func rng(tb testing.TB) *rand.Rand { + seed1, seed2 := rand.Uint64(), rand.Uint64() + tb.Logf("PRNG seeds: %d, %d", seed1, seed2) + return rand.New(rand.NewPCG(seed1, seed2)) +} + +func genBench(rng *rand.Rand, set string, n int) []struct{ before, after string } { // before and after for benchmarks. 24 strings of length n with // before and after differing at least once, and about 5% - rand.Seed(3) var ans []struct{ before, after string } for range 24 { // maybe b should have an approximately known number of diffs - a := randstr(set, n) + a := randstr(rng, set, n) cnt := 0 bb := make([]rune, 0, n) for _, r := range a { From 279ce35e2ea3600c9a8444a0fa3d07ed4cfef5df Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 14 May 2025 18:28:24 +0000 Subject: [PATCH 079/196] internal/mcp/design: further clean-up Do a pass through the design doc, adding a bit more clean-up and commentary. Additionally, update ServerTool and ServerPrompt to be consistent with the design. Change-Id: I1b75c4690d86fe8946d512cb36a1fc2938b3c883 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672576 Reviewed-by: Jonathan Amsterdam Reviewed-by: Sam Thanawalla Auto-Submit: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/client.go | 2 +- internal/mcp/design/design.md | 237 ++++++++++++++++++++-------------- internal/mcp/mcp.go | 8 +- internal/mcp/prompt.go | 18 +-- internal/mcp/prompt_test.go | 4 +- internal/mcp/server.go | 8 +- internal/mcp/tool.go | 8 +- internal/mcp/tool_test.go | 4 +- 8 files changed, 168 insertions(+), 121 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 97caf60be15..dff7c1bdb43 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -208,7 +208,7 @@ func (c *ClientSession) CallTool(ctx context.Context, name string, args map[stri // NOTE: the following struct should consist of all fields of callToolParams except name and arguments. -// CallToolOptions contains options to [Client.CallTools]. +// CallToolOptions contains options to [ClientSession.CallTool]. type CallToolOptions struct { ProgressToken any // string or int } diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 8ceb7466399..fbd59560306 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -4,10 +4,12 @@ This document discusses the design of a Go SDK for the [model context protocol](https://modelcontextprotocol.io/specification/2025-03-26). It is intended to seed a GitHub discussion about the official Go MCP SDK. -The golang.org/x/tools/internal/mcp package contains a prototype that we built -to explore the MCP design space. Many of the ideas there are present in this -document. However, we have diverged and expanded on the APIs of that prototype, -and this document should be considered canonical. +The +[golang.org/x/tools/internal/mcp](https://pkg.go.dev/golang.org/x/tools/internal/mcp@master) +package contains a prototype that we built to explore the MCP design space. +Many of the ideas there are present in this document. However, we have diverged +from and expanded on the APIs of that prototype, and this document should be +considered canonical. ## Similarities and differences with mark3labs/mcp-go @@ -24,9 +26,7 @@ not viable for an official SDK. In yet others, we simply think there is room for API refinement, and we should take this opportunity to consider our options. Therefore, we wrote this document as though it were proposing a new implementation. Nevertheless, much of the API discussed here originated from or -was inspired by mcp-go and other unofficial SDKs, and if the consensus of this -discussion is close enough to mcp-go or any other unofficial SDK, we can start -from a fork. +was inspired by mcp-go and other unofficial SDKs. Since mcp-go is so influential and popular, we have noted significant differences from its API in the sections below. Although the API here is not @@ -58,8 +58,7 @@ SDK. An official SDK should aim to be: In the sections below, we visit each aspect of the MCP spec, in approximately the order they are presented by the [official spec](https://modelcontextprotocol.io/specification/2025-03-26) -For each, we discuss considerations for the Go implementation. In many cases an -API is suggested, though in some there may be open questions. +For each, we discuss considerations for the Go implementation, and propose a Go API. ## Foundations @@ -68,22 +67,26 @@ API is suggested, though in some there may be open questions. In the sections that follow, it is assumed that most of the MCP API lives in a single shared package, the `mcp` package. This is inconsistent with other MCP SDKs, but is consistent with Go packages like `net/http`, `net/rpc`, or -`google.golang.org/grpc`. +`google.golang.org/grpc`. We believe that having a single package aids +discoverability in package documentation and in the IDE. Furthermore, it avoids +somwhat arbitrary decisions about package structure that may be rendered +inaccurate by future evolution of the spec. Functionality that is not directly related to MCP (like jsonschema or jsonrpc2) belongs in a separate package. -Therefore, this is the package layout. `module.path` is a placeholder for the -final module path of the mcp module +Therefore, this is the package layout, assuming +github.com/modelcontextprotocol/go-sdk as the module path. -- `module.path/mcp`: the bulk of the user facing API -- `module.path/jsonschema`: a jsonschema implementation, with validation -- `module.path/internal/jsonrpc2`: a fork of x/tools/internal/jsonrpc2_v2 +- `github.com/modelcontextprotocol/go-sdk/mcp`: the bulk of the user facing API +- `github.com/modelcontextprotocol/go-sdk/jsonschema`: a jsonschema implementation, with validation +- `github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2`: a fork of x/tools/internal/jsonrpc2_v2 The JSON-RPC implementation is hidden, to avoid tight coupling. As described in the next section, the only aspects of JSON-RPC that need to be exposed in the SDK are the message types, for the purposes of defining custom transports. We -can expose these types from the `mcp` package via aliases or wrappers. +can expose these types by promoting them from the `mcp` package using aliases +or wrappers. **Difference from mcp-go**: Our `mcp` package includes all the functionality of mcp-go's `mcp`, `client`, `server` and `transport` packages. @@ -98,12 +101,15 @@ defines two transports: - **streamable http**: communication over a relatively complicated series of text/event-stream GET and HTTP POST requests. -Additionally, version `2024-11-05` of the spec defined a simpler HTTP transport: +Additionally, version `2024-11-05` of the spec defined a simpler (yet stateful) +HTTP transport: - **sse**: client issues a hanging GET request and receives messages via `text/event-stream`, and sends messages via POST to a session endpoint. -Furthermore, the spec [states](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#custom-transports) that it must be possible for users to define their own custom transports. +Furthermore, the spec +[states](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#custom-transports) +that it must be possible for users to define their own custom transports. Given the diversity of the transport implementations, they can be challenging to abstract. However, since JSON-RPC requires a bidirectional stream, we can @@ -124,13 +130,12 @@ type Stream interface { } ``` -Methods accept a Go `Context` and return an `error`, -as is idiomatic for APIs that do I/O. +Methods accept a Go `Context` and return an `error`, as is idiomatic for APIs +that do I/O. -A `Transport` is something that connects a logical JSON-RPC -stream, and nothing more. Streams must be closeable in order to -implement client and server shutdown, and therefore conform to the `io.Closer` -interface. +A `Transport` is something that connects a logical JSON-RPC stream, and nothing +more. Streams must be closeable in order to implement client and server +shutdown, and therefore conform to the `io.Closer` interface. Other SDKs define higher-level transports, with, for example, methods to send a notification or make a call. Those are jsonrpc2 operations on top of the @@ -142,6 +147,16 @@ language server `gopls`, which we propose to fork for the MCP SDK. It already handles concerns like client/server connection, request lifecycle, cancellation, and shutdown. +**Differences from mcp-go**: The Go team has a battle-tested JSON-RPC +implementation that we use for gopls, our Go LSP server. We are using the new +version of this library as part of our MCP SDK. It handles all JSON-RPC 2.0 +features, including cancellation. + +The `Transport` interface here is lower-level than that of mcp-go, but serves a +similar purpose. We believe the lower-level interface is easier to implement. + +#### stdio transports + In the MCP Spec, the **stdio** transport uses newline-delimited JSON to communicate over stdin/stdout. It's possible to model both client side and server side of this communication with a shared type that communicates over an @@ -178,13 +193,20 @@ func NewStdIOTransport() *StdIOTransport func (t *StdIOTransport) Connect(context.Context) (Stream, error) ``` +#### HTTP transports + The HTTP transport APIs are even more asymmetrical. Since connections are initiated via HTTP requests, the client developer will create a transport, but the server developer will typically install an HTTP handler. Internally, the -HTTP handler will create a transport for each new client connection. +HTTP handler will create a logical transport for each new client connection. Importantly, since they serve many connections, the HTTP handlers must accept a -callback to get an MCP server for each new session. +callback to get an MCP server for each new session. As described below, MCP +servers can optionally connect to multiple clients. This allows customization +of per-session servers: if the MCP server is stateless, the user can return the +same MCP server for each connection. On the other hand, if any per-session +customization is required, it is possible by returning a different `Server` +instance for each connection. ```go // SSEHTTPHandler is an http.Handler that serves SSE-based MCP sessions as defined by @@ -212,7 +234,7 @@ see [Middleware](#Middleware) below. By default, the SSE handler creates messages endpoints with the `?sessionId=...` query parameter. Users that want more control over the management of sessions and session endpoints may write their own handler, and -create `SSEServerTransport` instances themselves, for incoming GET requests. +create `SSEServerTransport` instances themselves for incoming GET requests. ```go // A SSEServerTransport is a logical SSE session created through a hanging GET @@ -233,6 +255,9 @@ type SSEServerTransport struct { /* ... */ } // The transport is itself an [http.Handler]. It is the caller's responsibility // to ensure that the resulting transport serves HTTP requests on the given // session endpoint. +// +// Most callers should instead use an [SSEHandler], which transparently handles +// the delegation to SSEServerTransports. func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTransport // ServeHTTP handles POST requests to the transport endpoint. @@ -282,7 +307,27 @@ func NewStreamableClientTransport(url string) *StreamableClientTransport { func (*StreamableClientTransport) Connect(context.Context) (Stream, error) ``` -Finally, we also provide a couple of transport implementations for special scenarios. +**Differences from mcp-go**: In mcp-go, server authors create an `MCPServer`, +populate it with tools, resources and so on, and then wrap it in an `SSEServer` +or `StdioServer`. Users can manage their own sessions with `RegisterSession` +and `UnregisterSession`. Rather than use a server constructor to get a distinct +server for each connection, there is a concept of a "session tool" that +overlays tools for a specific session. + +We find the similarity in names among the three server types to be confusing, +and we could not discover any uses of the session methods in the open-source +ecosystem. Furthermore, we believe that a server factor (`getServer`) provides +equivalent functionality as the per-session logic of mcp-go, with a smaller API +surface and fewer overlapping concepts. + +Additionally, individual handlers and transports here have a minimal API, and +do not expose internal details. Customization of things like handlers or +session management is intended to be implemented with middleware and/or +compositional patterns. + +#### Other transports + +We also provide a couple of transport implementations for special scenarios. An InMemoryTransport can be used when the client and server reside in the same process. A LoggingTransport is a middleware layer that logs RPC logs to a desired location, specified as an io.Writer. @@ -302,32 +347,6 @@ type LoggingTransport struct { /* ... */ } func NewLoggingTransport(delegate Transport, w io.Writer) *LoggingTransport ``` -**Differences from mcp-go**: The Go team has a battle-tested JSON-RPC -implementation that we use for gopls, our Go LSP server. We are using the new -version of this library as part of our MCP SDK. It handles all JSON-RPC 2.0 -features, including cancellation. - -The `Transport` interface here is lower-level than that of mcp-go, but serves a -similar purpose. We believe the lower-level interface is easier to implement. - -In mcp-go, server authors create an `MCPServer`, populate it with tools, -resources and so on, and then wrap it in an `SSEServer` or `StdioServer`. These -also use session IDs, which are exposed. Users can manage their own sessions -with `RegisterSession` and `UnregisterSession`. - -We find the similarity in names among the three server types to be confusing, -and we could not discover any uses of the session methods in the open-source -ecosystem. In our design, server authors create a `Server`, and then -connect it to a `Transport`. An `SSEHTTPHandler` manages sessions for -incoming SSE connections, but does not expose them. HTTP handlers accept a -server constructor, rather than Server, to allow for stateful or "per-session" -servers. - -Individual handlers and transports here have a minimal smaller API, and do not -expose internal details. Customization of things like handlers or session -management is intended to be implemented with middleware and/or compositional -patterns. - ### Protocol types Types needed for the protocol are generated from the @@ -373,7 +392,7 @@ type Content struct { ``` **Differences from mcp-go**: these types are largely similar, but our type -generation flattens types rather than using struct embedding. +generator flattens types rather than using struct embedding. ### Clients and Servers @@ -387,13 +406,18 @@ and resources from servers. Additionally, handlers for these features may themselves be stateful, for example if a tool handler caches state from earlier requests in the session. -We believe that in the common case, any change to a client or server, -such as adding a tool, is intended for all its peers. -It is therefore more useful to allow multiple connections from a client, and to -a server. This is similar to the `net/http` packages, in which an `http.Client` -and `http.Server` each may handle multiple unrelated connections. When users -add features to a client or server, all connected peers are notified of the -change. +We believe that in the common case, any change to a client or server, such as +adding a tool, is intended for all its peers. It is therefore more useful to +allow multiple connections from a client, and to a server. This is similar to +the `net/http` packages, in which an `http.Client` and `http.Server` each may +handle multiple unrelated connections. When users add features to a client or +server, all connected peers are notified of the change. + +Supporting multiple connections to servers (and from clients) still allows for +stateful components, as it is up to the user to decide whether or not to create +distinct servers/clients for each connection. For example, if the user wants to +create a distinct server for each new connection, they can do so in the +`getServer` factory passed to transport handlers. Following the terminology of the [spec](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#session-management), @@ -423,6 +447,7 @@ func (*ClientSession) Client() *Client func (*ClientSession) Close() error func (*ClientSession) Wait() error // Methods for calling through the ClientSession are described below. +// For example: ClientSession.ListTools. type Server struct { /* ... */ } func NewServer(name, version string, opts *ServerOptions) *Server @@ -436,6 +461,7 @@ func (*ServerSession) Server() *Server func (*ServerSession) Close() error func (*ServerSession) Wait() error // Methods for calling through the ServerSession are described below. +// For example: ServerSession.ListRoots. ``` Here's an example of these APIs from the client side: @@ -472,17 +498,18 @@ session until the client disconnects: func (*Server) Run(context.Context, Transport) ``` -**Differences from mcp-go**: the Server APIs are very similar to mcp-go, -though the association between servers and transports is different. In -mcp-go, a single server is bound to what we would call an `SSEHTTPHandler`, -and reused for all client sessions. As discussed above, the transport -abstraction here is differentiated from HTTP serving, and the `Server.Connect` -method provides a consistent API for binding to an arbitrary transport. Servers -here do not have methods for sending notifications or calls, because they are -logically distinct from the `ServerSession`. In mcp-go, servers are `n:1`, -but there is no abstraction of a server session: sessions are addressed in -Server APIs through their `sessionID`: `SendNotificationToAllClients`, -`SendNotificationToClient`, `SendNotificationToSpecificClient`. +**Differences from mcp-go**: the Server APIs are similar to mcp-go, though the +association between servers and transports is different. In mcp-go, a single +server is bound to what we would call an `SSEHTTPHandler`, and reused for all +sessions. Per-session behavior is implemented though a 'session tool' overlay. +As discussed above, the transport abstraction here is differentiated from HTTP +serving, and the `Server.Connect` method provides a consistent API for binding +to an arbitrary transport. Servers here do not have methods for sending +notifications or calls, because they are logically distinct from the +`ServerSession`. In mcp-go, servers are `n:1`, but there is no abstraction of a +server session: sessions are addressed in Server APIs through their +`sessionID`: `SendNotificationToAllClients`, `SendNotificationToClient`, +`SendNotificationToSpecificClient`. The client API here is different, since clients and client sessions are conceptually distinct. The `ClientSession` is closer to mcp-go's notion of @@ -503,27 +530,25 @@ func (*ClientSession) ListTools(context.Context, *ListToolsParams) (*ListToolsRe ``` Our SDK has a method for every RPC in the spec, and except for `CallTool`, -their signatures all share this form. -We do this, rather than providing more convenient shortcut signatures, -to maintain backward compatibility if the spec makes backward-compatible changes -such as adding a new property to the request parameters +their signatures all share this form. We do this, rather than providing more +convenient shortcut signatures, to maintain backward compatibility if the spec +makes backward-compatible changes such as adding a new property to the request +parameters (as in [this commit](https://github.com/modelcontextprotocol/modelcontextprotocol/commit/2fce8a077688bf8011e80af06348b8fe1dae08ac), for example). -To avoid boilerplate, we don't repeat this -signature for RPCs defined in the spec; readers may assume it when we mention a -"spec method." +To avoid boilerplate, we don't repeat this signature for RPCs defined in the +spec; readers may assume it when we mention a "spec method." `CallTool` is the only exception: for convenience, it takes the tool name and -arguments, with an options struct for additional request fields. -Our SDK has a method for every RPC in the spec, and their signatures all share -this form. To avoid boilerplate, we don't repeat this signature for RPCs -defined in the spec; readers may assume it when we mention a "spec method." +arguments, with an options struct for additional request fields. See the +section on Tools below for details. -Why do we use params instead of the full request? JSON-RPC requests consist of a method -name and a set of parameters, and the method is already encoded in the Go method name. -Technically, the MCP spec could add a field to a request while preserving backward -compatibility, which would break the Go SDK's compatibility. But in the unlikely event -that were to happen, we would add that field to the Params struct. +Why do we use params instead of the full JSON-RPC request? As much as possible, +we endeavor to hide JSON-RPC details when they are not relevant to the business +logic of your client or server. In this case, the additional information in the +JSON-RPC request is just the request ID and method name; the request ID is +irrelevant, and the method name is implied by the name of the Go method +providing the API. We believe that any change to the spec that would require callers to pass a new a parameter is not backward compatible. Therefore, it will always work to pass @@ -595,9 +620,9 @@ code. ```go type JSONRPCError struct { - Code int64 `json:"code"` - Message string `json:"message"` - Data json.RawMessage `json:"data,omitempty"` + Code int64 `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data,omitempty"` } ``` @@ -658,9 +683,9 @@ func (c *ServerSession) Ping(ctx context.Context, *PingParams) error Additionally, client and server sessions can be configured with automatic keepalive behavior. If the `KeepAlive` option is set to a non-zero duration, -it defines an -interval for regular "ping" requests. If the peer fails to respond to pings -originating from the keepalive check, the session is automatically closed. +it defines an interval for regular "ping" requests. If the peer fails to +respond to pings originating from the keepalive check, the session is +automatically closed. ```go type ClientOptions struct { @@ -847,6 +872,19 @@ directly with assignment or a struct literal. Client sessions can call the spec method `ListTools` or an iterator method `Tools` to list the available tools. +As mentioned above, the client session method `CallTool` has a non-standard +signature, so that `CallTool` can handle the marshalling of tool arguments: the +type of `CallToolParams.Arguments` is `json.RawMessage`, to delegate +unmarshalling to the tool handler. + +```go +func (c *ClientSession) CallTool(ctx context.Context, name string, args map[string]any, opts *CallToolOptions) (_ *CallToolResult, err error) + +type CallToolOptions struct { + ProgressToken any // string or int +} +``` + **Differences from mcp-go**: using variadic options to configure tools was significantly inspired by mcp-go. However, the distinction between `ToolOption` and `SchemaOption` allows for recursive application of schema options. @@ -972,6 +1010,11 @@ s.AddResources(&mcp.ServerResource{ Server sessions also support the spec methods `ListResources` and `ListResourceTemplates`, and the corresponding iterator methods `Resources` and `ResourceTemplates`. +**Differences from mcp-go**: for symmetry with tools and prompts, we use +`AddResources` rather than `AddResource`. Additionally, the `ResourceHandler` +returns a `ReadResourceResult`, rather than just its content, for compatibility +with future evolution of the spec. + #### Subscriptions ClientSessions can manage change notifications on particular resources: @@ -1105,3 +1148,7 @@ more pages exist. In addition to the `List` methods, the SDK provides an iterator method for each list operation. This simplifies pagination for clients by automatically handling the underlying pagination logic. See [Iterator Methods](#iterator-methods) above. + +**Differences with mcp-go**: the PageSize configuration is set with a +configuration field rather than a variadic option. Additionally, this design +proposes pagination by default, as this is likely desirable for most servers. diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go index d6ee36915b5..d1cd6c7a900 100644 --- a/internal/mcp/mcp.go +++ b/internal/mcp/mcp.go @@ -10,13 +10,13 @@ // To get started, create either a [Client] or [Server], and connect it to a // peer using a [Transport]. The diagram below illustrates how this works: // -// Client Server -// ⇅ (jsonrpc2) ⇅ -// Client Transport ⇄ Server Transport ⇄ ServerSession +// Client Server +// ⇅ (jsonrpc2) ⇅ +// ClientSession ⇄ Client Transport ⇄ Server Transport ⇄ ServerSession // // A [Client] is an MCP client, which can be configured with various client // capabilities. Clients may be connected to a [Server] instance -// using the [Client.Start] method. +// using the [Client.Connect] method. // // Similarly, a [Server] is an MCP server, which can be configured with various // server capabilities. Servers may be connected to one or more [Client] diff --git a/internal/mcp/prompt.go b/internal/mcp/prompt.go index d6a2b117269..2c4c757f9bc 100644 --- a/internal/mcp/prompt.go +++ b/internal/mcp/prompt.go @@ -20,8 +20,8 @@ type PromptHandler func(context.Context, *ServerSession, map[string]string) (*Ge // A Prompt is a prompt definition bound to a prompt handler. type ServerPrompt struct { - Definition *Prompt - Handler PromptHandler + Prompt *Prompt + Handler PromptHandler } // NewPrompt is a helper to use reflection to create a prompt for the given @@ -41,7 +41,7 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, panic(fmt.Sprintf("handler request type must be a struct")) } prompt := &ServerPrompt{ - Definition: &Prompt{ + Prompt: &Prompt{ Name: name, Description: description, }, @@ -54,7 +54,7 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, if prop.Type != "string" { panic(fmt.Sprintf("handler type must consist only of string fields")) } - prompt.Definition.Arguments = append(prompt.Definition.Arguments, &PromptArgument{ + prompt.Prompt.Arguments = append(prompt.Prompt.Arguments, &PromptArgument{ Name: name, Description: prop.Description, Required: required[name], @@ -95,16 +95,16 @@ func (s promptSetter) set(p *ServerPrompt) { s(p) } // Required and Description, and panics when encountering any other option. func Argument(name string, opts ...SchemaOption) PromptOption { return promptSetter(func(p *ServerPrompt) { - i := slices.IndexFunc(p.Definition.Arguments, func(arg *PromptArgument) bool { + i := slices.IndexFunc(p.Prompt.Arguments, func(arg *PromptArgument) bool { return arg.Name == name }) var arg *PromptArgument if i < 0 { - i = len(p.Definition.Arguments) + i = len(p.Prompt.Arguments) arg = &PromptArgument{Name: name} - p.Definition.Arguments = append(p.Definition.Arguments, arg) + p.Prompt.Arguments = append(p.Prompt.Arguments, arg) } else { - arg = p.Definition.Arguments[i] + arg = p.Prompt.Arguments[i] } for _, opt := range opts { switch v := opt.(type) { @@ -116,6 +116,6 @@ func Argument(name string, opts ...SchemaOption) PromptOption { panic(fmt.Sprintf("unsupported prompt argument schema option %T", opt)) } } - p.Definition.Arguments[i] = arg + p.Prompt.Arguments[i] = arg }) } diff --git a/internal/mcp/prompt_test.go b/internal/mcp/prompt_test.go index fa2e3bc0a71..4de5aa93d9f 100644 --- a/internal/mcp/prompt_test.go +++ b/internal/mcp/prompt_test.go @@ -46,8 +46,8 @@ func TestNewPrompt(t *testing.T) { }, } for _, test := range tests { - if diff := cmp.Diff(test.want, test.prompt.Definition.Arguments); diff != "" { - t.Errorf("NewPrompt(%v) mismatch (-want +got):\n%s", test.prompt.Definition.Name, diff) + if diff := cmp.Diff(test.want, test.prompt.Prompt.Arguments); diff != "" { + t.Errorf("NewPrompt(%v) mismatch (-want +got):\n%s", test.prompt.Prompt.Name, diff) } } } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 1e397c6d8f3..d91c0ef9cf0 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -53,8 +53,8 @@ func NewServer(name, version string, opts *ServerOptions) *Server { name: name, version: version, opts: *opts, - prompts: newFeatureSet(func(p *ServerPrompt) string { return p.Definition.Name }), - tools: newFeatureSet(func(t *ServerTool) string { return t.Definition.Name }), + prompts: newFeatureSet(func(p *ServerPrompt) string { return p.Prompt.Name }), + tools: newFeatureSet(func(t *ServerTool) string { return t.Tool.Name }), resources: newFeatureSet(func(r *ServerResource) string { return r.Resource.URI }), } } @@ -170,7 +170,7 @@ func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPr defer s.mu.Unlock() res := new(ListPromptsResult) for p := range s.prompts.all() { - res.Prompts = append(res.Prompts, p.Definition) + res.Prompts = append(res.Prompts, p.Prompt) } return res, nil } @@ -191,7 +191,7 @@ func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListTool defer s.mu.Unlock() res := new(ListToolsResult) for t := range s.tools.all() { - res.Tools = append(res.Tools, t.Definition) + res.Tools = append(res.Tools, t.Tool) } return res, nil } diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index f7b04660f08..43ebe1bfdb4 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -17,8 +17,8 @@ type ToolHandler func(context.Context, *ServerSession, *CallToolParams) (*CallTo // A Tool is a tool definition that is bound to a tool handler. type ServerTool struct { - Definition *Tool - Handler ToolHandler + Tool *Tool + Handler ToolHandler } // NewTool is a helper to make a tool using reflection on the given handler. @@ -59,7 +59,7 @@ func NewTool[TReq any](name, description string, handler func(context.Context, * return res, nil } t := &ServerTool{ - Definition: &Tool{ + Tool: &Tool{ Name: name, Description: description, InputSchema: schema, @@ -94,7 +94,7 @@ func (s toolSetter) set(t *ServerTool) { s(t) } func Input(opts ...SchemaOption) ToolOption { return toolSetter(func(t *ServerTool) { for _, opt := range opts { - opt.set(t.Definition.InputSchema) + opt.set(t.Tool.InputSchema) } }) } diff --git a/internal/mcp/tool_test.go b/internal/mcp/tool_test.go index 88694af3e75..ae4e5ee93e5 100644 --- a/internal/mcp/tool_test.go +++ b/internal/mcp/tool_test.go @@ -83,8 +83,8 @@ func TestNewTool(t *testing.T) { }, } for _, test := range tests { - if diff := cmp.Diff(test.want, test.tool.Definition.InputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Errorf("NewTool(%v) mismatch (-want +got):\n%s", test.tool.Definition.Name, diff) + if diff := cmp.Diff(test.want, test.tool.Tool.InputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Errorf("NewTool(%v) mismatch (-want +got):\n%s", test.tool.Tool.Name, diff) } } } From 07c24ad5cd7c20c79c2e15d1e645ae9f81c5f1de Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 14 May 2025 22:18:22 +0000 Subject: [PATCH 080/196] internal/mcp: update README Update the README to reflect recent API changes. Also, remove some design discussion in the README, to instead refer to the canonical design doc. Change-Id: Id77d34c6364d0a35e33f6952a291b3b659ad732b Reviewed-on: https://go-review.googlesource.com/c/tools/+/672619 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/mcp/README.md | 81 +++++++++++++++--------------------------- 1 file changed, 28 insertions(+), 53 deletions(-) diff --git a/internal/mcp/README.md b/internal/mcp/README.md index 1f616746bb1..761df9b2d27 100644 --- a/internal/mcp/README.md +++ b/internal/mcp/README.md @@ -1,10 +1,12 @@ -# MCP package +# MCP SDK prototype [![PkgGoDev](https://pkg.go.dev/badge/golang.org/x/tools)](https://pkg.go.dev/golang.org/x/tools/internal/mcp) -The mcp package provides an SDK for writing [model context protocol](https://modelcontextprotocol.io/introduction) -clients and servers. It is a work-in-progress. As of writing, it is a prototype -to explore the design space of client/server transport and binding. +The mcp package provides a software development kit (SDK) for writing clients +and servers of the [model context +protocol](https://modelcontextprotocol.io/introduction). It is unstable, and +will change in breaking ways in the future. As of writing, it is a prototype to +explore the design space of client/server transport and binding. ## Installation @@ -32,12 +34,14 @@ func main() { client := mcp.NewClient("mcp-client", "v1.0.0", nil) // Connect to a server over stdin/stdout transport := mcp.NewCommandTransport(exec.Command("myserver")) - if err := client.Connect(ctx, transport, nil); err != nil { + session, err := client.Connect(ctx, transport) + if err != nil { log.Fatal(err) } + defer session.Close() // Call a tool on the server. - if content, err := client.CallTool(ctx, "greet", map[string]any{"name": "you"}); err != nil { - log.Printf("CallTool returns error: %v", err) + if content, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}, nil); err != nil { + log.Printf("CallTool failed: %v", err) } else { log.Printf("CallTool returns: %v", content) } @@ -59,64 +63,31 @@ type HiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *HiParams) ([]mcp.Content, error) { - return []mcp.Content{ - mcp.TextContent{Text: "Hi " + params.Name}, +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *HiParams) ([]*mcp.Content, error) { + return []*mcp.Content{ + mcp.NewTextContent("Hi " + params.Name), }, nil } func main() { // Create a server with a single tool. server := mcp.NewServer("greeter", "v1.0.0", nil) - server.AddTools(mcp.MakeTool("greet", "say hi", SayHi)) + server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) // Run the server over stdin/stdout, until the client diconnects - _ = server.Run(context.Background(), mcp.NewStdIOTransport(), nil) + _ = server.Run(context.Background(), mcp.NewStdIOTransport()) } ``` -## Core Concepts +## Design -The mcp package leverages Go's [reflect](https://pkg.go.dev/reflect) package to -automatically generate the JSON schema for your tools / prompts' input -parameters. As an mcp server developer, ensure your input parameter structs -include the standard `"json"` tags (as demonstrated in the `HiParams` example). -Refer to the [jsonschema](https://www.google.com/search?q=internal/jsonschema/infer.go) -package for detailed information on schema inference. - -### Tools - -Tools in MCP allow servers to expose executable functions that can be invoked by clients and used by LLMs to perform actions. The server can add tools using - -```go -... -server := mcp.NewServer("greeter", "v1.0.0", nil) -server.AddTools(mcp.MakeTool("greet", "say hi", SayHi)) -... -``` - -### Prompts - -Prompts enable servers to define reusable prompt templates and workflows that clients can easily surface to users and LLMs. The server can add prompts by using - -```go -... -server := mcp.NewServer("greeter", "v0.0.1", nil) -server.AddPrompts(mcp.MakePrompt("greet", "", PromptHi)) -... -``` - -### Resources - -Resources are a core primitive in the Model Context Protocol (MCP) that allow servers to expose data and content that can be read by clients and used as context for LLM interactions. - - - -Resources are not supported yet. +See [design.md](./design/design.md) for the SDK design. That document is +canonical: given any divergence between the design doc and this prototype, the +doc reflects the latest design. ## Testing -To test your client or server using stdio transport, you can use local -transport instead of creating real stdio transportation. See [example](server_example_test.go). +To test your client or server using stdio transport, you can use an in-memory +transport. See [example](server_example_test.go). To test your client or server using sse transport, you can use the [httptest](https://pkg.go.dev/net/http/httptest) package. See [example](sse_example_test.go). @@ -128,6 +99,10 @@ If you encounter a conduct-related issue, please mail conduct@golang.org. ## License -Unless otherwise noted, the Go source files are distributed under the BSD-style license found in the [LICENSE](../../LICENSE) file. +Unless otherwise noted, the Go source files are distributed under the BSD-style +license found in the [LICENSE](../../LICENSE) file. -Upon a potential move to [modelcontextprotocol](https://github.com/modelcontextprotocol), the license will be updated to the MIT License, and the license header will reflect the Go MCP SDK Authors. +Upon a potential move to the +[modelcontextprotocol](https://github.com/modelcontextprotocol) organization, +the license will be updated to the MIT License, and the license header will +reflect the Go MCP SDK Authors. From 6731e88867c2a96d1903dd3ef8bf3b598e2dc45f Mon Sep 17 00:00:00 2001 From: Peter Weinberger Date: Thu, 15 May 2025 08:13:05 -0400 Subject: [PATCH 081/196] internal/diff/lcs: fix flaky test A parameter was too small, so the computed lcs was occasionally incorrect (parts per million). (Another badly thought-out attempt to be efficient.( Change-Id: I5a7505419a877030db4a9a04eb2597eb8f4dcba0 Reviewed-on: https://go-review.googlesource.com/c/tools/+/673135 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- internal/diff/lcs/old.go | 6 +++--- internal/diff/lcs/old_test.go | 10 +--------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/internal/diff/lcs/old.go b/internal/diff/lcs/old.go index c0d43a6c2c7..4c346706a75 100644 --- a/internal/diff/lcs/old.go +++ b/internal/diff/lcs/old.go @@ -105,7 +105,7 @@ func forward(e *editGraph) lcs { return ans } // from D to D+1 - for D := 0; D < e.limit; D++ { + for D := range e.limit { e.setForward(D+1, -(D + 1), e.getForward(D, -D)) if ok, ans := e.fdone(D+1, -(D + 1)); ok { return ans @@ -206,7 +206,7 @@ func backward(e *editGraph) lcs { return ans } // from D to D+1 - for D := 0; D < e.limit; D++ { + for D := range e.limit { e.setBackward(D+1, -(D + 1), e.getBackward(D, -D)-1) if ok, ans := e.bdone(D+1, -(D + 1)); ok { return ans @@ -300,7 +300,7 @@ func twosided(e *editGraph) lcs { e.setBackward(0, 0, e.ux) // from D to D+1 - for D := 0; D < e.limit; D++ { + for D := range e.limit { // just finished a backwards pass, so check if got, ok := e.twoDone(D, D); ok { return e.twolcs(D, D, got) diff --git a/internal/diff/lcs/old_test.go b/internal/diff/lcs/old_test.go index 32639e0c69f..035465fa34c 100644 --- a/internal/diff/lcs/old_test.go +++ b/internal/diff/lcs/old_test.go @@ -106,14 +106,6 @@ func TestRegressionOld003(t *testing.T) { } func TestRandOld(t *testing.T) { - // This test has been observed to fail on 0.5% of runs, - // for example, using these seed pairs: - // - 14495503613572398601, 14235960032715252551 - // - 4604737379815557952, 3544687276571693387 - // - 5419883078200329767, 11421218423438832472 - // - 16595049989808072974, 2139246309634979125 - // - 7079260183459082455, 16563974573191788291 - // TODO(adonovan): fix. rng := rng(t) for i := range 1000 { @@ -123,7 +115,7 @@ func TestRandOld(t *testing.T) { b := []rune(randstr(rng, "abωc", 16)) seq := runesSeqs{a, b} - const lim = 24 // large enough to get true lcs + const lim = 0 // make sure we get the lcs (24 was too small) _, forw := compute(seq, forward, lim) _, back := compute(seq, backward, lim) _, two := compute(seq, twosided, lim) From 78956f956474ab4a134716d2346661960547f7b7 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 15 May 2025 15:00:16 +0000 Subject: [PATCH 082/196] internal/mcp/design: yet more cleanup; define the Roots iterator Do another pass of clean up for the design prior to posting. Also, add a cached Roots iterator, as discussed. Change-Id: If8d444b23c7332b89effb2fce8b3be28841f0882 Reviewed-on: https://go-review.googlesource.com/c/tools/+/673195 TryBot-Bypass: Sam Thanawalla Reviewed-by: Sam Thanawalla TryBot-Bypass: Robert Findley Reviewed-by: Jonathan Amsterdam --- internal/mcp/design/design.md | 61 ++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index fbd59560306..1c1d559ef95 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -69,13 +69,13 @@ single shared package, the `mcp` package. This is inconsistent with other MCP SDKs, but is consistent with Go packages like `net/http`, `net/rpc`, or `google.golang.org/grpc`. We believe that having a single package aids discoverability in package documentation and in the IDE. Furthermore, it avoids -somwhat arbitrary decisions about package structure that may be rendered -inaccurate by future evolution of the spec. +arbitrary decisions about package structure that may be rendered inaccurate by +future evolution of the spec. Functionality that is not directly related to MCP (like jsonschema or jsonrpc2) belongs in a separate package. -Therefore, this is the package layout, assuming +Therefore, this is the core package layout, assuming github.com/modelcontextprotocol/go-sdk as the module path. - `github.com/modelcontextprotocol/go-sdk/mcp`: the bulk of the user facing API @@ -91,7 +91,7 @@ or wrappers. **Difference from mcp-go**: Our `mcp` package includes all the functionality of mcp-go's `mcp`, `client`, `server` and `transport` packages. -### jsonrpc2 and Transports +### JSON-RPC and Transports The MCP is defined in terms of client-server communication over bidirectional JSON-RPC message streams. Specifically, version `2025-03-26` of the spec @@ -124,8 +124,8 @@ type Transport interface { // A Stream is a bidirectional jsonrpc2 Stream. type Stream interface { - Read(ctx context.Context) (jsonrpc2.Message, error) - Write(ctx context.Context, jsonrpc2.Message) error + Read(ctx context.Context) (JSONRPCMessage, error) + Write(ctx context.Context, JSONRPCMessage) error Close() error } ``` @@ -161,7 +161,7 @@ In the MCP Spec, the **stdio** transport uses newline-delimited JSON to communicate over stdin/stdout. It's possible to model both client side and server side of this communication with a shared type that communicates over an `io.ReadWriteCloser`. However, for the purposes of future-proofing, we should -use a distinct types for both client and server stdio transport. +use a different types for client and server stdio transport. The `CommandTransport` is the client side of the stdio transport, and connects by starting a command and binding its jsonrpc2 stream to its @@ -293,7 +293,7 @@ func (*StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request func (*StreamableHTTPHandler) Close() error // Unlike the SSE transport, the streamable transport constructor accepts a -// session ID, not an endpoint, along with the http response for the request +// session ID, not an endpoint, along with the HTTP response for the request // that created the session. It is the caller's responsibility to delegate // requests to this session. type StreamableServerTransport struct { /* ... */ } @@ -316,7 +316,7 @@ overlays tools for a specific session. We find the similarity in names among the three server types to be confusing, and we could not discover any uses of the session methods in the open-source -ecosystem. Furthermore, we believe that a server factor (`getServer`) provides +ecosystem. Furthermore, we believe that a server factory (`getServer`) provides equivalent functionality as the per-session logic of mcp-go, with a smaller API surface and fewer overlapping concepts. @@ -353,7 +353,7 @@ Types needed for the protocol are generated from the [JSON schema of the MCP spec](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json). These types will be included in the `mcp` package, but will be unexported -unless they are needed for the user-facing API. Notably, JSON-RPC message types +unless they are needed for the user-facing API. Notably, JSON-RPC request types are elided, since they are handled by the `jsonrpc2` package and should not be observed by the user. @@ -389,6 +389,10 @@ type Content struct { Data []byte `json:"data,omitempty"` Resource *ResourceContents `json:"resource,omitempty"` } + +// NewTextContent creates a [Content] with text. +func NewTextContent(text string) *Content +// etc. ``` **Differences from mcp-go**: these types are largely similar, but our type @@ -437,7 +441,8 @@ method. ```go type Client struct { /* ... */ } func NewClient(name, version string, opts *ClientOptions) *Client -func (c *Client) Connect(context.Context, Transport) (*ClientSession, error) +func (*Client) Connect(context.Context, Transport) (*ClientSession, error) +func (*Client) Sessions() iter.Seq[*ClientSession] // Methods for adding/removing client features are described below. type ClientOptions struct { /* ... */ } // described below @@ -451,7 +456,8 @@ func (*ClientSession) Wait() error type Server struct { /* ... */ } func NewServer(name, version string, opts *ServerOptions) *Server -func (s *Server) Connect(context.Context, Transport) (*ServerSession, error) +func (*Server) Connect(context.Context, Transport) (*ServerSession, error) +func (*Server) Sessions() iter.Seq[*ServerSession] // Methods for adding/removing server features are described below. type ServerOptions struct { /* ... */ } // described below @@ -517,7 +523,7 @@ Client. For both clients and servers, mcp-go uses variadic options to customize behavior, whereas an options struct is used here. We felt that in this case, an -options struct would be more readable, and result in cleaner package +options struct would be more readable, and result in simpler package documentation. ### Spec Methods @@ -555,7 +561,7 @@ parameter is not backward compatible. Therefore, it will always work to pass `nil` for any `XXXParams` argument that isn't currently necessary. For example, it is okay to call `Ping` like so: ```go -err := session.Ping(ctx, nil)` +err := session.Ping(ctx, nil) ``` #### Iterator Methods @@ -630,9 +636,9 @@ As described by the [spec](https://modelcontextprotocol.io/specification/2025-03-26/server/tools#error-handling), tool execution errors are reported in tool results. -**Differences from mcp-go**: the `JSONRPCError` type here does not expose -details that are irrelevant or can be inferred from the caller (ID and Method). -Otherwise, this behavior is similar. +**Differences from mcp-go**: the `JSONRPCError` type here does not include ID +and Method, which can be inferred from the caller. Otherwise, this behavior is +similar. ### Cancellation @@ -707,8 +713,9 @@ client, not server, and the keepalive option is only provided for SSE servers ### Roots -Clients support the MCP Roots feature out of the box, including roots-changed notifications. -Roots can be added and removed from a `Client` with `AddRoots` and `RemoveRoots`: +Clients support the MCP Roots feature, including roots-changed notifications. +Roots can be added and removed from a `Client` with `AddRoots` and +`RemoveRoots`: ```go // AddRoots adds the given roots to the client, @@ -722,10 +729,10 @@ func (*Client) AddRoots(roots ...*Root) func (*Client) RemoveRoots(uris ...string) ``` -Server sessions can call the spec method `ListRoots` to get the roots. If a server installs a -`RootsChangedHandler`, it will be called when the client sends a roots-changed -notification, which happens whenever the list of roots changes after a -connection has been established. +Server sessions can call the spec method `ListRoots` to get the roots. If a +server installs a `RootsChangedHandler`, it will be called when the client +sends a roots-changed notification, which happens whenever the list of roots +changes after a connection has been established. ```go type ServerOptions { @@ -735,6 +742,14 @@ type ServerOptions { } ``` +The `Roots` method provides a +[cached](https://modelcontextprotocol.io/specification/2025-03-26/client/roots#implementation-guidelines) +iterator of the root set, invalidated when roots change. + +```go +func (*ServerSession) Roots(context.Context) (iter.Seq[*Root, error]) +``` + ### Sampling Clients that support sampling are created with a `CreateMessageHandler` option From b1e5d850d388988c6b7fc31b5ebddb7f47684195 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 15 May 2025 13:11:35 -0400 Subject: [PATCH 083/196] internal/gocommand: move and reenable TestRmdirAfterGoList ...this time with logic to WalkDir the unremovable directory. Updates golang/go#73481 Change-Id: If7cb9a4de4b02b02d5771509c2a125e1460a4bde Reviewed-on: https://go-review.googlesource.com/c/tools/+/673197 Auto-Submit: Alan Donovan LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- go/packages/packages_test.go | 88 ---------------------------- internal/gocommand/invoke_test.go | 97 +++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 88 deletions(-) diff --git a/go/packages/packages_test.go b/go/packages/packages_test.go index 2623aa5a03b..ae3cbb6bb2b 100644 --- a/go/packages/packages_test.go +++ b/go/packages/packages_test.go @@ -28,9 +28,7 @@ import ( "time" "github.com/google/go-cmp/cmp" - "golang.org/x/sync/errgroup" "golang.org/x/tools/go/packages" - "golang.org/x/tools/internal/gocommand" "golang.org/x/tools/internal/packagesinternal" "golang.org/x/tools/internal/packagestest" "golang.org/x/tools/internal/testenv" @@ -3402,89 +3400,3 @@ func writeTree(t *testing.T, archive string) string { } return root } - -// This is not a test of go/packages at all: it's a test of whether it -// is possible to delete the directory used by go list once it has -// finished. It is intended to evaluate the hypothesis (to explain -// issue #71544) that the go command, on Windows, occasionally fails -// to release all its handles to the temporary directory even when it -// should have finished. -// -// If this test ever fails, the combination of the gocommand package -// and the go command itself has a bug. -func TestRmdirAfterGoList_Runner(t *testing.T) { - t.Skip("golang/go#73503: this test is frequently flaky") - - testRmdirAfterGoList(t, func(ctx context.Context, dir string) { - var runner gocommand.Runner - stdout, stderr, friendlyErr, err := runner.RunRaw(ctx, gocommand.Invocation{ - Verb: "list", - Args: []string{"-json", "example.com/p"}, - WorkingDir: dir, - }) - if ctx.Err() != nil { - return // don't report error if canceled - } - if err != nil || friendlyErr != nil { - t.Fatalf("go list failed: %v, %v (stdout=%s stderr=%s)", - err, friendlyErr, stdout, stderr) - } - }) -} - -// TestRmdirAfterGoList_Direct is a variant of -// TestRmdirAfterGoList_Runner that executes go list directly, to -// control for the substantial logic of the gocommand package. -// -// If this test ever fails, the go command itself has a bug. -func TestRmdirAfterGoList_Direct(t *testing.T) { - testRmdirAfterGoList(t, func(ctx context.Context, dir string) { - cmd := exec.Command("go", "list", "-json", "example.com/p") - cmd.Dir = dir - cmd.Stdout = new(strings.Builder) - cmd.Stderr = new(strings.Builder) - err := cmd.Run() - if ctx.Err() != nil { - return // don't report error if canceled - } - if err != nil { - t.Fatalf("go list failed: %v (stdout=%s stderr=%s)", - err, cmd.Stdout, cmd.Stderr) - } - }) -} - -func testRmdirAfterGoList(t *testing.T, f func(ctx context.Context, dir string)) { - testenv.NeedsExec(t) - - dir := t.TempDir() - if err := os.Mkdir(filepath.Join(dir, "p"), 0777); err != nil { - t.Fatalf("mkdir p: %v", err) - } - - // Create a go.mod file and 100 trivial Go files for the go command to read. - if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com"), 0666); err != nil { - t.Fatal(err) - } - for i := range 100 { - filename := filepath.Join(dir, fmt.Sprintf("p/%d.go", i)) - if err := os.WriteFile(filename, []byte("package p"), 0666); err != nil { - t.Fatal(err) - } - } - - g, ctx := errgroup.WithContext(context.Background()) - for range 10 { - g.Go(func() error { - f(ctx, dir) - // Return an error so that concurrent invocations are canceled. - return fmt.Errorf("oops") - }) - } - g.Wait() // ignore expected error - - // This is the critical operation. - if err := os.RemoveAll(dir); err != nil { - t.Fatalf("failed to remove temp dir: %v", err) - } -} diff --git a/internal/gocommand/invoke_test.go b/internal/gocommand/invoke_test.go index ab1c7b1a11f..7e29135633c 100644 --- a/internal/gocommand/invoke_test.go +++ b/internal/gocommand/invoke_test.go @@ -6,8 +6,15 @@ package gocommand_test import ( "context" + "fmt" + "io/fs" + "os" + "os/exec" + "path/filepath" + "strings" "testing" + "golang.org/x/sync/errgroup" "golang.org/x/tools/internal/gocommand" "golang.org/x/tools/internal/testenv" ) @@ -23,3 +30,93 @@ func TestGoVersion(t *testing.T) { t.Error(err) } } + +// This is not a test of go/packages at all: it's a test of whether it +// is possible to delete the directory used by go list once it has +// finished. It is intended to evaluate the hypothesis (to explain +// issue #71544) that the go command, on Windows, occasionally fails +// to release all its handles to the temporary directory even when it +// should have finished. +// +// If this test ever fails, the combination of the gocommand package +// and the go command itself has a bug; this has been observed (#73503). +func TestRmdirAfterGoList_Runner(t *testing.T) { + testRmdirAfterGoList(t, func(ctx context.Context, dir string) { + var runner gocommand.Runner + stdout, stderr, friendlyErr, err := runner.RunRaw(ctx, gocommand.Invocation{ + Verb: "list", + Args: []string{"-json", "example.com/p"}, + WorkingDir: dir, + }) + if ctx.Err() != nil { + return // don't report error if canceled + } + if err != nil || friendlyErr != nil { + t.Fatalf("go list failed: %v, %v (stdout=%s stderr=%s)", + err, friendlyErr, stdout, stderr) + } + }) +} + +// TestRmdirAfterGoList_Direct is a variant of +// TestRmdirAfterGoList_Runner that executes go list directly, to +// control for the substantial logic of the gocommand package. +// +// If this test ever fails, the go command itself has a bug; as of May +// 2025 this has never been observed. +func TestRmdirAfterGoList_Direct(t *testing.T) { + testRmdirAfterGoList(t, func(ctx context.Context, dir string) { + cmd := exec.Command("go", "list", "-json", "example.com/p") + cmd.Dir = dir + cmd.Stdout = new(strings.Builder) + cmd.Stderr = new(strings.Builder) + err := cmd.Run() + if ctx.Err() != nil { + return // don't report error if canceled + } + if err != nil { + t.Fatalf("go list failed: %v (stdout=%s stderr=%s)", + err, cmd.Stdout, cmd.Stderr) + } + }) +} + +func testRmdirAfterGoList(t *testing.T, f func(ctx context.Context, dir string)) { + testenv.NeedsExec(t) + + dir := t.TempDir() + if err := os.Mkdir(filepath.Join(dir, "p"), 0777); err != nil { + t.Fatalf("mkdir p: %v", err) + } + + // Create a go.mod file and 100 trivial Go files for the go command to read. + if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com"), 0666); err != nil { + t.Fatal(err) + } + for i := range 100 { + filename := filepath.Join(dir, fmt.Sprintf("p/%d.go", i)) + if err := os.WriteFile(filename, []byte("package p"), 0666); err != nil { + t.Fatal(err) + } + } + + g, ctx := errgroup.WithContext(context.Background()) + for range 10 { + g.Go(func() error { + f(ctx, dir) + // Return an error so that concurrent invocations are canceled. + return fmt.Errorf("oops") + }) + } + g.Wait() // ignore expected error + + // This is the critical operation. + if err := os.RemoveAll(dir); err != nil { + t.Errorf("failed to remove temp dir: %v", err) + // List the contents of the directory, for clues. + filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + t.Log(path, d, err) + return nil + }) + } +} From ec0dab2f66632c2b99db8119b02a5d8a334f378c Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 15 May 2025 17:59:06 +0000 Subject: [PATCH 084/196] internal/mcp/design: sync with public design Due to the irregularities of github-flavored markdown, I had to join all paragraphs before posting the public design. Additionally, I made a few wording tweaks. If we want to continue editing the design here, we should keep it in sync with the public design discussion. Therefore, copy back those changes. Change-Id: Ie6ab3b3f59b923bb58540e86bbb1451ce0b4815d Reviewed-on: https://go-review.googlesource.com/c/tools/+/673199 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/mcp/design/design.md | 575 ++++++++-------------------------- 1 file changed, 134 insertions(+), 441 deletions(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 1c1d559ef95..e3c02db2284 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -1,119 +1,61 @@ -# Go MCP SDK design - -This document discusses the design of a Go SDK for the [model context -protocol](https://modelcontextprotocol.io/specification/2025-03-26). It is -intended to seed a GitHub discussion about the official Go MCP SDK. - -The -[golang.org/x/tools/internal/mcp](https://pkg.go.dev/golang.org/x/tools/internal/mcp@master) -package contains a prototype that we built to explore the MCP design space. -Many of the ideas there are present in this document. However, we have diverged -from and expanded on the APIs of that prototype, and this document should be -considered canonical. - -## Similarities and differences with mark3labs/mcp-go - -The most popular unofficial MCP SDK for Go is -[mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go). As of this -writing, it is imported by over 400 packages that span over 200 modules. - -We admire mcp-go, and seriously considered simply adopting it as a starting -point for this SDK. However, as we looked at doing so, we realized that a -significant amount of its API would probably need to change. In some cases, -mcp-go has older APIs that predated newer variations—an obvious opportunity -for cleanup. In others, it took a batteries-included approach that is probably -not viable for an official SDK. In yet others, we simply think there is room for -API refinement, and we should take this opportunity to consider our options. -Therefore, we wrote this document as though it were proposing a new -implementation. Nevertheless, much of the API discussed here originated from or -was inspired by mcp-go and other unofficial SDKs. - -Since mcp-go is so influential and popular, we have noted significant -differences from its API in the sections below. Although the API here is not -compatible with mcp-go, translating between them should be straightforward in -most cases. -(Later, we will provide a detailed translation guide.) +# Go SDK Design + +This document discusses the design of a Go SDK for the [model context protocol](https://modelcontextprotocol.io/specification/2025-03-26). The [golang.org/x/tools/internal/mcp](https://pkg.go.dev/golang.org/x/tools/internal/mcp@master) package contains a prototype that we built to explore the MCP design space. Many of the ideas there are present in this document. However, we have diverged from and expanded on the APIs of that prototype, and this document should be considered canonical. + +## Similarities and differences with mark3labs/mcp-go (and others) + +The most popular unofficial MCP SDK for Go is [mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go). As of this writing, it is imported by over 400 packages that span over 200 modules. + +We admire mcp-go, and where possible tried to align with its design. However, the APIs here diverge in a number of ways in order to keep the official SDK minimal, allow for future spec evolution, and support additional features. We have noted significant differences from mcp-go in the sections below. Although the API here is not compatible with mcp-go, translating between them should be straightforward in most cases. (Later, we will provide a detailed translation guide.) + +Thank you to everyone who contributes to mcp-go and other Go SDKs. We hope that we can collaborate to leverage all that we've learned about MCP and Go in an official SDK. # Requirements -These may be obvious, but it's worthwhile to define goals for an official MCP -SDK. An official SDK should aim to be: - -- **complete**: it should be possible to implement every feature of the MCP - spec, and these features should conform to all of the semantics described by - the spec. -- **idiomatic**: as much as possible, MCP features should be modeled using - features of the Go language and its standard library. Additionally, the SDK - should repeat idioms from similar domains. -- **robust**: the SDK itself should be well tested and reliable, and should - enable easy testability for its users. -- **future-proof**: the SDK should allow for future evolution of the MCP spec, - in such a way that we can (as much as possible) avoid incompatible changes to - the SDK API. -- **extensible**: to best serve the previous four concerns, the SDK should be - minimal. However, it should admit extensibility using (for example) simple - interfaces, middleware, or hooks. +These may be obvious, but it's worthwhile to define goals for an official MCP SDK. An official SDK should aim to be: + +- **complete**: it should be possible to implement every feature of the MCP spec, and these features should conform to all of the semantics described by the spec. +- **idiomatic**: as much as possible, MCP features should be modeled using features of the Go language and its standard library. Additionally, the SDK should repeat idioms from similar domains. +- **robust**: the SDK itself should be well tested and reliable, and should enable easy testability for its users. +- **future-proof**: the SDK should allow for future evolution of the MCP spec, in such a way that we can (as much as possible) avoid incompatible changes to the SDK API. +- **extensible**: to best serve the previous four concerns, the SDK should be minimal. However, it should admit extensibility using (for example) simple interfaces, middleware, or hooks. # Design considerations -In the sections below, we visit each aspect of the MCP spec, in approximately -the order they are presented by the [official spec](https://modelcontextprotocol.io/specification/2025-03-26) -For each, we discuss considerations for the Go implementation, and propose a Go API. +In the sections below, we visit each aspect of the MCP spec, in approximately the order they are presented by the [official spec](https://modelcontextprotocol.io/specification/2025-03-26) For each, we discuss considerations for the Go implementation, and propose a Go API. ## Foundations ### Package layout -In the sections that follow, it is assumed that most of the MCP API lives in a -single shared package, the `mcp` package. This is inconsistent with other MCP -SDKs, but is consistent with Go packages like `net/http`, `net/rpc`, or -`google.golang.org/grpc`. We believe that having a single package aids -discoverability in package documentation and in the IDE. Furthermore, it avoids -arbitrary decisions about package structure that may be rendered inaccurate by -future evolution of the spec. +In the sections that follow, it is assumed that most of the MCP API lives in a single shared package, the `mcp` package. This is inconsistent with other MCP SDKs, but is consistent with Go packages like `net/http`, `net/rpc`, or `google.golang.org/grpc`. We believe that having a single package aids discoverability in package documentation and in the IDE. Furthermore, it avoids arbitrary decisions about package structure that may be rendered inaccurate by future evolution of the spec. -Functionality that is not directly related to MCP (like jsonschema or jsonrpc2) -belongs in a separate package. +Functionality that is not directly related to MCP (like jsonschema or jsonrpc2) belongs in a separate package. -Therefore, this is the core package layout, assuming -github.com/modelcontextprotocol/go-sdk as the module path. +Therefore, this is the core package layout, assuming github.com/modelcontextprotocol/go-sdk as the module path. - `github.com/modelcontextprotocol/go-sdk/mcp`: the bulk of the user facing API - `github.com/modelcontextprotocol/go-sdk/jsonschema`: a jsonschema implementation, with validation - `github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2`: a fork of x/tools/internal/jsonrpc2_v2 -The JSON-RPC implementation is hidden, to avoid tight coupling. As described in -the next section, the only aspects of JSON-RPC that need to be exposed in the -SDK are the message types, for the purposes of defining custom transports. We -can expose these types by promoting them from the `mcp` package using aliases -or wrappers. +The JSON-RPC implementation is hidden, to avoid tight coupling. As described in the next section, the only aspects of JSON-RPC that need to be exposed in the SDK are the message types, for the purposes of defining custom transports. We can expose these types by promoting them from the `mcp` package using aliases or wrappers. -**Difference from mcp-go**: Our `mcp` package includes all the functionality of -mcp-go's `mcp`, `client`, `server` and `transport` packages. +**Difference from mcp-go**: Our `mcp` package includes all the functionality of mcp-go's `mcp`, `client`, `server` and `transport` packages. ### JSON-RPC and Transports -The MCP is defined in terms of client-server communication over bidirectional -JSON-RPC message streams. Specifically, version `2025-03-26` of the spec -defines two transports: +The MCP is defined in terms of client-server communication over bidirectional JSON-RPC message streams. Specifically, version `2025-03-26` of the spec defines two transports: - **stdio**: communication with a subprocess over stdin/stdout. -- **streamable http**: communication over a relatively complicated series of - text/event-stream GET and HTTP POST requests. +- **streamable http**: communication over a relatively complicated series of text/event-stream GET and HTTP POST requests. -Additionally, version `2024-11-05` of the spec defined a simpler (yet stateful) -HTTP transport: +Additionally, version `2024-11-05` of the spec defined a simpler (yet stateful) HTTP transport: -- **sse**: client issues a hanging GET request and receives messages via - `text/event-stream`, and sends messages via POST to a session endpoint. +- **sse**: client issues a hanging GET request and receives messages via `text/event-stream`, and sends messages via POST to a session endpoint. -Furthermore, the spec -[states](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#custom-transports) -that it must be possible for users to define their own custom transports. +Furthermore, the spec [states](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#custom-transports) that it must be possible for users to define their own custom transports. -Given the diversity of the transport implementations, they can be challenging -to abstract. However, since JSON-RPC requires a bidirectional stream, we can -use this to model the MCP transport abstraction: +Given the diversity of the transport implementations, they can be challenging to abstract. However, since JSON-RPC requires a bidirectional stream, we can use this to model the MCP transport abstraction: ```go // A Transport is used to create a bidirectional connection between MCP client @@ -130,42 +72,23 @@ type Stream interface { } ``` -Methods accept a Go `Context` and return an `error`, as is idiomatic for APIs -that do I/O. +Methods accept a Go `Context` and return an `error`, as is idiomatic for APIs that do I/O. -A `Transport` is something that connects a logical JSON-RPC stream, and nothing -more. Streams must be closeable in order to implement client and server -shutdown, and therefore conform to the `io.Closer` interface. +A `Transport` is something that connects a logical JSON-RPC stream, and nothing more. Streams must be closeable in order to implement client and server shutdown, and therefore conform to the `io.Closer` interface. -Other SDKs define higher-level transports, with, for example, methods to send a -notification or make a call. Those are jsonrpc2 operations on top of the -logical stream, and the lower-level interface is easier to implement in most -cases, which means it is easier to implement custom transports. +Other SDKs define higher-level transports, with, for example, methods to send a notification or make a call. Those are jsonrpc2 operations on top of the logical stream, and the lower-level interface is easier to implement in most cases, which means it is easier to implement custom transports. -For our prototype, we've used an internal `jsonrpc2` package based on the Go -language server `gopls`, which we propose to fork for the MCP SDK. It already -handles concerns like client/server connection, request lifecycle, -cancellation, and shutdown. +For our prototype, we've used an internal `jsonrpc2` package based on the Go language server `gopls`, which we propose to fork for the MCP SDK. It already handles concerns like client/server connection, request lifecycle, cancellation, and shutdown. -**Differences from mcp-go**: The Go team has a battle-tested JSON-RPC -implementation that we use for gopls, our Go LSP server. We are using the new -version of this library as part of our MCP SDK. It handles all JSON-RPC 2.0 -features, including cancellation. +**Differences from mcp-go**: The Go team has a battle-tested JSON-RPC implementation that we use for gopls, our Go LSP server. We are using the new version of this library as part of our MCP SDK. It handles all JSON-RPC 2.0 features, including cancellation. -The `Transport` interface here is lower-level than that of mcp-go, but serves a -similar purpose. We believe the lower-level interface is easier to implement. +The `Transport` interface here is lower-level than that of mcp-go, but serves a similar purpose. We believe the lower-level interface is easier to implement. #### stdio transports -In the MCP Spec, the **stdio** transport uses newline-delimited JSON to -communicate over stdin/stdout. It's possible to model both client side and -server side of this communication with a shared type that communicates over an -`io.ReadWriteCloser`. However, for the purposes of future-proofing, we should -use a different types for client and server stdio transport. +In the MCP Spec, the **stdio** transport uses newline-delimited JSON to communicate over stdin/stdout. It's possible to model both client side and server side of this communication with a shared type that communicates over an `io.ReadWriteCloser`. However, for the purposes of future-proofing, we should use a different types for client and server stdio transport. -The `CommandTransport` is the client side of the stdio transport, and -connects by starting a command and binding its jsonrpc2 stream to its -stdin/stdout. +The `CommandTransport` is the client side of the stdio transport, and connects by starting a command and binding its jsonrpc2 stream to its stdin/stdout. ```go // A CommandTransport is a [Transport] that runs a command and communicates @@ -180,8 +103,7 @@ func NewCommandTransport(cmd *exec.Command) *CommandTransport func (*CommandTransport) Connect(ctx context.Context) (Stream, error) { ``` -The `StdIOTransport` is the server side of the stdio transport, and connects by -binding to `os.Stdin` and `os.Stdout`. +The `StdIOTransport` is the server side of the stdio transport, and connects by binding to `os.Stdin` and `os.Stdout`. ```go // A StdIOTransport is a [Transport] that communicates using newline-delimited @@ -195,18 +117,9 @@ func (t *StdIOTransport) Connect(context.Context) (Stream, error) #### HTTP transports -The HTTP transport APIs are even more asymmetrical. Since connections are initiated -via HTTP requests, the client developer will create a transport, but -the server developer will typically install an HTTP handler. Internally, the -HTTP handler will create a logical transport for each new client connection. +The HTTP transport APIs are even more asymmetrical. Since connections are initiated via HTTP requests, the client developer will create a transport, but the server developer will typically install an HTTP handler. Internally, the HTTP handler will create a logical transport for each new client connection. -Importantly, since they serve many connections, the HTTP handlers must accept a -callback to get an MCP server for each new session. As described below, MCP -servers can optionally connect to multiple clients. This allows customization -of per-session servers: if the MCP server is stateless, the user can return the -same MCP server for each connection. On the other hand, if any per-session -customization is required, it is possible by returning a different `Server` -instance for each connection. +Importantly, since they serve many connections, the HTTP handlers must accept a callback to get an MCP server for each new session. As described below, MCP servers can optionally connect to multiple clients. This allows customization of per-session servers: if the MCP server is stateless, the user can return the same MCP server for each connection. On the other hand, if any per-session customization is required, it is possible by returning a different `Server` instance for each connection. ```go // SSEHTTPHandler is an http.Handler that serves SSE-based MCP sessions as defined by @@ -226,15 +139,9 @@ func (*SSEHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) func (*SSEHTTPHandler) Close() error ``` -Notably absent are options to hook into low-level request handling for the purposes -of authentication or context injection. These concerns are better handled using -standard HTTP middleware patterns. For middleware at the level of the MCP protocol, -see [Middleware](#Middleware) below. +Notably absent are options to hook into low-level request handling for the purposes of authentication or context injection. These concerns are instead handled using standard HTTP middleware patterns. For middleware at the level of the MCP protocol, see [Middleware](#Middleware) below. -By default, the SSE handler creates messages endpoints with the -`?sessionId=...` query parameter. Users that want more control over the -management of sessions and session endpoints may write their own handler, and -create `SSEServerTransport` instances themselves for incoming GET requests. +By default, the SSE handler creates messages endpoints with the `?sessionId=...` query parameter. Users that want more control over the management of sessions and session endpoints may write their own handler, and create `SSEServerTransport` instances themselves for incoming GET requests. ```go // A SSEServerTransport is a logical SSE session created through a hanging GET @@ -307,30 +214,13 @@ func NewStreamableClientTransport(url string) *StreamableClientTransport { func (*StreamableClientTransport) Connect(context.Context) (Stream, error) ``` -**Differences from mcp-go**: In mcp-go, server authors create an `MCPServer`, -populate it with tools, resources and so on, and then wrap it in an `SSEServer` -or `StdioServer`. Users can manage their own sessions with `RegisterSession` -and `UnregisterSession`. Rather than use a server constructor to get a distinct -server for each connection, there is a concept of a "session tool" that -overlays tools for a specific session. - -We find the similarity in names among the three server types to be confusing, -and we could not discover any uses of the session methods in the open-source -ecosystem. Furthermore, we believe that a server factory (`getServer`) provides -equivalent functionality as the per-session logic of mcp-go, with a smaller API -surface and fewer overlapping concepts. +**Differences from mcp-go**: In mcp-go, server authors create an `MCPServer`, populate it with tools, resources and so on, and then wrap it in an `SSEServer` or `StdioServer`. Users can manage their own sessions with `RegisterSession` and `UnregisterSession`. Rather than use a server constructor to get a distinct server for each connection, there is a concept of a "session tool" that overlays tools for a specific session. -Additionally, individual handlers and transports here have a minimal API, and -do not expose internal details. Customization of things like handlers or -session management is intended to be implemented with middleware and/or -compositional patterns. +Here, we tried to differentiate the concept of a `Server`, `HTTPHandler`, and `Transport`, and provide per-session customization through either the `getServer` constructor or middleware. Additionally, individual handlers and transports here have a minimal API, and do not expose internal details. (Open question: are we oversimplifying?) #### Other transports -We also provide a couple of transport implementations for special scenarios. -An InMemoryTransport can be used when the client and server reside in the same -process. A LoggingTransport is a middleware layer that logs RPC logs to a desired -location, specified as an io.Writer. +We also provide a couple of transport implementations for special scenarios. An InMemoryTransport can be used when the client and server reside in the same process. A LoggingTransport is a middleware layer that logs RPC logs to a desired location, specified as an io.Writer. ```go // An InMemoryTransport is a [Transport] that communicates over an in-memory @@ -349,21 +239,13 @@ func NewLoggingTransport(delegate Transport, w io.Writer) *LoggingTransport ### Protocol types -Types needed for the protocol are generated from the -[JSON schema of the MCP spec](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json). +Types needed for the protocol are generated from the [JSON schema of the MCP spec](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json). -These types will be included in the `mcp` package, but will be unexported -unless they are needed for the user-facing API. Notably, JSON-RPC request types -are elided, since they are handled by the `jsonrpc2` package and should not be -observed by the user. +These types will be included in the `mcp` package, but will be unexported unless they are needed for the user-facing API. Notably, JSON-RPC request types are elided, since they are handled by the `jsonrpc2` package and should not be observed by the user. -For user-provided data, we use `json.RawMessage`, so that -marshalling/unmarshalling can be delegated to the business logic of the client -or server. +For user-provided data, we use `json.RawMessage`, so that marshalling/unmarshalling can be delegated to the business logic of the client or server. -For union types, which can't be represented in Go (specifically `Content` and -`ResourceContents`), we prefer distinguished unions: struct types with fields -corresponding to the union of all properties for union elements. +For union types, which can't be represented in Go (specifically `Content` and `ResourceContents`), we prefer distinguished unions: struct types with fields corresponding to the union of all properties for union elements. For brevity, only a few examples are shown here: @@ -395,39 +277,19 @@ func NewTextContent(text string) *Content // etc. ``` -**Differences from mcp-go**: these types are largely similar, but our type -generator flattens types rather than using struct embedding. +**Differences from mcp-go**: these types are largely similar, but our type generator flattens types rather than using struct embedding. ### Clients and Servers -Generally speaking, the SDK is used by creating a `Client` or `Server` -instance, adding features to it, and connecting it to a peer. - -However, the SDK must make a non-obvious choice in these APIs: are clients 1:1 -with their logical connections? What about servers? Both clients and servers -are stateful: users may add or remove roots from clients, and tools, prompts, -and resources from servers. Additionally, handlers for these features may -themselves be stateful, for example if a tool handler caches state from earlier -requests in the session. - -We believe that in the common case, any change to a client or server, such as -adding a tool, is intended for all its peers. It is therefore more useful to -allow multiple connections from a client, and to a server. This is similar to -the `net/http` packages, in which an `http.Client` and `http.Server` each may -handle multiple unrelated connections. When users add features to a client or -server, all connected peers are notified of the change. - -Supporting multiple connections to servers (and from clients) still allows for -stateful components, as it is up to the user to decide whether or not to create -distinct servers/clients for each connection. For example, if the user wants to -create a distinct server for each new connection, they can do so in the -`getServer` factory passed to transport handlers. - -Following the terminology of the -[spec](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#session-management), -we call the logical connection between a client and server a "session." There -must necessarily be a `ClientSession` and a `ServerSession`, corresponding to -the APIs available from the client and server perspective, respectively. +Generally speaking, the SDK is used by creating a `Client` or `Server` instance, adding features to it, and connecting it to a peer. + +However, the SDK must make a non-obvious choice in these APIs: are clients 1:1 with their logical connections? What about servers? Both clients and servers are stateful: users may add or remove roots from clients, and tools, prompts, and resources from servers. Additionally, handlers for these features may themselves be stateful, for example if a tool handler caches state from earlier requests in the session. + +We believe that in the common case, any change to a client or server, such as adding a tool, is intended for all its peers. It is therefore more useful to allow multiple connections from a client, and to a server. This is similar to the `net/http` packages, in which an `http.Client` and `http.Server` each may handle multiple unrelated connections. When users add features to a client or server, all connected peers are notified of the change. + +Supporting multiple connections to servers (and from clients) still allows for stateful components, as it is up to the user to decide whether or not to create distinct servers/clients for each connection. For example, if the user wants to create a distinct server for each new connection, they can do so in the `getServer` factory passed to transport handlers. + +Following the terminology of the [spec](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#session-management), we call the logical connection between a client and server a "session." There must necessarily be a `ClientSession` and a `ServerSession`, corresponding to the APIs available from the client and server perspective, respectively. ``` Client Server @@ -435,8 +297,7 @@ Client Server ClientSession ⇄ Client Transport ⇄ Server Transport ⇄ ServerSession ``` -Sessions are created from either `Client` or `Server` using the `Connect` -method. +Sessions are created from either `Client` or `Server` using the `Connect` method. ```go type Client struct { /* ... */ } @@ -497,68 +358,33 @@ session, err := server.Connect(ctx, transport) return session.Wait() ``` -For convenience, we provide `Server.Run` to handle the common case of running a -session until the client disconnects: +For convenience, we provide `Server.Run` to handle the common case of running a session until the client disconnects: ```go func (*Server) Run(context.Context, Transport) ``` -**Differences from mcp-go**: the Server APIs are similar to mcp-go, though the -association between servers and transports is different. In mcp-go, a single -server is bound to what we would call an `SSEHTTPHandler`, and reused for all -sessions. Per-session behavior is implemented though a 'session tool' overlay. -As discussed above, the transport abstraction here is differentiated from HTTP -serving, and the `Server.Connect` method provides a consistent API for binding -to an arbitrary transport. Servers here do not have methods for sending -notifications or calls, because they are logically distinct from the -`ServerSession`. In mcp-go, servers are `n:1`, but there is no abstraction of a -server session: sessions are addressed in Server APIs through their -`sessionID`: `SendNotificationToAllClients`, `SendNotificationToClient`, -`SendNotificationToSpecificClient`. - -The client API here is different, since clients and client sessions are -conceptually distinct. The `ClientSession` is closer to mcp-go's notion of -Client. - -For both clients and servers, mcp-go uses variadic options to customize -behavior, whereas an options struct is used here. We felt that in this case, an -options struct would be more readable, and result in simpler package -documentation. +**Differences from mcp-go**: the Server APIs are similar to mcp-go, though the association between servers and transports is different. In mcp-go, a single server is bound to what we would call an `SSEHTTPHandler`, and reused for all sessions. Per-session behavior is implemented though a 'session tool' overlay. As discussed above, the transport abstraction here is differentiated from HTTP serving, and the `Server.Connect` method provides a consistent API for binding to an arbitrary transport. Servers here do not have methods for sending notifications or calls, because they are logically distinct from the `ServerSession`. In mcp-go, servers are `n:1`, but there is no abstraction of a server session: sessions are addressed in Server APIs through their `sessionID`: `SendNotificationToAllClients`, `SendNotificationToClient`, `SendNotificationToSpecificClient`. + +The client API here is different, since clients and client sessions are conceptually distinct. The `ClientSession` is closer to mcp-go's notion of Client. + +For both clients and servers, mcp-go uses variadic options to customize behavior, whereas an options struct is used here. We felt that in this case, an options struct would be more readable, and result in simpler package documentation. ### Spec Methods -In our SDK, RPC methods that are defined in the specification take a context and -a params pointer as arguments, and return a result pointer and error: +In our SDK, RPC methods that are defined in the specification take a context and a params pointer as arguments, and return a result pointer and error: ```go func (*ClientSession) ListTools(context.Context, *ListToolsParams) (*ListToolsResult, error) ``` -Our SDK has a method for every RPC in the spec, and except for `CallTool`, -their signatures all share this form. We do this, rather than providing more -convenient shortcut signatures, to maintain backward compatibility if the spec -makes backward-compatible changes such as adding a new property to the request -parameters -(as in [this commit](https://github.com/modelcontextprotocol/modelcontextprotocol/commit/2fce8a077688bf8011e80af06348b8fe1dae08ac), -for example). -To avoid boilerplate, we don't repeat this signature for RPCs defined in the -spec; readers may assume it when we mention a "spec method." - -`CallTool` is the only exception: for convenience, it takes the tool name and -arguments, with an options struct for additional request fields. See the -section on Tools below for details. - -Why do we use params instead of the full JSON-RPC request? As much as possible, -we endeavor to hide JSON-RPC details when they are not relevant to the business -logic of your client or server. In this case, the additional information in the -JSON-RPC request is just the request ID and method name; the request ID is -irrelevant, and the method name is implied by the name of the Go method -providing the API. - -We believe that any change to the spec that would require callers to pass a new a -parameter is not backward compatible. Therefore, it will always work to pass -`nil` for any `XXXParams` argument that isn't currently necessary. For example, it is okay to call `Ping` like so: +Our SDK has a method for every RPC in the spec, and except for `CallTool`, their signatures all share this form. We do this, rather than providing more convenient shortcut signatures, to maintain backward compatibility if the spec makes backward-compatible changes such as adding a new property to the request parameters (as in [this commit](https://github.com/modelcontextprotocol/modelcontextprotocol/commit/2fce8a077688bf8011e80af06348b8fe1dae08ac), for example). To avoid boilerplate, we don't repeat this signature for RPCs defined in the spec; readers may assume it when we mention a "spec method." + +`CallTool` is the only exception: for convenience, it takes the tool name and arguments, with an options struct for additional request fields. See the section on Tools below for details. + +Why do we use params instead of the full JSON-RPC request? As much as possible, we endeavor to hide JSON-RPC details when they are not relevant to the business logic of your client or server. In this case, the additional information in the JSON-RPC request is just the request ID and method name; the request ID is irrelevant, and the method name is implied by the name of the Go method providing the API. + +We believe that any change to the spec that would require callers to pass a new a parameter is not backward compatible. Therefore, it will always work to pass `nil` for any `XXXParams` argument that isn't currently necessary. For example, it is okay to call `Ping` like so: ```go err := session.Ping(ctx, nil) @@ -566,9 +392,7 @@ err := session.Ping(ctx, nil) #### Iterator Methods -For convenience, iterator methods handle pagination for the `List` spec methods -automatically, traversing all pages. If Params are supplied, iteration begins -from the provided cursor (if present). +For convenience, iterator methods handle pagination for the `List` spec methods automatically, traversing all pages. If Params are supplied, iteration begins from the provided cursor (if present). ```go func (*ClientSession) Tools(context.Context, *ListToolsParams) iter.Seq2[Tool, error] @@ -582,8 +406,7 @@ func (*ClientSession) ResourceTemplates(context.Context, *ListResourceTemplatesP ### Middleware -We provide a mechanism to add MCP-level middleware on the server side, which runs after the -request has been parsed but before any normal handling. +We provide a mechanism to add MCP-level middleware on the server side, which runs after the request has been parsed but before any normal handling. ```go // A Dispatcher dispatches an MCP message to the appropriate handler. @@ -610,19 +433,13 @@ func withLogging(h mcp.Dispatcher) mcp.Dispatcher { server.AddDispatchers(withLogging) ``` -**Differences from mcp-go**: Version 0.26.0 of mcp-go defines 24 server hooks. -Each hook consists of a field in the `Hooks` struct, a `Hooks.Add` method, and -a type for the hook function. These are rarely used. The most common is -`OnError`, which occurs fewer than ten times in open-source code. +**Differences from mcp-go**: Version 0.26.0 of mcp-go defines 24 server hooks. Each hook consists of a field in the `Hooks` struct, a `Hooks.Add` method, and a type for the hook function. These are rarely used. The most common is `OnError`, which occurs fewer than ten times in open-source code. ### Errors -With the exception of tool handler errors, protocol errors are handled -transparently as Go errors: errors in server-side feature handlers are -propagated as errors from calls from the `ClientSession`, and vice-versa. +With the exception of tool handler errors, protocol errors are handled transparently as Go errors: errors in server-side feature handlers are propagated as errors from calls from the `ClientSession`, and vice-versa. -Protocol errors wrap a `JSONRPCError` type which exposes its underlying error -code. +Protocol errors wrap a `JSONRPCError` type which exposes its underlying error code. ```go type JSONRPCError struct { @@ -632,18 +449,13 @@ type JSONRPCError struct { } ``` -As described by the -[spec](https://modelcontextprotocol.io/specification/2025-03-26/server/tools#error-handling), -tool execution errors are reported in tool results. +As described by the [spec](https://modelcontextprotocol.io/specification/2025-03-26/server/tools#error-handling), tool execution errors are reported in tool results. -**Differences from mcp-go**: the `JSONRPCError` type here does not include ID -and Method, which can be inferred from the caller. Otherwise, this behavior is -similar. +**Differences from mcp-go**: the `JSONRPCError` type here does not include ID and Method, which can be inferred from the caller. Otherwise, this behavior is similar. ### Cancellation -Cancellation is implemented transparently using context cancellation. The user -can cancel an operation by cancelling the associated context: +Cancellation is implemented transparently using context cancellation. The user can cancel an operation by cancelling the associated context: ```go ctx, cancel := context.WithCancel(ctx) @@ -651,9 +463,7 @@ go session.CallTool(ctx, "slow", map[string]any{}, nil) cancel() ``` -When this client call is cancelled, a `"notifications/cancelled"` notification -is sent to the server. However, the client call returns immediately with -`ctx.Err()`: it does not wait for the result from the server. +When this client call is cancelled, a `"notifications/cancelled"` notification is sent to the server. However, the client call returns immediately with `ctx.Err()`: it does not wait for the result from the server. The server observes a client cancellation as a cancelled context. @@ -668,9 +478,7 @@ type XXXParams struct { // where XXX is each type of call } ``` -Handlers can notify their peer about progress by calling the `NotifyProgress` -method. The notification is only sent if the peer requested it by providing -a progress token. +Handlers can notify their peer about progress by calling the `NotifyProgress` method. The notification is only sent if the peer requested it by providing a progress token. ```go func (*ClientSession) NotifyProgress(context.Context, *ProgressNotification) @@ -679,19 +487,14 @@ func (*ServerSession) NotifyProgress(context.Context, *ProgressNotification) ### Ping / KeepAlive -Both `ClientSession` and `ServerSession` expose a `Ping` method to call "ping" -on their peer. +Both `ClientSession` and `ServerSession` expose a `Ping` method to call "ping" on their peer. ```go func (c *ClientSession) Ping(ctx context.Context, *PingParams) error func (c *ServerSession) Ping(ctx context.Context, *PingParams) error ``` -Additionally, client and server sessions can be configured with automatic -keepalive behavior. If the `KeepAlive` option is set to a non-zero duration, -it defines an interval for regular "ping" requests. If the peer fails to -respond to pings originating from the keepalive check, the session is -automatically closed. +Additionally, client and server sessions can be configured with automatic keepalive behavior. If the `KeepAlive` option is set to a non-zero duration, it defines an interval for regular "ping" requests. If the peer fails to respond to pings originating from the keepalive check, the session is automatically closed. ```go type ClientOptions struct { @@ -705,17 +508,13 @@ type ServerOptions struct { } ``` -**Differences from mcp-go**: in mcp-go the `Ping` method is only provided for -client, not server, and the keepalive option is only provided for SSE servers -(as a variadic option). +**Differences from mcp-go**: in mcp-go the `Ping` method is only provided for client, not server, and the keepalive option is only provided for SSE servers (as a variadic option). ## Client Features ### Roots -Clients support the MCP Roots feature, including roots-changed notifications. -Roots can be added and removed from a `Client` with `AddRoots` and -`RemoveRoots`: +Clients support the MCP Roots feature, including roots-changed notifications. Roots can be added and removed from a `Client` with `AddRoots` and `RemoveRoots`: ```go // AddRoots adds the given roots to the client, @@ -729,10 +528,7 @@ func (*Client) AddRoots(roots ...*Root) func (*Client) RemoveRoots(uris ...string) ``` -Server sessions can call the spec method `ListRoots` to get the roots. If a -server installs a `RootsChangedHandler`, it will be called when the client -sends a roots-changed notification, which happens whenever the list of roots -changes after a connection has been established. +Server sessions can call the spec method `ListRoots` to get the roots. If a server installs a `RootsChangedHandler`, it will be called when the client sends a roots-changed notification, which happens whenever the list of roots changes after a connection has been established. ```go type ServerOptions { @@ -742,9 +538,7 @@ type ServerOptions { } ``` -The `Roots` method provides a -[cached](https://modelcontextprotocol.io/specification/2025-03-26/client/roots#implementation-guidelines) -iterator of the root set, invalidated when roots change. +The `Roots` method provides a [cached](https://modelcontextprotocol.io/specification/2025-03-26/client/roots#implementation-guidelines) iterator of the root set, invalidated when roots change. ```go func (*ServerSession) Roots(context.Context) (iter.Seq[*Root, error]) @@ -752,8 +546,7 @@ func (*ServerSession) Roots(context.Context) (iter.Seq[*Root, error]) ### Sampling -Clients that support sampling are created with a `CreateMessageHandler` option -for handling server calls. To perform sampling, a server session calls the spec method `CreateMessage`. +Clients that support sampling are created with a `CreateMessageHandler` option for handling server calls. To perform sampling, a server session calls the spec method `CreateMessage`. ```go type ClientOptions struct { @@ -766,8 +559,7 @@ type ClientOptions struct { ### Tools -A `Tool` is a logical MCP tool, generated from the MCP spec, and a `ServerTool` -is a tool bound to a tool handler. +A `Tool` is a logical MCP tool, generated from the MCP spec, and a `ServerTool` is a tool bound to a tool handler. ```go type Tool struct { @@ -799,32 +591,14 @@ Remove them by name with `RemoveTools`: server.RemoveTools("add", "subtract") ``` -A tool's input schema, expressed as a [JSON Schema](https://json-schema.org), -provides a way to validate the tool's input. One of the challenges in defining -tools is the need to associate them with a Go function, yet support the -arbitrary complexity of JSON Schema. To achieve this, we have seen two primary -approaches: +A tool's input schema, expressed as a [JSON Schema](https://json-schema.org), provides a way to validate the tool's input. One of the challenges in defining tools is the need to associate them with a Go function, yet support the arbitrary complexity of JSON Schema. To achieve this, we have seen two primary approaches: -1. Use reflection to generate the tool's input schema from a Go type (à la - `metoro-io/mcp-golang`) +1. Use reflection to generate the tool's input schema from a Go type (à la `metoro-io/mcp-golang`) 2. Explicitly build the input schema (à la `mark3labs/mcp-go`). -Both of these have their advantages and disadvantages. Reflection is nice, -because it allows you to bind directly to a Go API, and means that the JSON -schema of your API is compatible with your Go types by construction. It also -means that concerns like parsing and validation can be handled automatically. -However, it can become cumbersome to express the full breadth of JSON schema -using Go types or struct tags, and sometimes you want to express things that -aren’t naturally modeled by Go types, like unions. Explicit schemas are simple -and readable, and give the caller full control over their tool definition, but -involve significant boilerplate. - -We have found that a hybrid model works well, where the _initial_ schema is -derived using reflection, but any customization on top of that schema is -applied using variadic options. We achieve this using a `NewTool` helper, which -generates the schema from the input type, and wraps the handler to provide -parsing and validation. The schema (and potentially other features) can be -customized using ToolOptions. +Both of these have their advantages and disadvantages. Reflection is nice, because it allows you to bind directly to a Go API, and means that the JSON schema of your API is compatible with your Go types by construction. It also means that concerns like parsing and validation can be handled automatically. However, it can become cumbersome to express the full breadth of JSON schema using Go types or struct tags, and sometimes you want to express things that aren’t naturally modeled by Go types, like unions. Explicit schemas are simple and readable, and give the caller full control over their tool definition, but involve significant boilerplate. + +We have found that a hybrid model works well, where the _initial_ schema is derived using reflection, but any customization on top of that schema is applied using variadic options. We achieve this using a `NewTool` helper, which generates the schema from the input type, and wraps the handler to provide parsing and validation. The schema (and potentially other features) can be customized using ToolOptions. ```go // NewTool creates a Tool using reflection on the given handler. @@ -833,11 +607,7 @@ func NewTool[TInput any](name, description string, handler func(context.Context, type ToolOption interface { /* ... */ } ``` -`NewTool` determines the input schema for a Tool from the struct used -in the handler. Each struct field that would be marshaled by `encoding/json.Marshal` -becomes a property of the schema. The property is required unless -the field's `json` tag specifies "omitempty" or "omitzero" (new in Go 1.24). -For example, given this struct: +`NewTool` determines the input schema for a Tool from the struct used in the handler. Each struct field that would be marshaled by `encoding/json.Marshal` becomes a property of the schema. The property is required unless the field's `json` tag specifies "omitempty" or "omitzero" (new in Go 1.24). For example, given this struct: ```go struct { @@ -850,9 +620,7 @@ struct { "name" and "Choices" are required, while "count" is optional. -As of this writing, the only `ToolOption` is `Input`, which allows customizing the -input schema of the tool using schema options. These schema options are -recursive, in the sense that they may also be applied to properties. +As of this writing, the only `ToolOption` is `Input`, which allows customizing the input schema of the tool using schema options. These schema options are recursive, in the sense that they may also be applied to properties. ```go func Input(...SchemaOption) ToolOption @@ -869,10 +637,7 @@ NewTool(name, description, handler, Input(Property("count", Description("size of the inventory")))) ``` -The most recent JSON Schema spec defines over 40 keywords. Providing them all -as options would bloat the API despite the fact that most would be very rarely -used. For less common keywords, use the `Schema` option to set the schema -explicitly: +The most recent JSON Schema spec defines over 40 keywords. Providing them all as options would bloat the API despite the fact that most would be very rarely used. For less common keywords, use the `Schema` option to set the schema explicitly: ```go NewTool(name, description, handler, @@ -881,16 +646,11 @@ NewTool(name, description, handler, Schemas are validated on the server before the tool handler is called. -Since all the fields of the Tool struct are exported, a Tool can also be created -directly with assignment or a struct literal. +Since all the fields of the Tool struct are exported, a Tool can also be created directly with assignment or a struct literal. -Client sessions can call the spec method `ListTools` or an iterator method `Tools` -to list the available tools. +Client sessions can call the spec method `ListTools` or an iterator method `Tools` to list the available tools. -As mentioned above, the client session method `CallTool` has a non-standard -signature, so that `CallTool` can handle the marshalling of tool arguments: the -type of `CallToolParams.Arguments` is `json.RawMessage`, to delegate -unmarshalling to the tool handler. +As mentioned above, the client session method `CallTool` has a non-standard signature, so that `CallTool` can handle the marshalling of tool arguments: the type of `CallToolParams.Arguments` is `json.RawMessage`, to delegate unmarshalling to the tool handler. ```go func (c *ClientSession) CallTool(ctx context.Context, name string, args map[string]any, opts *CallToolOptions) (_ *CallToolResult, err error) @@ -900,36 +660,17 @@ type CallToolOptions struct { } ``` -**Differences from mcp-go**: using variadic options to configure tools was -significantly inspired by mcp-go. However, the distinction between `ToolOption` -and `SchemaOption` allows for recursive application of schema options. -For example, that limitation is visible in [this -code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), -which must resort to untyped maps to express a nested schema. - -Additionally, the `NewTool` helper provides a means for building a tool from a -Go function using reflection, that automatically handles parsing and validation -of inputs. - -We provide a full JSON Schema implementation for validating tool input schemas -against incoming arguments. The `jsonschema.Schema` type provides exported -features for all keywords in the JSON Schema draft2020-12 spec. Tool definers -can use it to construct any schema they want, so there is no need to provide -options for all of them. When combined with schema inference from input -structs, we found that we needed only three options to cover the common cases, -instead of mcp-go's 23. For example, we will provide `Enum`, which occurs 125 -times in open source code, but not MinItems, MinLength or MinProperties, which -each occur only once (and in an SDK that wraps mcp-go). - -For registering tools, we provide only `AddTools`; mcp-go's `SetTools`, -`AddTool`, `AddSessionTool`, and `AddSessionTools` are deemed unnecessary. -(Similarly for Delete/Remove). +**Differences from mcp-go**: using variadic options to configure tools was significantly inspired by mcp-go. However, the distinction between `ToolOption` and `SchemaOption` allows for recursive application of schema options. For example, that limitation is visible in [this code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), which must resort to untyped maps to express a nested schema. + +Additionally, the `NewTool` helper provides a means for building a tool from a Go function using reflection, that automatically handles parsing and validation of inputs. + +We provide a full JSON Schema implementation for validating tool input schemas against incoming arguments. The `jsonschema.Schema` type provides exported features for all keywords in the JSON Schema draft2020-12 spec. Tool definers can use it to construct any schema they want, so there is no need to provide options for all of them. When combined with schema inference from input structs, we found that we needed only three options to cover the common cases, instead of mcp-go's 23. For example, we will provide `Enum`, which occurs 125 times in open source code, but not MinItems, MinLength or MinProperties, which each occur only once (and in an SDK that wraps mcp-go). + +For registering tools, we provide only `AddTools`; mcp-go's `SetTools`, `AddTool`, `AddSessionTool`, and `AddSessionTools` are deemed unnecessary. (Similarly for Delete/Remove). ### Prompts -Use `NewPrompt` to create a prompt. -As with tools, prompt argument schemas can be inferred from a struct, or obtained -from options. +Use `NewPrompt` to create a prompt. As with tools, prompt argument schemas can be inferred from a struct, or obtained from options. ```go func NewPrompt[TReq any](name, description string, @@ -954,26 +695,19 @@ server.AddPrompts( server.RemovePrompts("code_review") ``` -Client sessions can call the spec method `ListPrompts` or the iterator method `Prompts` -to list the available prompts, and the spec method `GetPrompt` to get one. +Client sessions can call the spec method `ListPrompts` or the iterator method `Prompts` to list the available prompts, and the spec method `GetPrompt` to get one. -**Differences from mcp-go**: We provide a `NewPrompt` helper to bind a prompt -handler to a Go function using reflection to derive its arguments. We provide -`RemovePrompts` to remove prompts from the server. +**Differences from mcp-go**: We provide a `NewPrompt` helper to bind a prompt handler to a Go function using reflection to derive its arguments. We provide `RemovePrompts` to remove prompts from the server. ### Resources and resource templates -In our design, each resource and resource template is associated with a function that reads it, -with this signature: +In our design, each resource and resource template is associated with a function that reads it, with this signature: ```go type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) (*ReadResourceResult, error) ``` -The arguments include the `ServerSession` so the handler can observe the client's roots. -The handler should return the resource contents in a `ReadResourceResult`, calling either `NewTextResourceContents` -or `NewBlobResourceContents`. If the handler omits the URI or MIME type, the server will populate them from the -resource. +The arguments include the `ServerSession` so the handler can observe the client's roots. The handler should return the resource contents in a `ReadResourceResult`, calling either `NewTextResourceContents` or `NewBlobResourceContents`. If the handler omits the URI or MIME type, the server will populate them from the resource. The `ServerResource` and `ServerResourceTemplate` types hold the association between the resource and its handler: @@ -989,9 +723,7 @@ type ServerResourceTemplate struct { } ``` -To add a resource or resource template to a server, users call the `AddResources` and -`AddResourceTemplates` methods with one or more `ServerResource`s or `ServerResourceTemplate`s. -We also provide methods to remove them. +To add a resource or resource template to a server, users call the `AddResources` and `AddResourceTemplates` methods with one or more `ServerResource`s or `ServerResourceTemplate`s. We also provide methods to remove them. ```go func (*Server) AddResources(...*ServerResource) @@ -1001,8 +733,7 @@ func (s *Server) RemoveResources(uris ...string) func (s *Server) RemoveResourceTemplates(uriTemplates ...string) ``` -The `ReadResource` method finds a resource or resource template matching the argument URI and calls -its assocated handler. +The `ReadResource` method finds a resource or resource template matching the argument URI and calls its assocated handler. To read files from the local filesystem, we recommend using `FileResourceHandler` to construct a handler: @@ -1022,13 +753,9 @@ s.AddResources(&mcp.ServerResource{ Handler: s.FileReadResourceHandler("/public")}) ``` -Server sessions also support the spec methods `ListResources` and `ListResourceTemplates`, -and the corresponding iterator methods `Resources` and `ResourceTemplates`. +Server sessions also support the spec methods `ListResources` and `ListResourceTemplates`, and the corresponding iterator methods `Resources` and `ResourceTemplates`. -**Differences from mcp-go**: for symmetry with tools and prompts, we use -`AddResources` rather than `AddResource`. Additionally, the `ResourceHandler` -returns a `ReadResourceResult`, rather than just its content, for compatibility -with future evolution of the spec. +**Differences from mcp-go**: for symmetry with tools and prompts, we use `AddResources` rather than `AddResource`. Additionally, the `ResourceHandler` returns a `ReadResourceResult`, rather than just its content, for compatibility with future evolution of the spec. #### Subscriptions @@ -1039,14 +766,9 @@ func (*ClientSession) Subscribe(context.Context, *SubscribeParams) error func (*ClientSession) Unsubscribe(context.Context, *UnsubscribeParams) error ``` -The server does not implement resource subscriptions. It passes along -subscription requests to the user, and supplies a method to notify clients of -changes. It tracks which sessions have subscribed to which resources so the -user doesn't have to. +The server does not implement resource subscriptions. It passes along subscription requests to the user, and supplies a method to notify clients of changes. It tracks which sessions have subscribed to which resources so the user doesn't have to. -If a server author wants to support resource subscriptions, they must provide handlers -to be called when clients subscribe and unsubscribe. It is an error to provide only -one of these handlers. +If a server author wants to support resource subscriptions, they must provide handlers to be called when clients subscribe and unsubscribe. It is an error to provide only one of these handlers. ```go type ServerOptions struct { @@ -1068,10 +790,7 @@ The server routes these notifications to the server sessions that subscribed to ### ListChanged notifications -When a list of tools, prompts or resources changes as the result of an AddXXX -or RemoveXXX call, the server informs all its connected clients by sending the -corresponding type of notification. -A client will receive these notifications if it was created with the corresponding option: +When a list of tools, prompts or resources changes as the result of an AddXXX or RemoveXXX call, the server informs all its connected clients by sending the corresponding type of notification. A client will receive these notifications if it was created with the corresponding option: ```go type ClientOptions struct { @@ -1083,18 +802,13 @@ type ClientOptions struct { } ``` -**Differences from mcp-go**: mcp-go instead provides a general `OnNotification` -handler. For type-safety, and to hide JSON RPC details, we provide -feature-specific handlers here. +**Differences from mcp-go**: mcp-go instead provides a general `OnNotification` handler. For type-safety, and to hide JSON RPC details, we provide feature-specific handlers here. ### Completion -Clients call the spec method `Complete` to request completions. -Servers automatically handle these requests based on their collections of -prompts and resources. +Clients call the spec method `Complete` to request completions. Servers automatically handle these requests based on their collections of prompts and resources. -**Differences from mcp-go**: the client API is similar. mcp-go has not yet -defined its server-side behavior. +**Differences from mcp-go**: the client API is similar. mcp-go has not yet defined its server-side behavior. ### Logging @@ -1111,22 +825,13 @@ type ServerOptions { } ``` -Server sessions have a field `Logger` holding a `slog.Logger` that writes to the client session. -A call to a log method like `Info` is translated to a `LoggingMessageNotification` as -follows: +Server sessions have a field `Logger` holding a `slog.Logger` that writes to the client session. A call to a log method like `Info` is translated to a `LoggingMessageNotification` as follows: -- The attributes and the message populate the "data" property with the - output of a `slog.JSONHandler`: The result is always a JSON object, with the - key "msg" for the message. +- The attributes and the message populate the "data" property with the output of a `slog.JSONHandler`: The result is always a JSON object, with the key "msg" for the message. - If the `LoggerName` server option is set, it populates the "logger" property. -- The standard slog levels `Info`, `Debug`, `Warn` and `Error` map to the - corresponding levels in the MCP spec. The other spec levels map - to integers between the slog levels. For example, "notice" is level 2 because - it is between "warning" (slog value 4) and "info" (slog value 0). - The `mcp` package defines consts for these levels. To log at the "notice" - level, a handler would call `session.Logger.Log(ctx, mcp.LevelNotice, "message")`. +- The standard slog levels `Info`, `Debug`, `Warn` and `Error` map to the corresponding levels in the MCP spec. The other spec levels map to integers between the slog levels. For example, "notice" is level 2 because it is between "warning" (slog value 4) and "info" (slog value 0). The `mcp` package defines consts for these levels. To log at the "notice" level, a handler would call `session.Logger.Log(ctx, mcp.LevelNotice, "message")`. A client that wishes to receive log messages must provide a handler: @@ -1139,15 +844,9 @@ type ClientOptions struct { ### Pagination -Servers initiate pagination for `ListTools`, `ListPrompts`, `ListResources`, -and `ListResourceTemplates`, dictating the page size and providing a -`NextCursor` field in the Result if more pages exist. The SDK implements keyset -pagination, using the unique ID of the feature as the key for a stable sort order and encoding -the cursor as an opaque string. +Servers initiate pagination for `ListTools`, `ListPrompts`, `ListResources`, and `ListResourceTemplates`, dictating the page size and providing a `NextCursor` field in the Result if more pages exist. The SDK implements keyset pagination, using the unique ID of the feature as the key for a stable sort order and encoding the cursor as an opaque string. -For server implementations, the page size for the list operation may be -configured via the `ServerOptions.PageSize` field. PageSize must be a -non-negative integer. If zero, a sensible default is used. +For server implementations, the page size for the list operation may be configured via the `ServerOptions.PageSize` field. PageSize must be a non-negative integer. If zero, a sensible default is used. ```go type ServerOptions { @@ -1156,14 +855,8 @@ type ServerOptions { } ``` -Client requests for List methods include an optional Cursor field for -pagination. Server responses for List methods include a `NextCursor` field if -more pages exist. +Client requests for List methods include an optional Cursor field for pagination. Server responses for List methods include a `NextCursor` field if more pages exist. -In addition to the `List` methods, the SDK provides an iterator method for each -list operation. This simplifies pagination for clients by automatically handling -the underlying pagination logic. See [Iterator Methods](#iterator-methods) above. +In addition to the `List` methods, the SDK provides an iterator method for each list operation. This simplifies pagination for clients by automatically handling the underlying pagination logic. See [Iterator Methods](#iterator-methods) above. -**Differences with mcp-go**: the PageSize configuration is set with a -configuration field rather than a variadic option. Additionally, this design -proposes pagination by default, as this is likely desirable for most servers. +**Differences with mcp-go**: the PageSize configuration is set with a configuration field rather than a variadic option. Additionally, this design proposes pagination by default, as this is likely desirable for most servers From c6b2a9c0a5dfc90fe9ca1aecd86e470e064529b4 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 14 May 2025 18:07:48 -0400 Subject: [PATCH 085/196] internal/typesinternal: add TypeNameFor helper This function eliminates most of the places where we had reinvented this wheel: type HasTypeName interface { Obj() *TypeName } Updates golang/go#66890 Updates golang/go#71886 Change-Id: I70d3c8141b20efb1ba0e02fc846d8b38b8f3a23c Reviewed-on: https://go-review.googlesource.com/c/tools/+/672975 Auto-Submit: Alan Donovan Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Commit-Queue: Alan Donovan --- go/analysis/passes/composite/composite.go | 3 ++- go/analysis/passes/stringintconv/string.go | 12 +++++----- go/types/objectpath/objectpath.go | 2 +- gopls/internal/golang/extract.go | 27 +++++++++------------ gopls/internal/protocol/command/gen/gen.go | 8 +++---- internal/refactor/inline/inline.go | 9 +++---- internal/typesinternal/types.go | 28 ++++++++++++++++++++++ refactor/rename/check.go | 2 +- refactor/rename/spec.go | 3 ++- 9 files changed, 57 insertions(+), 37 deletions(-) diff --git a/go/analysis/passes/composite/composite.go b/go/analysis/passes/composite/composite.go index 25c98a97bbc..ed2284e6306 100644 --- a/go/analysis/passes/composite/composite.go +++ b/go/analysis/passes/composite/composite.go @@ -153,7 +153,8 @@ func isLocalType(pass *analysis.Pass, typ types.Type) bool { return isLocalType(pass, x.Elem()) case interface{ Obj() *types.TypeName }: // *Named or *TypeParam (aliases were removed already) // names in package foo are local to foo_test too - return strings.TrimSuffix(x.Obj().Pkg().Path(), "_test") == strings.TrimSuffix(pass.Pkg.Path(), "_test") + return x.Obj().Pkg() != nil && + strings.TrimSuffix(x.Obj().Pkg().Path(), "_test") == strings.TrimSuffix(pass.Pkg.Path(), "_test") } return false } diff --git a/go/analysis/passes/stringintconv/string.go b/go/analysis/passes/stringintconv/string.go index a23721cd26f..7dbff1e4d8d 100644 --- a/go/analysis/passes/stringintconv/string.go +++ b/go/analysis/passes/stringintconv/string.go @@ -17,6 +17,7 @@ import ( "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/internal/analysisinternal" "golang.org/x/tools/internal/typeparams" + "golang.org/x/tools/internal/typesinternal" ) //go:embed doc.go @@ -60,12 +61,11 @@ func describe(typ, inType types.Type, inName string) string { } func typeName(t types.Type) string { - type hasTypeName interface{ Obj() *types.TypeName } // Alias, Named, TypeParam - switch t := t.(type) { - case *types.Basic: - return t.Name() - case hasTypeName: - return t.Obj().Name() + if basic, ok := t.(*types.Basic); ok { + return basic.Name() // may be (e.g.) "untyped int", which has no TypeName + } + if tname := typesinternal.TypeNameFor(t); tname != nil { + return tname.Name() } return "" } diff --git a/go/types/objectpath/objectpath.go b/go/types/objectpath/objectpath.go index 16ed3c1780b..d3c2913bef3 100644 --- a/go/types/objectpath/objectpath.go +++ b/go/types/objectpath/objectpath.go @@ -603,7 +603,7 @@ func Object(pkg *types.Package, p Path) (types.Object, error) { type hasTypeParams interface { TypeParams() *types.TypeParamList } - // abstraction of *types.{Named,TypeParam} + // abstraction of *types.{Alias,Named,TypeParam} type hasObj interface { Obj() *types.TypeName } diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index f19285b8a3c..59916676fe9 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -1819,26 +1819,21 @@ var conventionalVarNames = map[objKey]string{ // // For special types, it uses known conventional names. func varNameForType(t types.Type) (string, bool) { - var typeName string - if tn, ok := t.(interface{ Obj() *types.TypeName }); ok { - obj := tn.Obj() - k := objKey{name: obj.Name()} - if obj.Pkg() != nil { - k.pkg = obj.Pkg().Name() - } - if name, ok := conventionalVarNames[k]; ok { - return name, true - } - typeName = obj.Name() - } else if b, ok := t.(*types.Basic); ok { - typeName = b.Name() + tname := typesinternal.TypeNameFor(t) + if tname == nil { + return "", false } - if typeName == "" { - return "", false + // Have Alias, Basic, Named, or TypeParam. + k := objKey{name: tname.Name()} + if tname.Pkg() != nil { + k.pkg = tname.Pkg().Name() + } + if name, ok := conventionalVarNames[k]; ok { + return name, true } - return AbbreviateVarName(typeName), true + return AbbreviateVarName(tname.Name()), true } // adjustReturnStatements adds "zero values" of the given types to each return diff --git a/gopls/internal/protocol/command/gen/gen.go b/gopls/internal/protocol/command/gen/gen.go index d4935020b38..779e6d83523 100644 --- a/gopls/internal/protocol/command/gen/gen.go +++ b/gopls/internal/protocol/command/gen/gen.go @@ -15,6 +15,7 @@ import ( "golang.org/x/tools/gopls/internal/protocol/command/commandmeta" "golang.org/x/tools/internal/imports" + "golang.org/x/tools/internal/typesinternal" ) const src = `// Copyright 2024 The Go Authors. All rights reserved. @@ -192,11 +193,8 @@ func Generate() ([]byte, error) { } func pkgPath(t types.Type) string { - type hasTypeName interface { // *Named or *Alias (or *TypeParam) - Obj() *types.TypeName - } - if t, ok := t.(hasTypeName); ok { - if pkg := t.Obj().Pkg(); pkg != nil { + if tname := typesinternal.TypeNameFor(t); tname != nil { + if pkg := tname.Pkg(); pkg != nil { return pkg.Path() } } diff --git a/internal/refactor/inline/inline.go b/internal/refactor/inline/inline.go index a17078cb96c..445f6b705c4 100644 --- a/internal/refactor/inline/inline.go +++ b/internal/refactor/inline/inline.go @@ -3516,12 +3516,9 @@ func (st *state) assignStmts(callerStmt *ast.AssignStmt, returnOperands []ast.Ex typeName string obj *types.TypeName // nil for basic types ) - switch typ := typ.(type) { - case *types.Basic: - typeName = typ.Name() - case interface{ Obj() *types.TypeName }: // Named, Alias, TypeParam - obj = typ.Obj() - typeName = typ.Obj().Name() + if tname := typesinternal.TypeNameFor(typ); tname != nil { + obj = tname + typeName = tname.Name() } // Special case: check for universe "any". diff --git a/internal/typesinternal/types.go b/internal/typesinternal/types.go index cc244689ef8..1f292edb6d3 100644 --- a/internal/typesinternal/types.go +++ b/internal/typesinternal/types.go @@ -69,6 +69,34 @@ func NameRelativeTo(pkg *types.Package) types.Qualifier { } } +// TypeNameFor returns the type name symbol for the specified type, if +// it is a [*types.Alias], [*types.Named], [*types.TypeParam], or a +// [*types.Basic] representing a type. +// +// For all other types, and for Basic types representing a builtin, +// constant, or nil, it returns nil. Be careful not to convert the +// resulting nil pointer to a [types.Object]! +// +// If t is the type of a constant, it may be an "untyped" type, which +// has no TypeName. To access the name of such types (e.g. "untyped +// int"), use [types.Basic.Name]. +func TypeNameFor(t types.Type) *types.TypeName { + switch t := t.(type) { + case *types.Alias: + return t.Obj() + case *types.Named: + return t.Obj() + case *types.TypeParam: + return t.Obj() + case *types.Basic: + // See issues #71886 and #66890 for some history. + if tname, ok := types.Universe.Lookup(t.Name()).(*types.TypeName); ok { + return tname + } + } + return nil +} + // A NamedOrAlias is a [types.Type] that is named (as // defined by the spec) and capable of bearing type parameters: it // abstracts aliases ([types.Alias]) and defined types diff --git a/refactor/rename/check.go b/refactor/rename/check.go index 58cbff9b594..f41213a7a73 100644 --- a/refactor/rename/check.go +++ b/refactor/rename/check.go @@ -772,7 +772,7 @@ func (r *renamer) checkMethod(from *types.Func) { var iface string I := recv(imeth).Type() - if named, ok := I.(hasTypeName); ok { + if named, ok := I.(hasTypeName); ok { // *Named or *Alias pos = named.Obj().Pos() iface = "interface " + named.Obj().Name() } else { diff --git a/refactor/rename/spec.go b/refactor/rename/spec.go index 0a6d7d4346c..c1854d4a5ad 100644 --- a/refactor/rename/spec.go +++ b/refactor/rename/spec.go @@ -463,7 +463,8 @@ func findObjects(info *loader.PackageInfo, spec *spec) ([]types.Object, error) { } if spec.searchFor == "" { - // If it is an embedded field, return the type of the field. + // If it is an embedded field (*Named or *Alias), + // return the type of the field. if v, ok := obj.(*types.Var); ok && v.Anonymous() { if t, ok := typesinternal.Unpointer(v.Type()).(hasTypeName); ok { return []types.Object{t.Obj()}, nil From b37bd0b8c1f563df447a22786c6cb71e558a5f2a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 14 May 2025 18:20:49 -0400 Subject: [PATCH 086/196] internal/typesinternal: add go1.23 methods to NamedOrAlias ...and eliminate the compatibility shims. Change-Id: Iba04cdb0270e20c900b15aab94e5ed78f7c7171e Reviewed-on: https://go-review.googlesource.com/c/tools/+/672976 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Commit-Queue: Alan Donovan Auto-Submit: Alan Donovan --- .../internal/golang/completion/completion.go | 4 +-- gopls/internal/golang/completion/literal.go | 21 +++++++----- .../golang/stubmethods/stubcalledfunc.go | 2 +- .../golang/stubmethods/stubmethods.go | 2 +- internal/facts/imports.go | 4 +-- internal/typesinternal/types.go | 33 +++++-------------- 6 files changed, 26 insertions(+), 40 deletions(-) diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index f61fdc6f7ba..d6b49ca9d04 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -1818,7 +1818,7 @@ func (c *completer) injectType(ctx context.Context, t types.Type) { // considered via a lexical search, so we need to directly inject // them. Also allow generic types since lexical search does not // infer instantiated versions of them. - if pnt, ok := t.(typesinternal.NamedOrAlias); !ok || typesinternal.TypeParams(pnt).Len() > 0 { + if pnt, ok := t.(typesinternal.NamedOrAlias); !ok || pnt.TypeParams().Len() > 0 { // If our expected type is "[]int", this will add a literal // candidate of "[]int{}". c.literal(ctx, t, nil) @@ -2879,7 +2879,7 @@ func (c *completer) expectedCallParamType(inf candidateInference, node *ast.Call func expectedConstraint(t types.Type, idx int) types.Type { var tp *types.TypeParamList if pnt, ok := t.(typesinternal.NamedOrAlias); ok { - tp = typesinternal.TypeParams(pnt) + tp = pnt.TypeParams() } else if sig, _ := t.Underlying().(*types.Signature); sig != nil { tp = sig.TypeParams() } diff --git a/gopls/internal/golang/completion/literal.go b/gopls/internal/golang/completion/literal.go index 20cce04b69f..5dc364724c6 100644 --- a/gopls/internal/golang/completion/literal.go +++ b/gopls/internal/golang/completion/literal.go @@ -517,21 +517,24 @@ func (c *completer) typeNameSnippet(literalType types.Type, qual types.Qualifier var ( snip snippet.Builder typeName string - pnt, _ = literalType.(typesinternal.NamedOrAlias) // = *Named | *Alias + tparams *types.TypeParamList ) - tparams := typesinternal.TypeParams(pnt) - if tparams.Len() > 0 && !c.fullyInstantiated(pnt) { - // tparams.Len() > 0 implies pnt != nil. - // Inv: pnt is not "error" or "unsafe.Pointer", so pnt.Obj() != nil and has a Pkg(). + t, ok := literalType.(typesinternal.NamedOrAlias) // = *Named | *Alias + if ok { + tparams = t.TypeParams() + } + if tparams.Len() > 0 && !c.fullyInstantiated(t) { + // tparams.Len() > 0 implies t != nil. + // Inv: t is not "error" or "unsafe.Pointer", so t.Obj() != nil and has a Pkg(). // We are not "fully instantiated" meaning we have type params that must be specified. - if pkg := qual(pnt.Obj().Pkg()); pkg != "" { + if pkg := qual(t.Obj().Pkg()); pkg != "" { typeName = pkg + "." } // We do this to get "someType" instead of "someType[T]". - typeName += pnt.Obj().Name() + typeName += t.Obj().Name() snip.WriteText(typeName + "[") if c.opts.placeholders { @@ -560,8 +563,8 @@ func (c *completer) typeNameSnippet(literalType types.Type, qual types.Qualifier // fullyInstantiated reports whether all of t's type params have // specified type args. func (c *completer) fullyInstantiated(t typesinternal.NamedOrAlias) bool { - targs := typesinternal.TypeArgs(t) - tparams := typesinternal.TypeParams(t) + targs := t.TypeArgs() + tparams := t.TypeParams() if tparams.Len() != targs.Len() { return false diff --git a/gopls/internal/golang/stubmethods/stubcalledfunc.go b/gopls/internal/golang/stubmethods/stubcalledfunc.go index b4b59340d83..a40bf23924d 100644 --- a/gopls/internal/golang/stubmethods/stubcalledfunc.go +++ b/gopls/internal/golang/stubmethods/stubcalledfunc.go @@ -121,7 +121,7 @@ func (si *CallStubInfo) Emit(out *bytes.Buffer, qual types.Qualifier) error { recvName, star, recv.Name(), - typesutil.FormatTypeParams(typesinternal.TypeParams(si.Receiver)), + typesutil.FormatTypeParams(si.Receiver.TypeParams()), si.MethodName) // Emit parameters, avoiding name conflicts. diff --git a/gopls/internal/golang/stubmethods/stubmethods.go b/gopls/internal/golang/stubmethods/stubmethods.go index 43842264d70..317a55325e5 100644 --- a/gopls/internal/golang/stubmethods/stubmethods.go +++ b/gopls/internal/golang/stubmethods/stubmethods.go @@ -200,7 +200,7 @@ func (si *IfaceStubInfo) Emit(out *bytes.Buffer, qual types.Qualifier) error { mrn, star, si.Concrete.Obj().Name(), - typesutil.FormatTypeParams(typesinternal.TypeParams(si.Concrete)), + typesutil.FormatTypeParams(si.Concrete.TypeParams()), missing[index].fn.Name(), strings.TrimPrefix(types.TypeString(missing[index].fn.Type(), qual), "func")) } diff --git a/internal/facts/imports.go b/internal/facts/imports.go index ed5ec5fa131..cc9383e8004 100644 --- a/internal/facts/imports.go +++ b/internal/facts/imports.go @@ -52,7 +52,7 @@ func importMap(imports []*types.Package) map[string]*types.Package { // nop case typesinternal.NamedOrAlias: // *types.{Named,Alias} // Add the type arguments if this is an instance. - if targs := typesinternal.TypeArgs(T); targs.Len() > 0 { + if targs := T.TypeArgs(); targs.Len() > 0 { for i := 0; i < targs.Len(); i++ { addType(targs.At(i)) } @@ -69,7 +69,7 @@ func importMap(imports []*types.Package) map[string]*types.Package { // common aspects addObj(T.Obj()) - if tparams := typesinternal.TypeParams(T); tparams.Len() > 0 { + if tparams := T.TypeParams(); tparams.Len() > 0 { for i := 0; i < tparams.Len(); i++ { addType(tparams.At(i)) } diff --git a/internal/typesinternal/types.go b/internal/typesinternal/types.go index 1f292edb6d3..a5cd7e8dbfc 100644 --- a/internal/typesinternal/types.go +++ b/internal/typesinternal/types.go @@ -105,7 +105,7 @@ func TypeNameFor(t types.Type) *types.TypeName { // Every type declared by an explicit "type" declaration is a // NamedOrAlias. (Built-in type symbols may additionally // have type [types.Basic], which is not a NamedOrAlias, -// though the spec regards them as "named".) +// though the spec regards them as "named"; see [TypeNameFor].) // // NamedOrAlias cannot expose the Origin method, because // [types.Alias.Origin] and [types.Named.Origin] have different @@ -113,32 +113,15 @@ func TypeNameFor(t types.Type) *types.TypeName { type NamedOrAlias interface { types.Type Obj() *types.TypeName - // TODO(hxjiang): add method TypeArgs() *types.TypeList after stop supporting go1.22. + TypeArgs() *types.TypeList + TypeParams() *types.TypeParamList + SetTypeParams(tparams []*types.TypeParam) } -// TypeParams is a light shim around t.TypeParams(). -// (go/types.Alias).TypeParams requires >= 1.23. -func TypeParams(t NamedOrAlias) *types.TypeParamList { - switch t := t.(type) { - case *types.Alias: - return aliases.TypeParams(t) - case *types.Named: - return t.TypeParams() - } - return nil -} - -// TypeArgs is a light shim around t.TypeArgs(). -// (go/types.Alias).TypeArgs requires >= 1.23. -func TypeArgs(t NamedOrAlias) *types.TypeList { - switch t := t.(type) { - case *types.Alias: - return aliases.TypeArgs(t) - case *types.Named: - return t.TypeArgs() - } - return nil -} +var ( + _ NamedOrAlias = (*types.Alias)(nil) + _ NamedOrAlias = (*types.Named)(nil) +) // Origin returns the generic type of the Named or Alias type t if it // is instantiated, otherwise it returns t. From bc8c84cbab37ef35fe49e05ac839276471471633 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 15 May 2025 13:25:24 -0400 Subject: [PATCH 087/196] gopls/internal/golang: AddTest: fix types.Package.Path nil panic Also, simplify isContextType by using a library function. Fixes golang/go#73687 Change-Id: Ic2483f26d15560bbfd69ce53f0c348343b6c080c Reviewed-on: https://go-review.googlesource.com/c/tools/+/673198 Reviewed-by: Hongxiang Jiang LUCI-TryBot-Result: Go LUCI --- gopls/internal/golang/addtest.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/gopls/internal/golang/addtest.go b/gopls/internal/golang/addtest.go index 89d0be3d1fd..66ed9716c9a 100644 --- a/gopls/internal/golang/addtest.go +++ b/gopls/internal/golang/addtest.go @@ -28,6 +28,7 @@ import ( "golang.org/x/tools/gopls/internal/protocol" goplsastutil "golang.org/x/tools/gopls/internal/util/astutil" "golang.org/x/tools/gopls/internal/util/moremaps" + "golang.org/x/tools/internal/analysisinternal" "golang.org/x/tools/internal/imports" "golang.org/x/tools/internal/typesinternal" ) @@ -480,12 +481,8 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. }, } - var isContextType = func(t types.Type) bool { - named, ok := t.(*types.Named) - if !ok { - return false - } - return named.Obj().Pkg().Path() == "context" && named.Obj().Name() == "Context" + isContextType := func(t types.Type) bool { + return analysisinternal.IsTypeNamed(t, "context", "Context") } for i := range sig.Params().Len() { From 5c7400c9e565bd7da802f00d301c6fb716c2a946 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 14 May 2025 16:27:16 -0400 Subject: [PATCH 088/196] gopls/internal/cache/parsego: use PreorderStack A minor cleanup, and opportunity to converge on go1.25 std API. Updates golang/go#73319 Change-Id: Id0a6bf01b7db0ef31b0920778065666fb6b92464 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672935 Auto-Submit: Alan Donovan LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- gopls/internal/cache/parsego/parse.go | 37 +++++++-------------------- internal/astutil/util.go | 2 ++ 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/gopls/internal/cache/parsego/parse.go b/gopls/internal/cache/parsego/parse.go index df4d9c8e44d..3346edd2b7a 100644 --- a/gopls/internal/cache/parsego/parse.go +++ b/gopls/internal/cache/parsego/parse.go @@ -28,6 +28,7 @@ import ( "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/util/astutil" "golang.org/x/tools/gopls/internal/util/safetoken" + internalastutil "golang.org/x/tools/internal/astutil" "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/event" ) @@ -148,7 +149,12 @@ func Parse(ctx context.Context, fset *token.FileSet, uri protocol.DocumentURI, s // positions have been mangled, and type checker errors may not make sense. func fixAST(n ast.Node, tok *token.File, src []byte) (fixes []FixType) { var err error - walkASTWithParent(n, func(n, parent ast.Node) bool { + internalastutil.PreorderStack(n, nil, func(n ast.Node, stack []ast.Node) bool { + var parent ast.Node + if len(stack) > 0 { + parent = stack[len(stack)-1] + } + switch n := n.(type) { case *ast.BadStmt: if fixDeferOrGoStmt(n, parent, tok, src) { @@ -207,32 +213,6 @@ func fixAST(n ast.Node, tok *token.File, src []byte) (fixes []FixType) { return fixes } -// walkASTWithParent walks the AST rooted at n. The semantics are -// similar to ast.Inspect except it does not call f(nil). -// TODO(adonovan): replace with PreorderStack. -func walkASTWithParent(n ast.Node, f func(n ast.Node, parent ast.Node) bool) { - var ancestors []ast.Node - ast.Inspect(n, func(n ast.Node) (recurse bool) { - defer func() { - if recurse { - ancestors = append(ancestors, n) - } - }() - - if n == nil { - ancestors = ancestors[:len(ancestors)-1] - return false - } - - var parent ast.Node - if len(ancestors) > 0 { - parent = ancestors[len(ancestors)-1] - } - - return f(n, parent) - }) -} - // TODO(rfindley): revert this intrumentation once we're certain the crash in // #59097 is fixed. type FixType int @@ -253,13 +233,14 @@ const ( // // fixSrc returns a non-nil result if and only if a fix was applied. func fixSrc(f *ast.File, tf *token.File, src []byte) (newSrc []byte, fix FixType) { - walkASTWithParent(f, func(n, parent ast.Node) bool { + internalastutil.PreorderStack(f, nil, func(n ast.Node, stack []ast.Node) bool { if newSrc != nil { return false } switch n := n.(type) { case *ast.BlockStmt: + parent := stack[len(stack)-1] newSrc = fixMissingCurlies(f, n, parent, tf, src) if newSrc != nil { fix = FixedCurlies diff --git a/internal/astutil/util.go b/internal/astutil/util.go index 1862668a7c6..f06dbda3697 100644 --- a/internal/astutil/util.go +++ b/internal/astutil/util.go @@ -71,6 +71,8 @@ func PosInStringLiteral(lit *ast.BasicLit, offset int) (token.Pos, error) { // In practice, the second call is nearly always used only to pop the // stack, and it is surprisingly tricky to do this correctly; see // https://go.dev/issue/73319. +// +// TODO(adonovan): replace with [ast.PreorderStack] when go1.25 is assured. func PreorderStack(root ast.Node, stack []ast.Node, f func(n ast.Node, stack []ast.Node) bool) { before := len(stack) ast.Inspect(root, func(n ast.Node) bool { From 2263a61d5ce1f05b275114b1c2d067ab1db3bf41 Mon Sep 17 00:00:00 2001 From: xieyuschen Date: Fri, 16 May 2025 23:40:12 +0800 Subject: [PATCH 089/196] gopls/internal/test/marker: add a folding ranges test case This CL adds a test case for folding_range, and demonstrates gopls current reports folding range correctly. So the issue may not caused by gopls. Updates: golang/go#73735 Change-Id: I60e7a649ee10a4ffecb5468981e80c96f0ada455 Reviewed-on: https://go-review.googlesource.com/c/tools/+/673395 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Reviewed-by: Alan Donovan Commit-Queue: Alan Donovan Auto-Submit: Alan Donovan --- gopls/internal/test/marker/testdata/foldingrange/a.txt | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/gopls/internal/test/marker/testdata/foldingrange/a.txt b/gopls/internal/test/marker/testdata/foldingrange/a.txt index 864442e1b0c..f64a6e0014a 100644 --- a/gopls/internal/test/marker/testdata/foldingrange/a.txt +++ b/gopls/internal/test/marker/testdata/foldingrange/a.txt @@ -131,6 +131,11 @@ func _( c int, ) { } + +func _() { // comment + +} + -- @raw -- package folding //@foldingrange(raw) @@ -262,3 +267,8 @@ func _(<52 kind=""> c int, ) {<53 kind=""> } + +func _() {<54 kind=""> // comment + +} + From c460ea9266b569116af467577816bc9d9f4e2676 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 16 May 2025 11:16:59 -0400 Subject: [PATCH 090/196] gopls/internal/analysis/modernize: disable slicesdelete pass Unfortunately, it is not safe to replace append(s[:i], s[j:]...) with slices.Delete(s, i, j) because the latter clears out the vacated array slots, which may have unexpected effects on program behavior in esoteric cases. So we disable the pass for now. Perhaps we can later enable it after either making it more sophisticated or revising our safety goals. Fixes golang/go#73686 Change-Id: Ia8be98c8f26c1699b7d75f25e2efdb5a71696031 Reviewed-on: https://go-review.googlesource.com/c/tools/+/673516 LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Robert Findley Commit-Queue: Alan Donovan --- gopls/doc/analyzers.md | 3 --- gopls/internal/analysis/modernize/doc.go | 3 --- .../internal/analysis/modernize/modernize_test.go | 2 ++ gopls/internal/analysis/modernize/slicesdelete.go | 15 +++++++++++++++ gopls/internal/doc/api.json | 4 ++-- 5 files changed, 19 insertions(+), 8 deletions(-) diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md index e18a7c7efda..915afe346dc 100644 --- a/gopls/doc/analyzers.md +++ b/gopls/doc/analyzers.md @@ -3741,9 +3741,6 @@ Categories of modernize diagnostic: benchmark with "for b.Loop()", and remove any preceding calls to b.StopTimer, b.StartTimer, and b.ResetTimer. - - slicesdelete: replace append(s[:i], s[i+1]...) by - slices.Delete(s, i, i+1), added in go1.21. - - rangeint: replace a 3-clause "for i := 0; i < n; i++" loop by "for i := range n", added in go1.22. diff --git a/gopls/internal/analysis/modernize/doc.go b/gopls/internal/analysis/modernize/doc.go index 2c4b893f6d2..e7cf5c9c8fd 100644 --- a/gopls/internal/analysis/modernize/doc.go +++ b/gopls/internal/analysis/modernize/doc.go @@ -82,9 +82,6 @@ // benchmark with "for b.Loop()", and remove any preceding calls // to b.StopTimer, b.StartTimer, and b.ResetTimer. // -// - slicesdelete: replace append(s[:i], s[i+1]...) by -// slices.Delete(s, i, i+1), added in go1.21. -// // - rangeint: replace a 3-clause "for i := 0; i < n; i++" loop by // "for i := range n", added in go1.22. // diff --git a/gopls/internal/analysis/modernize/modernize_test.go b/gopls/internal/analysis/modernize/modernize_test.go index e823e983995..7ef77f16bce 100644 --- a/gopls/internal/analysis/modernize/modernize_test.go +++ b/gopls/internal/analysis/modernize/modernize_test.go @@ -12,6 +12,8 @@ import ( ) func Test(t *testing.T) { + modernize.EnableSlicesDelete = true + analysistest.RunWithSuggestedFixes(t, analysistest.TestData(), modernize.Analyzer, "appendclipped", "bloop", diff --git a/gopls/internal/analysis/modernize/slicesdelete.go b/gopls/internal/analysis/modernize/slicesdelete.go index 493009c35be..2d396787ddf 100644 --- a/gopls/internal/analysis/modernize/slicesdelete.go +++ b/gopls/internal/analysis/modernize/slicesdelete.go @@ -16,11 +16,26 @@ import ( "golang.org/x/tools/internal/analysisinternal" ) +// slices.Delete is not equivalent to append(s[:i], [j:]...): +// it clears the vacated array slots; see #73686. +// Until we either fix it or revise our safety goals, +// we disable this analyzer for now. +// +// Its former documentation in doc.go was: +// +// - slicesdelete: replace append(s[:i], s[i+1]...) by +// slices.Delete(s, i, i+1), added in go1.21. +var EnableSlicesDelete = false + // The slicesdelete pass attempts to replace instances of append(s[:i], s[i+k:]...) // with slices.Delete(s, i, i+k) where k is some positive constant. // Other variations that will also have suggested replacements include: // append(s[:i-1], s[i:]...) and append(s[:i+k1], s[i+k2:]) where k2 > k1. func slicesdelete(pass *analysis.Pass) { + if !EnableSlicesDelete { + return + } + // Skip the analyzer in packages where its // fixes would create an import cycle. if within(pass, "slices", "runtime") { diff --git a/gopls/internal/doc/api.json b/gopls/internal/doc/api.json index 37a996950be..969bc1a17ef 100644 --- a/gopls/internal/doc/api.json +++ b/gopls/internal/doc/api.json @@ -1492,7 +1492,7 @@ }, { "Name": "\"modernize\"", - "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n - slicesdelete: replace append(s[:i], s[i+1]...) by\n slices.Delete(s, i, i+1), added in go1.21.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", + "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", "Default": "true", "Status": "" }, @@ -3212,7 +3212,7 @@ }, { "Name": "modernize", - "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n - slicesdelete: replace append(s[:i], s[i+1]...) by\n slices.Delete(s, i, i+1), added in go1.21.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", + "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", "URL": "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/modernize", "Default": true }, From 5a46e4d95cb79a92437e0914a7318b20f77ae79f Mon Sep 17 00:00:00 2001 From: xieyuschen Date: Sat, 17 May 2025 00:01:39 -0600 Subject: [PATCH 091/196] gopls/internal/analysis/modernize/slicesdelete: convert index type if needed As go compiler has a rule "the index x must be an untyped constant or its core type must be an integer" but slices.Delete accepts the int only, the compatible index type may break type checking after applying modernize fix. This CL checks whether the index type is int type, and converts it to int by 'int(expr)'. Fixes golang/go#73663 Change-Id: I13553e8149baa290f8a0793ed5615481a1418919 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671735 Reviewed-by: Alan Donovan Reviewed-by: Michael Knyszek Auto-Submit: Alan Donovan LUCI-TryBot-Result: Go LUCI --- .../internal/analysis/modernize/modernize.go | 1 + .../analysis/modernize/slicesdelete.go | 33 ++++++++++++++++++- .../testdata/src/slicesdelete/slicesdelete.go | 12 +++++++ .../src/slicesdelete/slicesdelete.go.golden | 12 +++++++ 4 files changed, 57 insertions(+), 1 deletion(-) diff --git a/gopls/internal/analysis/modernize/modernize.go b/gopls/internal/analysis/modernize/modernize.go index d092c10c313..65fb81dd9de 100644 --- a/gopls/internal/analysis/modernize/modernize.go +++ b/gopls/internal/analysis/modernize/modernize.go @@ -177,6 +177,7 @@ var ( builtinAny = types.Universe.Lookup("any") builtinAppend = types.Universe.Lookup("append") builtinBool = types.Universe.Lookup("bool") + builtinInt = types.Universe.Lookup("int") builtinFalse = types.Universe.Lookup("false") builtinLen = types.Universe.Lookup("len") builtinMake = types.Universe.Lookup("make") diff --git a/gopls/internal/analysis/modernize/slicesdelete.go b/gopls/internal/analysis/modernize/slicesdelete.go index 2d396787ddf..ca862863c9e 100644 --- a/gopls/internal/analysis/modernize/slicesdelete.go +++ b/gopls/internal/analysis/modernize/slicesdelete.go @@ -45,7 +45,39 @@ func slicesdelete(pass *analysis.Pass) { inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) info := pass.TypesInfo report := func(file *ast.File, call *ast.CallExpr, slice1, slice2 *ast.SliceExpr) { + insert := func(pos token.Pos, text string) analysis.TextEdit { + return analysis.TextEdit{Pos: pos, End: pos, NewText: []byte(text)} + } + isIntExpr := func(e ast.Expr) bool { + return types.Identical(types.Default(info.TypeOf(e)), builtinInt.Type()) + } + isIntShadowed := func() bool { + scope := pass.TypesInfo.Scopes[file].Innermost(call.Lparen) + if _, obj := scope.LookupParent("int", call.Lparen); obj != builtinInt { + return true // int type is shadowed + } + return false + } + _, prefix, edits := analysisinternal.AddImport(info, file, "slices", "slices", "Delete", call.Pos()) + // append's indices may be any integer type; slices.Delete requires int. + // Insert int conversions as needed (and if possible). + if isIntShadowed() && (!isIntExpr(slice1.High) || !isIntExpr(slice2.Low)) { + return + } + if !isIntExpr(slice1.High) { + edits = append(edits, + insert(slice1.High.Pos(), "int("), + insert(slice1.High.End(), ")"), + ) + } + if !isIntExpr(slice2.Low) { + edits = append(edits, + insert(slice2.Low.Pos(), "int("), + insert(slice2.Low.End(), ")"), + ) + } + pass.Report(analysis.Diagnostic{ Pos: call.Pos(), End: call.End(), @@ -123,7 +155,6 @@ func slicesdelete(pass *analysis.Pass) { // Given two slice indices a and b, returns true if we can verify that a < b. // It recognizes certain forms such as i+k1 < i+k2 where k1 < k2. func increasingSliceIndices(info *types.Info, a, b ast.Expr) bool { - // Given an expression of the form i±k, returns (i, k) // where k is a signed constant. Otherwise it returns (e, 0). split := func(e ast.Expr) (ast.Expr, constant.Value) { diff --git a/gopls/internal/analysis/modernize/testdata/src/slicesdelete/slicesdelete.go b/gopls/internal/analysis/modernize/testdata/src/slicesdelete/slicesdelete.go index 0ee608d8f9f..4d3a8abb98b 100644 --- a/gopls/internal/analysis/modernize/testdata/src/slicesdelete/slicesdelete.go +++ b/gopls/internal/analysis/modernize/testdata/src/slicesdelete/slicesdelete.go @@ -42,3 +42,15 @@ func slicesdelete(test, other []byte, i int) { _ = append(test[:1+2], test[i-1:]...) // cannot verify a < b } + +func issue73663(test, other []byte, i int32) { + const k = 1 + _ = append(test[:i], test[i+1:]...) // want "Replace append with slices.Delete" + + _ = append(test[:i-1], test[i:]...) // want "Replace append with slices.Delete" + + _ = append(g.f[:i], g.f[i+k:]...) // want "Replace append with slices.Delete" + + type int string // int is shadowed, so no offered fix. + _ = append(test[:i], test[i+1:]...) +} diff --git a/gopls/internal/analysis/modernize/testdata/src/slicesdelete/slicesdelete.go.golden b/gopls/internal/analysis/modernize/testdata/src/slicesdelete/slicesdelete.go.golden index a15eb07dee9..e0e39ab189a 100644 --- a/gopls/internal/analysis/modernize/testdata/src/slicesdelete/slicesdelete.go.golden +++ b/gopls/internal/analysis/modernize/testdata/src/slicesdelete/slicesdelete.go.golden @@ -44,3 +44,15 @@ func slicesdelete(test, other []byte, i int) { _ = append(test[:1+2], test[i-1:]...) // cannot verify a < b } + +func issue73663(test, other []byte, i int32) { + const k = 1 + _ = slices.Delete(test, int(i), int(i+1)) // want "Replace append with slices.Delete" + + _ = slices.Delete(test, int(i-1), int(i)) // want "Replace append with slices.Delete" + + _ = slices.Delete(g.f, int(i), int(i+k)) // want "Replace append with slices.Delete" + + type int string // int is shadowed, so no offered fix. + _ = append(test[:i], test[i+1:]...) +} \ No newline at end of file From c5e4271849bf65ff4a9bc29806ba9cc43c29d2f1 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 15 May 2025 16:35:46 -0400 Subject: [PATCH 092/196] internal/mcp: dispatcher middleware Implement the dispatcher middleware framework, allowing users to wrap method dispatching with their own code. The Dispatcher functions have an untyped params arg, but expect its value to be a typed, unmarshaled XXXParams struct. That required splitting the elegant dispatch[P, R] function into two parts, one to do the unmarshaling and one to call the method. The nicest way I could find to do that was to map from method names to a pair of functions, one to unmarshal and one to call. We could also write it as two switches on method name, which would require less machinery but would duplicate the list of method names, already a tad fragile because they are literal strings. I also exported the initialize params and result, so user Dispatchers could access them. All Params and Result types will have to be exported for that reason. Change-Id: Ie2c95cc17b4ef3d3d095dccdadb294e5f1590796 Reviewed-on: https://go-review.googlesource.com/c/tools/+/673515 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/client.go | 9 +- internal/mcp/design/design.md | 14 +-- internal/mcp/generate.go | 8 +- internal/mcp/mcp_test.go | 61 +++++++++++ internal/mcp/protocol.go | 90 ++++++++-------- internal/mcp/server.go | 188 ++++++++++++++++++++++------------ internal/mcp/transport.go | 4 +- 7 files changed, 249 insertions(+), 125 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index dff7c1bdb43..66d520bd219 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -44,8 +44,7 @@ func NewClient(name, version string, opts *ClientOptions) *Client { } // ClientOptions configures the behavior of the client. -type ClientOptions struct { -} +type ClientOptions struct{} // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. @@ -82,14 +81,14 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e if err != nil { return nil, err } - params := &initializeParams{ + params := &InitializeParams{ ClientInfo: &implementation{Name: c.name, Version: c.version}, } if err := call(ctx, cs.conn, "initialize", params, &cs.initializeResult); err != nil { _ = cs.Close() return nil, err } - if err := cs.conn.Notify(ctx, "notifications/initialized", &initializedParams{}); err != nil { + if err := cs.conn.Notify(ctx, "notifications/initialized", &InitializedParams{}); err != nil { _ = cs.Close() return nil, err } @@ -105,7 +104,7 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e type ClientSession struct { conn *jsonrpc2.Connection client *Client - initializeResult *initializeResult + initializeResult *InitializeResult } // Close performs a graceful close of the connection, preventing new requests diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index e3c02db2284..ebf10e347b3 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -409,20 +409,20 @@ func (*ClientSession) ResourceTemplates(context.Context, *ListResourceTemplatesP We provide a mechanism to add MCP-level middleware on the server side, which runs after the request has been parsed but before any normal handling. ```go -// A Dispatcher dispatches an MCP message to the appropriate handler. +// A ServerMethodHandler dispatches an MCP message to the appropriate handler. // The params argument will be an XXXParams struct pointer, such as *GetPromptParams. // The response if err is non-nil should be an XXXResult struct pointer. -type Dispatcher func(ctx context.Context, s *ServerSession, method string, params any) (result any, err error) +type ServerMethodHandler func(ctx context.Context, s *ServerSession, method string, params any) (result any, err error) -// AddDispatchers calls each function from right to left on the previous result, beginning +// AddMiddlewarecalls each function from right to left on the previous result, beginning // with the server's current dispatcher, and installs the result as the new dispatcher. -func (*Server) AddDispatchers(middleware ...func(Dispatcher) Dispatcher)) +func (*Server) AddMiddleware(middleware ...func(ServerMethodHandler) ServerMethodHandler) ``` As an example, this code adds server-side logging: ```go -func withLogging(h mcp.Dispatcher) mcp.Dispatcher { +func withLogging(h mcp.ServerMethodHandler) mcp.ServerMethodHandler{ return func(ctx context.Context, s *mcp.ServerSession, method string, params any) (res any, err error) { log.Printf("request: %s %v", method, params) defer func() { log.Printf("response: %v, %v", res, err) }() @@ -430,9 +430,11 @@ func withLogging(h mcp.Dispatcher) mcp.Dispatcher { } } -server.AddDispatchers(withLogging) +server.AddMiddleware(withLogging) ``` +We will provide the same functionality on the client side as well. + **Differences from mcp-go**: Version 0.26.0 of mcp-go defines 24 server hooks. Each hook consists of a field in the `Hooks` struct, a `Hooks.Add` method, and a type for the hook function. These are rarely used. The most common is `OnError`, which occurs fewer than ten times in open-source code. ### Errors diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index dc136714988..2c549cde126 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -66,7 +66,7 @@ var declarations = config{ "CallToolResult": {}, "CancelledNotification": { Name: "-", - Fields: config{"Params": {Name: "cancelledParams"}}, + Fields: config{"Params": {Name: "CancelledParams"}}, }, "ClientCapabilities": {}, "GetPromptRequest": { @@ -77,12 +77,12 @@ var declarations = config{ "Implementation": {Name: "implementation"}, "InitializeRequest": { Name: "-", - Fields: config{"Params": {Name: "initializeParams"}}, + Fields: config{"Params": {Name: "InitializeParams"}}, }, - "InitializeResult": {Name: "initializeResult"}, + "InitializeResult": {Name: "InitializeResult"}, "InitializedNotification": { Name: "-", - Fields: config{"Params": {Name: "initializedParams"}}, + Fields: config{"Params": {Name: "InitializedParams"}}, }, "ListPromptsRequest": { Name: "-", diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 4a0b0b4acfa..d8304b600ae 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -5,6 +5,7 @@ package mcp import ( + "bytes" "context" "errors" "fmt" @@ -397,4 +398,64 @@ func TestCancellation(t *testing.T) { } } +func TestAddMiddleware(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + s := NewServer("testServer", "v1.0.0", nil) + ss, err := s.Connect(ctx, st) + if err != nil { + t.Fatal(err) + } + // Wait for the server to exit after the client closes its connection. + var clientWG sync.WaitGroup + clientWG.Add(1) + go func() { + if err := ss.Wait(); err != nil { + t.Errorf("server failed: %v", err) + } + clientWG.Done() + }() + + var buf bytes.Buffer + buf.WriteByte('\n') + + // traceCalls creates a middleware function that prints the method before and after each call + // with the given prefix. + traceCalls := func(prefix string) func(ServerMethodHandler) ServerMethodHandler { + return func(d ServerMethodHandler) ServerMethodHandler { + return func(ctx context.Context, ss *ServerSession, method string, params any) (any, error) { + fmt.Fprintf(&buf, "%s >%s\n", prefix, method) + defer fmt.Fprintf(&buf, "%s <%s\n", prefix, method) + return d(ctx, ss, method, params) + } + } + } + + // "1" is the outer middleware layer, called first; then "2" is called, and finally + // the default dispatcher. + s.AddMiddleware(traceCalls("1"), traceCalls("2")) + + c := NewClient("testClient", "v1.0.0", nil) + cs, err := c.Connect(ctx, ct) + if err != nil { + t.Fatal(err) + } + if _, err := cs.ListTools(ctx, nil); err != nil { + t.Fatal(err) + } + want := ` +1 >initialize +2 >initialize +2 tools/list +2 >tools/list +2 Date: Fri, 9 May 2025 11:42:29 -0400 Subject: [PATCH 093/196] gopls/internal/test/marker: add mcp tools action marker @mcptool(name, args, output=output) marker will trigger a tool call to MCP server. - name: name of the tool - args: json formatted marshal as map[string]any - output(named arg): the text output of tool call "output" is named argument because the Go MCP Server may introduce MCP tools that apply text edits to user workspace (may introducing "edits" named argument in the future). Currently, a JSON string is used for "output" as argument sizes are not expected to be very large. This format can be adjusted in the future if test readability becomes an issue. MCP server and client will only be activated through flag "-mcp". Remove session ID from hello_world mcp tool response since session ID is a global integer that may change based on the position of mcp tools test. Move "SessionEvent" type decl from mcp package to lsprpc. For golang/go#73580 Change-Id: I416557e1410392f6ff810d31e2ab9270eb90e3f0 Reviewed-on: https://go-review.googlesource.com/c/tools/+/672056 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan Auto-Submit: Hongxiang Jiang --- gopls/internal/cmd/serve.go | 4 +- gopls/internal/lsprpc/lsprpc.go | 27 +++- gopls/internal/mcp/mcp.go | 122 ++++++++---------- gopls/internal/test/integration/env.go | 25 ++++ gopls/internal/test/marker/doc.go | 7 + gopls/internal/test/marker/marker_test.go | 85 ++++++++++-- .../marker/testdata/mcptools/hello_world.txt | 15 +++ 7 files changed, 194 insertions(+), 91 deletions(-) create mode 100644 gopls/internal/test/marker/testdata/mcptools/hello_world.txt diff --git a/gopls/internal/cmd/serve.go b/gopls/internal/cmd/serve.go index 7da129c8f2a..761895a73e2 100644 --- a/gopls/internal/cmd/serve.go +++ b/gopls/internal/cmd/serve.go @@ -104,7 +104,7 @@ func (s *Serve) Run(ctx context.Context, args ...string) error { // (creation, exit) to the MCP server. The sender must ensure that an exit // event for a given LSP session ID is sent after its corresponding creation // event. - var eventChan chan mcp.SessionEvent + var eventChan chan lsprpc.SessionEvent // cache shared between MCP and LSP servers. var ca *cache.Cache @@ -116,7 +116,7 @@ func (s *Serve) Run(ctx context.Context, args ...string) error { } } else { if s.MCPAddress != "" { - eventChan = make(chan mcp.SessionEvent) + eventChan = make(chan lsprpc.SessionEvent) } ca = cache.New(nil) ss = lsprpc.NewStreamServer(ca, isDaemon, eventChan, s.app.options) diff --git a/gopls/internal/lsprpc/lsprpc.go b/gopls/internal/lsprpc/lsprpc.go index b7fb40139f9..f432d64aa76 100644 --- a/gopls/internal/lsprpc/lsprpc.go +++ b/gopls/internal/lsprpc/lsprpc.go @@ -22,7 +22,6 @@ import ( "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/debug" "golang.org/x/tools/gopls/internal/label" - "golang.org/x/tools/gopls/internal/mcp" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/protocol/command" "golang.org/x/tools/gopls/internal/server" @@ -31,6 +30,20 @@ import ( "golang.org/x/tools/internal/jsonrpc2" ) +// SessionEventType differentiates between new and exiting sessions. +type SessionEventType int + +const ( + SessionStart SessionEventType = iota + SessionEnd +) + +// SessionEvent holds information about the session event. +type SessionEvent struct { + Type SessionEventType + Session *cache.Session +} + // Unique identifiers for client/server. var serverIndex int64 @@ -49,13 +62,13 @@ type streamServer struct { // eventChan is an optional channel for LSP server session lifecycle events, // including session creation and termination. If nil, no events are sent. - eventChan chan mcp.SessionEvent + eventChan chan SessionEvent } // NewStreamServer creates a StreamServer using the shared cache. If // withTelemetry is true, each session is instrumented with telemetry that // records RPC statistics. -func NewStreamServer(cache *cache.Cache, daemon bool, eventChan chan mcp.SessionEvent, optionsFunc func(*settings.Options)) jsonrpc2.StreamServer { +func NewStreamServer(cache *cache.Cache, daemon bool, eventChan chan SessionEvent, optionsFunc func(*settings.Options)) jsonrpc2.StreamServer { return &streamServer{cache: cache, daemon: daemon, eventChan: eventChan, optionsOverrides: optionsFunc} } @@ -93,14 +106,14 @@ func (s *streamServer) ServeStream(ctx context.Context, conn jsonrpc2.Conn) erro jsonrpc2.MethodNotFound)))) if s.eventChan != nil { - s.eventChan <- mcp.SessionEvent{ + s.eventChan <- SessionEvent{ Session: session, - Type: mcp.SessionNew, + Type: SessionStart, } defer func() { - s.eventChan <- mcp.SessionEvent{ + s.eventChan <- SessionEvent{ Session: session, - Type: mcp.SessionExiting, + Type: SessionEnd, } }() } diff --git a/gopls/internal/mcp/mcp.go b/gopls/internal/mcp/mcp.go index 8d1b115ad34..1a4a595cd54 100644 --- a/gopls/internal/mcp/mcp.go +++ b/gopls/internal/mcp/mcp.go @@ -13,73 +13,76 @@ import ( "sync" "golang.org/x/tools/gopls/internal/cache" + "golang.org/x/tools/gopls/internal/lsprpc" "golang.org/x/tools/gopls/internal/util/moremaps" "golang.org/x/tools/internal/mcp" ) -// EventType differentiates between new and exiting sessions. -type EventType int - -const ( - SessionNew EventType = iota - SessionExiting -) - -// SessionEvent holds information about the session event. -type SessionEvent struct { - Type EventType - Session *cache.Session -} +// Serve start an MCP server serving at the input address. +// The server receives LSP session events on the specified channel, which the +// caller is responsible for closing. The server runs until the context is +// canceled. +func Serve(ctx context.Context, address string, eventChan <-chan lsprpc.SessionEvent, cache *cache.Cache, isDaemon bool) error { + listener, err := net.Listen("tcp", address) + if err != nil { + return err + } + defer listener.Close() -// Serve start a MCP server serving at the input address. -func Serve(ctx context.Context, address string, eventChan chan SessionEvent, cache *cache.Cache, isDaemon bool) error { - m := manager{ - mcpHandlers: make(map[string]*mcp.SSEHandler), - eventChan: eventChan, - cache: cache, - isDaemon: isDaemon, + // TODO(hxjiang): expose the MCP server address to the LSP client. + if isDaemon { + log.Printf("Gopls MCP daemon: listening on address %s...", listener.Addr()) } - return m.serve(ctx, address) -} + defer log.Printf("Gopls MCP server: exiting") -// manager manages the mapping between LSP sessions and MCP servers. -type manager struct { - mu sync.Mutex // lock for mcpHandlers. - mcpHandlers map[string]*mcp.SSEHandler // map from lsp session ids to MCP sse handlers. + svr := http.Server{ + Handler: HTTPHandler(eventChan, cache, isDaemon), + BaseContext: func(net.Listener) context.Context { + return ctx + }, + } - eventChan chan SessionEvent // channel for receiving session creation and termination event - isDaemon bool - cache *cache.Cache // TODO(hxjiang): use cache to perform static analysis + // Run the server until cancellation. + go func() { + <-ctx.Done() + svr.Close() + }() + return svr.Serve(listener) } -// serve serves MCP server at the input address. -func (m *manager) serve(ctx context.Context, address string) error { +// HTTPHandler returns an HTTP handler for handling requests from MCP client. +func HTTPHandler(eventChan <-chan lsprpc.SessionEvent, cache *cache.Cache, isDaemon bool) http.Handler { + var ( + mu sync.Mutex // lock for mcpHandlers. + mcpHandlers = make(map[string]*mcp.SSEHandler) // map from lsp session ids to MCP sse handlers. + ) + // Spin up go routine listen to the session event channel until channel close. go func() { - for event := range m.eventChan { - m.mu.Lock() + for event := range eventChan { + mu.Lock() switch event.Type { - case SessionNew: - m.mcpHandlers[event.Session.ID()] = mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { - return newServer(m.cache, event.Session) + case lsprpc.SessionStart: + mcpHandlers[event.Session.ID()] = mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { + return newServer(cache, event.Session) }) - case SessionExiting: - delete(m.mcpHandlers, event.Session.ID()) + case lsprpc.SessionEnd: + delete(mcpHandlers, event.Session.ID()) } - m.mu.Unlock() + mu.Unlock() } }() // In daemon mode, gopls serves mcp server at ADDRESS/sessions/$SESSIONID. // Otherwise, gopls serves mcp server at ADDRESS. mux := http.NewServeMux() - if m.isDaemon { + if isDaemon { mux.HandleFunc("/sessions/{id}", func(w http.ResponseWriter, r *http.Request) { sessionID := r.PathValue("id") - m.mu.Lock() - handler := m.mcpHandlers[sessionID] - m.mu.Unlock() + mu.Lock() + handler := mcpHandlers[sessionID] + mu.Unlock() if handler == nil { http.Error(w, fmt.Sprintf("session %s not established", sessionID), http.StatusNotFound) @@ -91,10 +94,10 @@ func (m *manager) serve(ctx context.Context, address string) error { } else { // TODO(hxjiang): should gopls serve only at a specific path? mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - m.mu.Lock() + mu.Lock() // When not in daemon mode, gopls has at most one LSP session. - _, handler, ok := moremaps.Arbitrary(m.mcpHandlers) - m.mu.Unlock() + _, handler, ok := moremaps.Arbitrary(mcpHandlers) + mu.Unlock() if !ok { http.Error(w, "session not established", http.StatusNotFound) @@ -104,28 +107,7 @@ func (m *manager) serve(ctx context.Context, address string) error { handler.ServeHTTP(w, r) }) } - - listener, err := net.Listen("tcp", address) - if err != nil { - return err - } - defer listener.Close() - // TODO(hxjiang): expose the mcp server address to the lsp client. - if m.isDaemon { - log.Printf("Gopls MCP daemon: listening on address %s...", listener.Addr()) - } - defer log.Printf("Gopls MCP server: exiting") - - svr := http.Server{ - Handler: mux, - BaseContext: func(net.Listener) context.Context { return ctx }, - } - // Run the server until cancellation. - go func() { - <-ctx.Done() - svr.Close() - }() - return svr.Serve(listener) + return mux } func newServer(_ *cache.Cache, session *cache.Session) *mcp.Server { @@ -140,10 +122,10 @@ type HelloParams struct { Name string `json:"name" mcp:"the name to say hi to"` } -func helloHandler(session *cache.Session) func(ctx context.Context, cc *mcp.ServerSession, request *HelloParams) ([]*mcp.Content, error) { +func helloHandler(_ *cache.Session) func(ctx context.Context, cc *mcp.ServerSession, request *HelloParams) ([]*mcp.Content, error) { return func(ctx context.Context, cc *mcp.ServerSession, request *HelloParams) ([]*mcp.Content, error) { return []*mcp.Content{ - mcp.NewTextContent("Hi " + request.Name + ", this is lsp session " + session.ID()), + mcp.NewTextContent("Hi " + request.Name), }, nil } } diff --git a/gopls/internal/test/integration/env.go b/gopls/internal/test/integration/env.go index 822120e8324..a2f5449b42a 100644 --- a/gopls/internal/test/integration/env.go +++ b/gopls/internal/test/integration/env.go @@ -7,19 +7,23 @@ package integration import ( "context" "fmt" + "net/http/httptest" "strings" "sync" "sync/atomic" "testing" + "golang.org/x/tools/gopls/internal/lsprpc" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/test/integration/fake" "golang.org/x/tools/internal/jsonrpc2/servertest" + "golang.org/x/tools/internal/mcp" ) // Env holds the building blocks of an editor testing environment, providing // wrapper methods that hide the boilerplate of plumbing contexts and checking // errors. +// Call [Env.Shutdown] for cleaning up resources after the test. type Env struct { TB testing.TB Ctx context.Context @@ -33,6 +37,12 @@ type Env struct { Editor *fake.Editor Awaiter *Awaiter + + // MCPServer, MCPSession and EventChan is owned by the Env, and shut down. + // Only available if the test enables MCP Server. + MCPServer *httptest.Server + MCPSession *mcp.ClientSession + EventChan chan<- lsprpc.SessionEvent } // nextAwaiterRegistration is used to create unique IDs for various Awaiter @@ -325,6 +335,21 @@ func (e *Env) OnceMet(pre Expectation, mustMeets ...Expectation) { e.Await(OnceMet(pre, AllOf(mustMeets...))) } +// Shutdown releases the resources of an Env that is no longer needed. +func (e *Env) Shutdown() { + e.Sandbox.Close() // ignore error + e.Editor.Shutdown(context.Background()) // ignore error + if e.MCPSession != nil { + e.MCPSession.Close() // ignore error + } + if e.MCPServer != nil { + e.MCPServer.Close() + } + if e.EventChan != nil { + close(e.EventChan) + } +} + // Await waits for all expectations to simultaneously be met. It should only be // called from the main test goroutine. func (a *Awaiter) Await(ctx context.Context, expectation Expectation) error { diff --git a/gopls/internal/test/marker/doc.go b/gopls/internal/test/marker/doc.go index 604ee4c4033..131d799a758 100644 --- a/gopls/internal/test/marker/doc.go +++ b/gopls/internal/test/marker/doc.go @@ -307,6 +307,13 @@ Here is the list of supported action markers: location name kind + - mcptool(name string, args string, output=golden): Executes an MCP tool + call using the provided tool name and args (a JSON-encoded value). It then + asserts that the MCP server's response matches the content of the golden + file identified by output. Unlike golden references for file edits or file + results, which may contain multiple files (each with a path), the output + golden content here is a single entity, effectively having an empty path(""). + # Argument conversion Marker arguments are first parsed by the internal/expect package, which accepts diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index 8cc7c56320d..f8aa59634d7 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -18,6 +18,7 @@ import ( "go/types" "io/fs" "log" + "net/http/httptest" "os" "path" "path/filepath" @@ -34,6 +35,7 @@ import ( "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/debug" "golang.org/x/tools/gopls/internal/lsprpc" + internalmcp "golang.org/x/tools/gopls/internal/mcp" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/test/compare" "golang.org/x/tools/gopls/internal/test/integration" @@ -45,6 +47,7 @@ import ( "golang.org/x/tools/internal/expect" "golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/jsonrpc2/servertest" + "golang.org/x/tools/internal/mcp" "golang.org/x/tools/internal/testenv" "golang.org/x/tools/txtar" ) @@ -176,16 +179,13 @@ func Test(t *testing.T) { run := &markerTestRun{ test: test, - env: newEnv(t, cache, test.files, test.proxyFiles, test.writeGoSum, config), + env: newEnv(t, cache, test.files, test.proxyFiles, test.writeGoSum, config, test.mcp), settings: config.Settings, values: make(map[expect.Identifier]any), diags: make(map[protocol.Location][]protocol.Diagnostic), extraNotes: make(map[protocol.DocumentURI]map[string][]*expect.Note), } - - // TODO(rfindley): make it easier to clean up the integration test environment. - defer run.env.Editor.Shutdown(context.Background()) // ignore error - defer run.env.Sandbox.Close() // ignore error + defer run.env.Shutdown() // Open all files so that we operate consistently with LSP clients, and // (pragmatically) so that we have a Mapper available via the fake @@ -603,6 +603,7 @@ var actionMarkerFuncs = map[string]func(marker){ "token": actionMarkerFunc(tokenMarker), "typedef": actionMarkerFunc(typedefMarker), "workspacesymbol": actionMarkerFunc(workspaceSymbolMarker), + "mcptool": actionMarkerFunc(mcpToolMarker, "output"), } // markerTest holds all the test data extracted from a test txtar archive. @@ -637,6 +638,7 @@ type markerTest struct { filterBuiltins bool filterKeywords bool errorsOK bool + mcp bool } // flagSet returns the flagset used for parsing the special "flags" file in the @@ -654,6 +656,7 @@ func (t *markerTest) flagSet() *flag.FlagSet { flags.BoolVar(&t.filterBuiltins, "filter_builtins", true, "if set, filter builtins from completion results") flags.BoolVar(&t.filterKeywords, "filter_keywords", true, "if set, filter keywords from completion results") flags.BoolVar(&t.errorsOK, "errors_ok", false, "if set, Error level log messages are acceptable in this test") + flags.BoolVar(&t.mcp, "mcp", false, "if set, enable model context protocol client and server in this test") return flags } @@ -946,7 +949,7 @@ func formatTest(test *markerTest) ([]byte, error) { // // TODO(rfindley): simplify and refactor the construction of testing // environments across integration tests, marker tests, and benchmarks. -func newEnv(t *testing.T, cache *cache.Cache, files, proxyFiles map[string][]byte, writeGoSum []string, config fake.EditorConfig) *integration.Env { +func newEnv(t *testing.T, cache *cache.Cache, files, proxyFiles map[string][]byte, writeGoSum []string, config fake.EditorConfig, enableMCP bool) *integration.Env { sandbox, err := fake.NewSandbox(&fake.SandboxConfig{ RootDir: t.TempDir(), Files: files, @@ -968,13 +971,23 @@ func newEnv(t *testing.T, cache *cache.Cache, files, proxyFiles map[string][]byt ctx = debug.WithInstance(ctx) awaiter := integration.NewAwaiter(sandbox.Workdir) - ss := lsprpc.NewStreamServer(cache, false, nil, nil) + + var eventChan chan lsprpc.SessionEvent + var mcpServer *httptest.Server + if enableMCP { + eventChan = make(chan lsprpc.SessionEvent) + mcpServer = httptest.NewServer(internalmcp.HTTPHandler(eventChan, cache, false)) + } + + ss := lsprpc.NewStreamServer(cache, false, eventChan, nil) + server := servertest.NewPipeServer(ss, jsonrpc2.NewRawStream) editor, err := fake.NewEditor(sandbox, config).Connect(ctx, server, awaiter.Hooks()) if err != nil { sandbox.Close() // ignore error t.Fatal(err) } + if err := awaiter.Await(ctx, integration.OnceMet( integration.InitialWorkspaceLoad, integration.NoShownMessage(""), @@ -982,12 +995,25 @@ func newEnv(t *testing.T, cache *cache.Cache, files, proxyFiles map[string][]byt sandbox.Close() // ignore error t.Fatal(err) } + + var mcpSession *mcp.ClientSession + if enableMCP { + client := mcp.NewClient("test", "v1.0.0", nil) + mcpSession, err = client.Connect(ctx, mcp.NewSSEClientTransport(mcpServer.URL)) + if err != nil { + t.Fatalf("fail to connect to mcp server: %v", err) + } + } + return &integration.Env{ - TB: t, - Ctx: ctx, - Editor: editor, - Sandbox: sandbox, - Awaiter: awaiter, + TB: t, + Ctx: ctx, + Editor: editor, + Sandbox: sandbox, + Awaiter: awaiter, + MCPSession: mcpSession, + MCPServer: mcpServer, + EventChan: eventChan, } } @@ -2402,6 +2428,41 @@ func itemLocation(item protocol.CallHierarchyItem) protocol.Location { } } +func mcpToolMarker(mark marker, tool string, args string) { + var toolArgs map[string]any + if err := json.Unmarshal([]byte(args), &toolArgs); err != nil { + mark.errorf("fail to unmarshal arguments to map[string]any: %v", err) + return + } + + res, err := mark.run.env.MCPSession.CallTool(mark.ctx(), tool, toolArgs, nil) + if err != nil { + mark.errorf("failed to call mcp tool: %v", err) + return + } + + var buf bytes.Buffer + for i, c := range res.Content { + if c.Type != "text" { + mark.errorf("unsupported return content[%v] type: %s", i, c.Type) + } + buf.WriteString(c.Text) + } + if !bytes.HasSuffix(buf.Bytes(), []byte{'\n'}) { + buf.WriteString("\n") // all golden content is newline terminated + } + + got := buf.String() + + output := namedArg(mark, "output", expect.Identifier("")) + golden := mark.getGolden(output) + if want, ok := golden.Get(mark.T(), "", []byte(got)); !ok { + mark.errorf("%s: missing golden file @%s", mark.note.Name, golden.id) + } else if diff := cmp.Diff(got, string(want)); diff != "" { + mark.errorf("unexpected mcp tools call %s return; got:\n%s\n want:\n%s\ndiff:\n%s", tool, got, want, diff) + } +} + func incomingCallsMarker(mark marker, src protocol.Location, want ...protocol.Location) { getCalls := func(item protocol.CallHierarchyItem) ([]protocol.Location, error) { calls, err := mark.server().IncomingCalls(mark.ctx(), &protocol.CallHierarchyIncomingCallsParams{Item: item}) diff --git a/gopls/internal/test/marker/testdata/mcptools/hello_world.txt b/gopls/internal/test/marker/testdata/mcptools/hello_world.txt new file mode 100644 index 00000000000..5bae5afa416 --- /dev/null +++ b/gopls/internal/test/marker/testdata/mcptools/hello_world.txt @@ -0,0 +1,15 @@ +This test exercises mcp tool hello_world. + +-- flags -- +-mcp + +-- go.mod -- +module golang.org/mcptests/mcptools + +-- mcp/tools/helloworld.go -- +package helloworld + +func A() {} //@mcptool("hello_world", `{"name": "jerry"}`, output=hello) + +-- @hello -- +Hi jerry From d71c72f562d1e732a586af6091ec9900379ace8a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 15 May 2025 13:59:59 -0400 Subject: [PATCH 094/196] gopls/internal/analysis/modernize: bloop: document deoptimization The bloop modernizer may prevent benchmarks from being trivialized (e.g. constant propagated out of existence). However, due to current compiler behavior it may also cause increased allocation by defeating inlining. This could be a surprising behavior change. Strictly speaking this violates our usual criteria for modernizers, but in this case we plan to just document the problem for now. Updates golang/go#73137 Fixes golang/go#73579 Change-Id: I29b87419ce5468f96dad2d046dca0dede65295c7 Reviewed-on: https://go-review.googlesource.com/c/tools/+/673200 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- gopls/doc/analyzers.md | 6 ++++++ gopls/internal/analysis/modernize/doc.go | 6 ++++++ gopls/internal/doc/api.json | 4 ++-- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md index 915afe346dc..cea19c40ca3 100644 --- a/gopls/doc/analyzers.md +++ b/gopls/doc/analyzers.md @@ -3741,6 +3741,12 @@ Categories of modernize diagnostic: benchmark with "for b.Loop()", and remove any preceding calls to b.StopTimer, b.StartTimer, and b.ResetTimer. + B.Loop intentionally defeats compiler optimizations such as + inlining so that the benchmark is not entirely optimized away. + Currently, however, it may cause benchmarks to become slower + in some cases due to increased allocation; see + https://go.dev/issue/73137. + - rangeint: replace a 3-clause "for i := 0; i < n; i++" loop by "for i := range n", added in go1.22. diff --git a/gopls/internal/analysis/modernize/doc.go b/gopls/internal/analysis/modernize/doc.go index e7cf5c9c8fd..e136807089f 100644 --- a/gopls/internal/analysis/modernize/doc.go +++ b/gopls/internal/analysis/modernize/doc.go @@ -82,6 +82,12 @@ // benchmark with "for b.Loop()", and remove any preceding calls // to b.StopTimer, b.StartTimer, and b.ResetTimer. // +// B.Loop intentionally defeats compiler optimizations such as +// inlining so that the benchmark is not entirely optimized away. +// Currently, however, it may cause benchmarks to become slower +// in some cases due to increased allocation; see +// https://go.dev/issue/73137. +// // - rangeint: replace a 3-clause "for i := 0; i < n; i++" loop by // "for i := range n", added in go1.22. // diff --git a/gopls/internal/doc/api.json b/gopls/internal/doc/api.json index 969bc1a17ef..24fdc0e3835 100644 --- a/gopls/internal/doc/api.json +++ b/gopls/internal/doc/api.json @@ -1492,7 +1492,7 @@ }, { "Name": "\"modernize\"", - "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", + "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n B.Loop intentionally defeats compiler optimizations such as\n inlining so that the benchmark is not entirely optimized away.\n Currently, however, it may cause benchmarks to become slower\n in some cases due to increased allocation; see\n https://go.dev/issue/73137.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", "Default": "true", "Status": "" }, @@ -3212,7 +3212,7 @@ }, { "Name": "modernize", - "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", + "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n B.Loop intentionally defeats compiler optimizations such as\n inlining so that the benchmark is not entirely optimized away.\n Currently, however, it may cause benchmarks to become slower\n in some cases due to increased allocation; see\n https://go.dev/issue/73137.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", "URL": "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/modernize", "Default": true }, From 7ae2e5ccc2ddcc3464bb7e0e2f71051667258fa8 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 16 May 2025 17:54:52 -0400 Subject: [PATCH 095/196] gopls/internal/golang: implement "inline local variable" code action This change adds a new refactor.inline.variable code action that replaces a reference to a local variable by that variable's initializer expression, if possible (not shadowed). The "inline all" case will be implemented in a follow-up. + test, doc, relnote Updates golang/go#70085 Change-Id: I83c2f0ff9715239f8436b91f0b86bfd6a42746e0 Reviewed-on: https://go-review.googlesource.com/c/tools/+/673636 Auto-Submit: Alan Donovan Commit-Queue: Alan Donovan Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- gopls/doc/features/transformation.md | 50 ++++++++++++ gopls/doc/release/v0.19.0.md | 22 +++++- gopls/internal/golang/codeaction.go | 14 +++- gopls/internal/golang/extract.go | 17 ++-- gopls/internal/golang/fix.go | 4 +- gopls/internal/golang/inline.go | 79 +++++++++++++++++++ gopls/internal/settings/codeactionkind.go | 3 +- gopls/internal/settings/default.go | 1 + .../marker/testdata/codeaction/inline-var.txt | 38 +++++++++ 9 files changed, 216 insertions(+), 12 deletions(-) create mode 100644 gopls/internal/test/marker/testdata/codeaction/inline-var.txt diff --git a/gopls/doc/features/transformation.md b/gopls/doc/features/transformation.md index 91b6c46b74d..f21904f902d 100644 --- a/gopls/doc/features/transformation.md +++ b/gopls/doc/features/transformation.md @@ -79,6 +79,7 @@ Gopls supports the following code actions: - [`refactor.extract.variable`](#extract) - [`refactor.extract.variable-all`](#extract) - [`refactor.inline.call`](#refactor.inline.call) +- [`refactor.inline.variable`](#refactor.inline.variable) - [`refactor.rewrite.addTags`](#refactor.rewrite.addTags) - [`refactor.rewrite.changeQuote`](#refactor.rewrite.changeQuote) - [`refactor.rewrite.fillStruct`](#refactor.rewrite.fillStruct) @@ -607,6 +608,55 @@ more detail. All of this is to say, it's a complex problem, and we aim for correctness first of all. We've already implemented a number of important "tidiness optimizations" and we expect more to follow. + + +## `refactor.inline.variable`: Inline local variable + +For a `codeActions` request where the selection is (or is within) an +identifier that is a use of a local variable whose declaration has an +initializer expression, gopls will return a code action of kind +`refactor.inline.variable`, whose effect is to inline the variable: +that is, to replace the reference by the variable's initializer +expression. + +For example, if invoked on the identifier `s` in the call `println(s)`: +```go +func f(x int) { + s := fmt.Sprintf("+%d", x) + println(s) +} +``` +the code action transforms the code to: + +```go +func f(x int) { + s := fmt.Sprintf("+%d", x) + println(fmt.Sprintf("+%d", x)) +} +``` + +(In this instance, `s` becomes an unreferenced variable which you will +need to remove.) + +The code action always replaces the reference by the initializer +expression, even if there are later assignments to the variable (such +as `s = ""`). + +The code action reports an error if it is not possible to make the +transformation because one of the identifiers within the initializer +expression (e.g. `x` in the example above) is shadowed by an +intervening declaration, as in this example: + +```go +func f(x int) { + s := fmt.Sprintf("+%d", x) + { + x := 123 + println(s, x) // error: cannot replace s with fmt.Sprintf(...) since x is shadowed + } +} +``` + ## `refactor.rewrite`: Miscellaneous rewrites diff --git a/gopls/doc/release/v0.19.0.md b/gopls/doc/release/v0.19.0.md index b8f53a72304..8842098639e 100644 --- a/gopls/doc/release/v0.19.0.md +++ b/gopls/doc/release/v0.19.0.md @@ -135,4 +135,24 @@ type Info struct { LinkTarget string -> LinkTarget string `json:"link_target"` ... } -``` \ No newline at end of file +``` + +## Inline local variable + +The new `refactor.inline.variable` code action replaces a reference to +a local variable by that variable's initializer expression. For +example, when applied to `s` in `println(s)`: + +```go +func f(x int) { + s := fmt.Sprintf("+%d", x) + println(s) +} +``` +it transforms the code to: +```go +func f(x int) { + s := fmt.Sprintf("+%d", x) + println(fmt.Sprintf("+%d", x)) +} +``` diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index 07b577de745..3d43d5694dc 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -253,6 +253,7 @@ var codeActionProducers = [...]codeActionProducer{ {kind: settings.RefactorExtractConstantAll, fn: refactorExtractVariableAll, needPkg: true}, {kind: settings.RefactorExtractVariableAll, fn: refactorExtractVariableAll, needPkg: true}, {kind: settings.RefactorInlineCall, fn: refactorInlineCall, needPkg: true}, + {kind: settings.RefactorInlineVariable, fn: refactorInlineVariable, needPkg: true}, {kind: settings.RefactorRewriteChangeQuote, fn: refactorRewriteChangeQuote}, {kind: settings.RefactorRewriteFillStruct, fn: refactorRewriteFillStruct, needPkg: true}, {kind: settings.RefactorRewriteFillSwitch, fn: refactorRewriteFillSwitch, needPkg: true}, @@ -506,7 +507,7 @@ func refactorExtractVariable(ctx context.Context, req *codeActionsRequest) error } // refactorExtractVariableAll produces "Extract N occurrences of EXPR" code action. -// See [extractAllOccursOfExpr] for command implementation. +// See [extractVariable] for implementation. func refactorExtractVariableAll(ctx context.Context, req *codeActionsRequest) error { info := req.pkg.TypesInfo() // Don't suggest if only one expr is found, @@ -957,6 +958,17 @@ func refactorInlineCall(ctx context.Context, req *codeActionsRequest) error { return nil } +// refactorInlineVariable produces the "Inline variable 'v'" code action. +// See [inlineVariableOne] for command implementation. +func refactorInlineVariable(ctx context.Context, req *codeActionsRequest) error { + // TODO(adonovan): offer "inline all" variant that eliminates the var (see #70085). + if curUse, _, ok := canInlineVariable(req.pkg.TypesInfo(), req.pgf.Cursor, req.start, req.end); ok { + title := fmt.Sprintf("Inline variable %q", curUse.Node().(*ast.Ident).Name) + req.addApplyFixAction(title, fixInlineVariable, req.loc) + } + return nil +} + // goTest produces "Run tests and benchmarks" code actions. // See [server.commandHandler.runTests] for command implementation. func goTest(ctx context.Context, req *codeActionsRequest) error { diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 59916676fe9..70a84f6159d 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -31,23 +31,24 @@ import ( "golang.org/x/tools/internal/typesinternal" ) -// extractVariable implements the refactor.extract.{variable,constant} CodeAction command. -func extractVariable(pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) { - return extractExprs(pkg, pgf, start, end, false) +// extractVariableOne implements the refactor.extract.{variable,constant} CodeAction command. +func extractVariableOne(pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) { + return extractVariable(pkg, pgf, start, end, false) } // extractVariableAll implements the refactor.extract.{variable,constant}-all CodeAction command. func extractVariableAll(pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) { - return extractExprs(pkg, pgf, start, end, true) + return extractVariable(pkg, pgf, start, end, true) } -// extractExprs replaces occurrence(s) of a specified expression within the same function -// with newVar. If 'all' is true, it replaces all occurrences of the same expression; -// otherwise, it only replaces the selected expression. +// extractVariable replaces one or all occurrences of a specified +// expression within the same function with newVar. If 'all' is true, +// it replaces all occurrences of the same expression; otherwise, it +// only replaces the selected expression. // // The new variable/constant is declared as close as possible to the first found expression // within the deepest common scope accessible to all candidate occurrences. -func extractExprs(pkg *cache.Package, pgf *parsego.File, start, end token.Pos, all bool) (*token.FileSet, *analysis.SuggestedFix, error) { +func extractVariable(pkg *cache.Package, pgf *parsego.File, start, end token.Pos, all bool) (*token.FileSet, *analysis.SuggestedFix, error) { var ( fset = pkg.FileSet() info = pkg.TypesInfo() diff --git a/gopls/internal/golang/fix.go b/gopls/internal/golang/fix.go index dbd83ef071f..0308c38c5cf 100644 --- a/gopls/internal/golang/fix.go +++ b/gopls/internal/golang/fix.go @@ -57,6 +57,7 @@ const ( fixExtractFunction = "extract_function" fixExtractMethod = "extract_method" fixInlineCall = "inline_call" + fixInlineVariable = "inline_variable" fixInvertIfCondition = "invert_if_condition" fixSplitLines = "split_lines" fixJoinLines = "join_lines" @@ -100,9 +101,10 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file // constructed directly by logic in server/code_action. fixExtractFunction: singleFile(extractFunction), fixExtractMethod: singleFile(extractMethod), - fixExtractVariable: singleFile(extractVariable), + fixExtractVariable: singleFile(extractVariableOne), fixExtractVariableAll: singleFile(extractVariableAll), fixInlineCall: inlineCall, + fixInlineVariable: singleFile(inlineVariableOne), fixInvertIfCondition: singleFile(invertIfCondition), fixSplitLines: singleFile(splitLines), fixJoinLines: singleFile(joinLines), diff --git a/gopls/internal/golang/inline.go b/gopls/internal/golang/inline.go index 8e5e906c566..15bc5bb52e0 100644 --- a/gopls/internal/golang/inline.go +++ b/gopls/internal/golang/inline.go @@ -15,6 +15,8 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ast/edge" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/parsego" @@ -23,6 +25,7 @@ import ( "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/refactor/inline" + "golang.org/x/tools/internal/typesinternal" ) // enclosingStaticCall returns the innermost function call enclosing @@ -134,3 +137,79 @@ func logger(ctx context.Context, name string, verbose bool) func(format string, return func(string, ...any) {} } } + +// canInlineVariable reports whether the selection is within an +// identifier that is a use of a variable that has an initializer +// expression. If so, it returns cursors for the identifier and the +// initializer expression. +func canInlineVariable(info *types.Info, curFile inspector.Cursor, start, end token.Pos) (_, _ inspector.Cursor, ok bool) { + if curUse, ok := curFile.FindByPos(start, end); ok { + if id, ok := curUse.Node().(*ast.Ident); ok { + if v, ok := info.Uses[id].(*types.Var); ok && + // Check that the variable is local. + // TODO(adonovan): simplify using go1.25 Var.Kind = Local. + !typesinternal.IsPackageLevel(v) && !v.IsField() { + + if curIdent, ok := curFile.FindByPos(v.Pos(), v.Pos()); ok { + curParent := curIdent.Parent() + switch kind, index := curIdent.ParentEdge(); kind { + case edge.ValueSpec_Names: + // var v = expr + spec := curParent.Node().(*ast.ValueSpec) + if len(spec.Names) == len(spec.Values) { + return curUse, curParent.ChildAt(edge.ValueSpec_Values, index), true + } + case edge.AssignStmt_Lhs: + // v := expr + stmt := curParent.Node().(*ast.AssignStmt) + if len(stmt.Lhs) == len(stmt.Rhs) { + return curUse, curParent.ChildAt(edge.AssignStmt_Rhs, index), true + } + } + } + } + } + } + return +} + +// inlineVariableOne computes a fix to replace the selected variable by +// its initialization expression. +func inlineVariableOne(pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) { + info := pkg.TypesInfo() + curUse, curRHS, ok := canInlineVariable(info, pgf.Cursor, start, end) + if !ok { + return nil, nil, fmt.Errorf("cannot inline variable here") + } + use := curUse.Node().(*ast.Ident) + + // Check that free symbols of rhs are unshadowed at curUse. + var ( + pos = use.Pos() + scope = info.Scopes[pgf.File].Innermost(pos) + ) + for curIdent := range curRHS.Preorder((*ast.Ident)(nil)) { + if ek, _ := curIdent.ParentEdge(); ek == edge.SelectorExpr_Sel { + continue // ignore f in x.f + } + id := curIdent.Node().(*ast.Ident) + obj1 := info.Uses[id] + _, obj2 := scope.LookupParent(id.Name, pos) + if obj1 != obj2 { + return nil, nil, fmt.Errorf("cannot inline variable: its initializer expression refers to %q, which is shadowed by the declaration at line %d", id.Name, safetoken.Position(pgf.Tok, obj2.Pos()).Line) + } + } + + // TODO(adonovan): also reject variables that are updated by assignments? + + return pkg.FileSet(), &analysis.SuggestedFix{ + Message: fmt.Sprintf("Replace variable %q by its initializer expression", use.Name), + TextEdits: []analysis.TextEdit{ + { + Pos: use.Pos(), + End: use.End(), + NewText: []byte(FormatNode(pkg.FileSet(), curRHS.Node())), + }, + }, + }, nil +} diff --git a/gopls/internal/settings/codeactionkind.go b/gopls/internal/settings/codeactionkind.go index ebe9606adab..b617b94eea7 100644 --- a/gopls/internal/settings/codeactionkind.go +++ b/gopls/internal/settings/codeactionkind.go @@ -101,7 +101,8 @@ const ( RefactorRewriteRemoveTags protocol.CodeActionKind = "refactor.rewrite.removeTags" // refactor.inline - RefactorInlineCall protocol.CodeActionKind = "refactor.inline.call" + RefactorInlineCall protocol.CodeActionKind = "refactor.inline.call" + RefactorInlineVariable protocol.CodeActionKind = "refactor.inline.variable" // refactor.extract RefactorExtractConstant protocol.CodeActionKind = "refactor.extract.constant" diff --git a/gopls/internal/settings/default.go b/gopls/internal/settings/default.go index aa81640f3e8..70adc1ade02 100644 --- a/gopls/internal/settings/default.go +++ b/gopls/internal/settings/default.go @@ -62,6 +62,7 @@ func DefaultOptions(overrides ...func(*Options)) *Options { RefactorRewriteRemoveUnusedParam: true, RefactorRewriteSplitLines: true, RefactorInlineCall: true, + RefactorInlineVariable: true, RefactorExtractConstant: true, RefactorExtractConstantAll: true, RefactorExtractFunction: true, diff --git a/gopls/internal/test/marker/testdata/codeaction/inline-var.txt b/gopls/internal/test/marker/testdata/codeaction/inline-var.txt new file mode 100644 index 00000000000..f1770b18867 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/inline-var.txt @@ -0,0 +1,38 @@ +This is a test of the refactor.inline.variable code action. + +-- go.mod -- +module example.com/a +go 1.18 + +-- a/a.go -- +package a + +import "fmt" + +func _(x int) { + s := fmt.Sprintf("+%d", x) + println(s) //@codeaction("s", "refactor.inline.variable", result=inlineS) +} + +-- @inlineS/a/a.go -- +package a + +import "fmt" + +func _(x int) { + s := fmt.Sprintf("+%d", x) + println(fmt.Sprintf("+%d", x)) //@codeaction("s", "refactor.inline.variable", result=inlineS) +} + +-- b/b.go -- +package b + +import "fmt" + +func _(x int) { + s2 := fmt.Sprintf("+%d", x) + { + x := "shadow" + println(s2, x) //@codeaction("s2", "refactor.inline.variable", err=re`refers to "x".*shadowed.*at line 8`) + } +} From 8edad1ef82bc6ccd164cb218e4e9fbb7b13b24d5 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 13 May 2025 15:38:08 -0400 Subject: [PATCH 096/196] internal/mcp: add the FileResourceHandler method Provide secure reading of files on the local filesystem. There are subtleties in path manipulation on Windows that I don't fully understand, so I defer that work for now. Change-Id: Ic11f4617dde4ac0b31ccd5b0a772bb2f56db026f Reviewed-on: https://go-review.googlesource.com/c/tools/+/673196 Reviewed-by: Robert Findley Reviewed-by: Sam Thanawalla LUCI-TryBot-Result: Go LUCI --- internal/mcp/design/design.md | 2 +- internal/mcp/mcp_test.go | 28 ++--- internal/mcp/resource.go | 158 +++++++++++++++++++++++++++ internal/mcp/resource_go124.go | 29 +++++ internal/mcp/resource_pre_go124.go | 25 +++++ internal/mcp/resource_test.go | 114 +++++++++++++++++++ internal/mcp/server.go | 72 +++++++----- internal/mcp/testdata/files/info.txt | 1 + 8 files changed, 386 insertions(+), 43 deletions(-) create mode 100644 internal/mcp/resource.go create mode 100644 internal/mcp/resource_go124.go create mode 100644 internal/mcp/resource_pre_go124.go create mode 100644 internal/mcp/resource_test.go create mode 100644 internal/mcp/testdata/files/info.txt diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index ebf10e347b3..fa61960aeeb 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -742,7 +742,7 @@ To read files from the local filesystem, we recommend using `FileResourceHandler ```go // FileResourceHandler returns a ResourceHandler that reads paths using dir as a root directory. // It protects against path traversal attacks. -// It will not read any file that is not in the root set of the client session requesting the resource. +// It will not read any file that is not in the root set of the client requesting the resource. func (*Server) FileResourceHandler(dir string) ResourceHandler ``` diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index d8304b600ae..ebd988e4329 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -9,6 +9,8 @@ import ( "context" "errors" "fmt" + "path/filepath" + "runtime" "slices" "strings" "sync" @@ -83,7 +85,11 @@ func TestEndToEnd(t *testing.T) { }() c := NewClient("testClient", "v1.0.0", nil) - c.AddRoots(&Root{URI: "file:///root"}) + rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) + if err != nil { + t.Fatal(err) + } + c.AddRoots(&Root{URI: "file://" + rootAbs}) // Connect the client. cs, err := c.Connect(ctx, ct) @@ -189,10 +195,13 @@ func TestEndToEnd(t *testing.T) { }) t.Run("resources", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("TODO: fix for Windows") + } resource1 := &Resource{ Name: "public", MIMEType: "text/plain", - URI: "file:///file1.txt", + URI: "file:///info.txt", } resource2 := &Resource{ Name: "public", // names are not unique IDs @@ -200,16 +209,7 @@ func TestEndToEnd(t *testing.T) { URI: "file:///nonexistent.txt", } - readHandler := func(_ context.Context, _ *ServerSession, p *ReadResourceParams) (*ReadResourceResult, error) { - if p.URI == "file:///file1.txt" { - return &ReadResourceResult{ - Contents: &ResourceContents{ - Text: "file contents", - }, - }, nil - } - return nil, ResourceNotFoundError(p.URI) - } + readHandler := s.FileResourceHandler("testdata/files") s.AddResources( &ServerResource{resource1, readHandler}, &ServerResource{resource2, readHandler}) @@ -226,7 +226,7 @@ func TestEndToEnd(t *testing.T) { uri string mimeType string // "": not found; "text/plain": resource; "text/template": template }{ - {"file:///file1.txt", "text/plain"}, + {"file:///info.txt", "text/plain"}, {"file:///nonexistent.txt", ""}, // TODO(jba): add resource template cases when we implement them } { @@ -238,7 +238,7 @@ func TestEndToEnd(t *testing.T) { t.Errorf("%s: not found but expected it to be", tt.uri) } } else { - t.Fatalf("reading %s: %v", tt.uri, err) + t.Errorf("reading %s: %v", tt.uri, err) } } else { if got := rres.Contents.URI; got != tt.uri { diff --git a/internal/mcp/resource.go b/internal/mcp/resource.go new file mode 100644 index 00000000000..2d4bea75e3e --- /dev/null +++ b/internal/mcp/resource.go @@ -0,0 +1,158 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "os" + "path/filepath" + "strings" + + jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" +) + +// A ServerResource associates a Resource with its handler. +type ServerResource struct { + Resource *Resource + Handler ResourceHandler +} + +// A ResourceHandler is a function that reads a resource. +// If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. +type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) (*ReadResourceResult, error) + +// ResourceNotFoundError returns an error indicating that a resource being read could +// not be found. +func ResourceNotFoundError(uri string) error { + return &jsonrpc2.WireError{ + Code: codeResourceNotFound, + Message: "Resource not found", + Data: json.RawMessage(fmt.Sprintf(`{"uri":%q}`, uri)), + } +} + +// The error code to return when a resource isn't found. +// See https://modelcontextprotocol.io/specification/2025-03-26/server/resources#error-handling +// However, the code they chose in in the wrong space +// (see https://github.com/modelcontextprotocol/modelcontextprotocol/issues/509). +// so we pick a different one, arbirarily for now (until they fix it). +// The immediate problem is that jsonprc2 defines -32002 as "server closing". +const codeResourceNotFound = -31002 + +// readFileResource reads from the filesystem at a URI relative to dirFilepath, respecting +// the roots. +// dirFilepath and rootFilepaths are absolute filesystem paths. +func readFileResource(rawURI, dirFilepath string, rootFilepaths []string) ([]byte, error) { + uriFilepath, err := computeURIFilepath(rawURI, dirFilepath, rootFilepaths) + if err != nil { + return nil, err + } + + var data []byte + err = withFile(dirFilepath, uriFilepath, func(f *os.File) error { + var err error + data, err = io.ReadAll(f) + return err + }) + if os.IsNotExist(err) { + err = ResourceNotFoundError(rawURI) + } + return data, err +} + +// computeURIFilepath returns a path relative to dirFilepath. +// The dirFilepath and rootFilepaths are absolute file paths. +func computeURIFilepath(rawURI, dirFilepath string, rootFilepaths []string) (string, error) { + // We use "file path" to mean a filesystem path. + uri, err := url.Parse(rawURI) + if err != nil { + return "", err + } + if uri.Scheme != "file" { + return "", fmt.Errorf("URI is not a file: %s", uri) + } + if uri.Path == "" { + // A more specific error than the one below, to catch the + // common mistake "file://foo". + return "", errors.New("empty path") + } + // The URI's path is interpreted relative to dirFilepath, and in the local filesystem. + // It must not try to escape its directory. + uriFilepathRel, err := filepath.Localize(strings.TrimPrefix(uri.Path, "/")) + if err != nil { + return "", fmt.Errorf("%q cannot be localized: %w", uriFilepathRel, err) + } + + // Check roots, if there are any. + if len(rootFilepaths) > 0 { + // To check against the roots, we need an absolute file path, not relative to the directory. + // uriFilepath is local, so the joined path is under dirFilepath. + uriFilepathAbs := filepath.Join(dirFilepath, uriFilepathRel) + rootOK := false + // Check that the requested file path is under some root. + // Since both paths are absolute, that's equivalent to filepath.Rel constructing + // a local path. + for _, rootFilepathAbs := range rootFilepaths { + if rel, err := filepath.Rel(rootFilepathAbs, uriFilepathAbs); err == nil && filepath.IsLocal(rel) { + rootOK = true + break + } + } + if !rootOK { + return "", fmt.Errorf("URI path %q is not under any root", uriFilepathAbs) + } + } + return uriFilepathRel, nil +} + +// fileRoots transforms the Roots obtained from the client into absolute paths on +// the local filesystem. +// TODO(jba): expose this functionality to user ResourceHandlers, +// so they don't have to repeat it. +func fileRoots(rawRoots []*Root) ([]string, error) { + var fileRoots []string + for _, r := range rawRoots { + fr, err := fileRoot(r) + if err != nil { + return nil, err + } + fileRoots = append(fileRoots, fr) + } + return fileRoots, nil +} + +// fileRoot returns the absolute path for Root. +func fileRoot(root *Root) (_ string, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("root %q: %w", root.URI, err) + } + }() + + // Convert to absolute file path. + rurl, err := url.Parse(root.URI) + if err != nil { + return "", err + } + if rurl.Scheme != "file" { + return "", errors.New("not a file URI") + } + if rurl.Path == "" { + // A more specific error than the one below, to catch the + // common mistake "file://foo". + return "", errors.New("empty path") + } + // We don't want Localize here: we want an absolute path, which is not local. + fileRoot := filepath.Clean(filepath.FromSlash(rurl.Path)) + if !filepath.IsAbs(fileRoot) { + return "", errors.New("not an absolute path") + } + return fileRoot, nil +} diff --git a/internal/mcp/resource_go124.go b/internal/mcp/resource_go124.go new file mode 100644 index 00000000000..af4c4f3b74e --- /dev/null +++ b/internal/mcp/resource_go124.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 + +package mcp + +import ( + "errors" + "os" +) + +// withFile calls f on the file at join(dir, rel), +// protecting against path traversal attacks. +func withFile(dir, rel string, f func(*os.File) error) (err error) { + r, err := os.OpenRoot(dir) + if err != nil { + return err + } + defer r.Close() + file, err := r.Open(rel) + if err != nil { + return err + } + // Record error, in case f writes. + defer func() { err = errors.Join(err, file.Close()) }() + return f(file) +} diff --git a/internal/mcp/resource_pre_go124.go b/internal/mcp/resource_pre_go124.go new file mode 100644 index 00000000000..77981c512d6 --- /dev/null +++ b/internal/mcp/resource_pre_go124.go @@ -0,0 +1,25 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.24 + +package mcp + +import ( + "errors" + "os" + "path/filepath" +) + +// withFile calls f on the file at join(dir, rel). +// It does not protect against path traversal attacks. +func withFile(dir, rel string, f func(*os.File) error) (err error) { + file, err := os.Open(filepath.Join(dir, rel)) + if err != nil { + return err + } + // Record error, in case f writes. + defer func() { err = errors.Join(err, file.Close()) }() + return f(file) +} diff --git a/internal/mcp/resource_test.go b/internal/mcp/resource_test.go new file mode 100644 index 00000000000..28e40eb416f --- /dev/null +++ b/internal/mcp/resource_test.go @@ -0,0 +1,114 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestFileRoot(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("TODO: fix for Windows") + } + + for _, tt := range []struct { + uri string + want string + wantErr string // error must contain this string + }{ + {uri: "file:///foo", want: "/foo"}, + {uri: "file:///foo/bar", want: "/foo/bar"}, + {uri: "file:///foo/../bar", want: "/bar"}, + {uri: "file:/foo", want: "/foo"}, + {uri: "http:///foo", wantErr: "not a file"}, + {uri: "file://foo", wantErr: "empty path"}, + {uri: ":", wantErr: "missing protocol scheme"}, + } { + got, err := fileRoot(&Root{URI: tt.uri}) + if err != nil { + if tt.wantErr == "" { + t.Errorf("%s: got %v, want success", tt.uri, err) + continue + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("%s: got %v, does not contain %q", tt.uri, err, tt.wantErr) + continue + } + } else if tt.wantErr != "" { + t.Errorf("%s: succeeded, but wanted error with %q", tt.uri, tt.wantErr) + } else if got != tt.want { + t.Errorf("%s: got %q, want %q", tt.uri, got, tt.want) + } + } +} + +func TestComputeURIFilepath(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("TODO: fix for Windows") + } + // TODO(jba): test with Windows \\host paths and C: paths + dirFilepath := filepath.FromSlash("/files") + rootFilepaths := []string{ + filepath.FromSlash("/files/public"), + filepath.FromSlash("/files/shared"), + } + for _, tt := range []struct { + uri string + want string + wantErr string // error must contain this + }{ + {"file:///public", "public", ""}, + {"file:///public/file", "public/file", ""}, + {"file:///shared/file", "shared/file", ""}, + {"http:///foo", "", "not a file"}, + {"file://foo", "", "empty"}, + {"file://foo/../bar", "", "localized"}, + {"file:///secret", "", "root"}, + {"file:///secret/file", "", "root"}, + {"file:///private/file", "", "root"}, + } { + t.Run(tt.uri, func(t *testing.T) { + tt.want = filepath.FromSlash(tt.want) // handle Windows + got, gotErr := computeURIFilepath(tt.uri, dirFilepath, rootFilepaths) + if gotErr != nil { + if tt.wantErr == "" { + t.Fatalf("got %v, wanted success", gotErr) + } + if !strings.Contains(gotErr.Error(), tt.wantErr) { + t.Fatalf("got error %v, does not contain %q", gotErr, tt.wantErr) + } + return + } + if tt.wantErr != "" { + t.Fatal("succeeded unexpectedly") + } + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestReadFileResource(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("TODO: fix for Windows") + } + abs, err := filepath.Abs("testdata") + if err != nil { + t.Fatal(err) + } + dirFilepath := filepath.Join(abs, "files") + got, err := readFileResource("file:///info.txt", dirFilepath, nil) + if err != nil { + t.Fatal(err) + } + want := "Contents\n" + if g := string(got); g != want { + t.Errorf("got %q, want %q", g, want) + } +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 24e8f1e6a13..79f4fe0a04f 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -10,6 +10,7 @@ import ( "fmt" "iter" "net/url" + "path/filepath" "slices" "sync" @@ -103,34 +104,6 @@ func (s *Server) RemoveTools(names ...string) { } } -// ResourceNotFoundError returns an error indicating that a resource being read could -// not be found. -func ResourceNotFoundError(uri string) error { - return &jsonrpc2.WireError{ - Code: codeResourceNotFound, - Message: "Resource not found", - Data: json.RawMessage(fmt.Sprintf(`{"uri":%q}`, uri)), - } -} - -// The error code to return when a resource isn't found. -// See https://modelcontextprotocol.io/specification/2025-03-26/server/resources#error-handling -// However, the code they chose in in the wrong space -// (see https://github.com/modelcontextprotocol/modelcontextprotocol/issues/509). -// so we pick a different one, arbirarily for now (until they fix it). -// The immediate problem is that jsonprc2 defines -32002 as "server closing". -const codeResourceNotFound = -31002 - -// A ResourceHandler is a function that reads a resource. -// If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. -type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) (*ReadResourceResult, error) - -// A ServerResource associates a Resource with its handler. -type ServerResource struct { - Resource *Resource - Handler ResourceHandler -} - // AddResource adds the given resource to the server and associates it with // a [ResourceHandler], which will be called when the client calls [ClientSession.ReadResource]. // If a resource with the same URI already exists, this one replaces it. @@ -247,6 +220,49 @@ func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *Re return res, nil } +// FileResourceHandler returns a ReadResourceHandler that reads paths using dir as +// a base directory. +// It honors client roots and protects against path traversal attacks. +// +// The dir argument should be a filesystem path. It need not be absolute, but +// that is recommended to avoid a dependency on the current working directory (the +// check against client roots is done with an absolute path). If dir is not absolute +// and the current working directory is unavailable, FileResourceHandler panics. +// +// Lexical path traversal attacks, where the path has ".." elements that escape dir, +// are always caught. Go 1.24 and above also protects against symlink-based attacks, +// where symlinks under dir lead out of the tree. +func (s *Server) FileResourceHandler(dir string) ResourceHandler { + // Convert dir to an absolute path. + dirFilepath, err := filepath.Abs(dir) + if err != nil { + panic(err) + } + return func(ctx context.Context, ss *ServerSession, params *ReadResourceParams) (_ *ReadResourceResult, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("reading resource %s: %w", params.URI, err) + } + }() + + // TODO: use a memoizing API here. + rootRes, err := ss.ListRoots(ctx, nil) + if err != nil { + return nil, fmt.Errorf("listing roots: %w", err) + } + roots, err := fileRoots(rootRes.Roots) + if err != nil { + return nil, err + } + data, err := readFileResource(params.URI, dirFilepath, roots) + if err != nil { + return nil, err + } + // TODO(jba): figure out mime type. + return &ReadResourceResult{Contents: NewBlobResourceContents(params.URI, "text/plain", data)}, nil + } +} + // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection. diff --git a/internal/mcp/testdata/files/info.txt b/internal/mcp/testdata/files/info.txt new file mode 100644 index 00000000000..dfe437bdebe --- /dev/null +++ b/internal/mcp/testdata/files/info.txt @@ -0,0 +1 @@ +Contents From db456f92b891b21ae6f9c211f209ec1e4e2f17e8 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 16 May 2025 14:28:28 -0400 Subject: [PATCH 097/196] internal/mcp: client middleware Implement middleware for clients. Make the middleware types generic so we can factor out code from both peers. I was able to make the request handler generic, with one weird trick (see handleRequest). The first three declarations of shared.go, and a few others as well, are copied unchanged from middleware.go. Change-Id: Ib3239d8fcfda022ddb5b554c38cff4cdbc3704c1 Reviewed-on: https://go-review.googlesource.com/c/tools/+/673635 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/client.go | 74 +++++++++++------- internal/mcp/design/design.md | 31 +++++--- internal/mcp/mcp_test.go | 54 +++++++++----- internal/mcp/server.go | 133 +++++++++------------------------ internal/mcp/shared.go | 136 ++++++++++++++++++++++++++++++++++ 5 files changed, 274 insertions(+), 154 deletions(-) create mode 100644 internal/mcp/shared.go diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 66d520bd219..f19e779ef20 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -18,12 +18,13 @@ import ( // A Client is an MCP client, which may be connected to an MCP server // using the [Client.Connect] method. type Client struct { - name string - version string - opts ClientOptions - mu sync.Mutex - roots *featureSet[*Root] - sessions []*ClientSession + name string + version string + opts ClientOptions + mu sync.Mutex + roots *featureSet[*Root] + sessions []*ClientSession + methodHandler_ MethodHandler[ClientSession] } // NewClient creates a new Client. @@ -33,9 +34,10 @@ type Client struct { // If non-nil, the provided options configure the Client. func NewClient(name, version string, opts *ClientOptions) *Client { c := &Client{ - name: name, - version: version, - roots: newFeatureSet(func(r *Root) string { return r.URI }), + name: name, + version: version, + roots: newFeatureSet(func(r *Root) string { return r.URI }), + methodHandler_: defaultMethodHandler[ClientSession], } if opts != nil { c.opts = *opts @@ -140,7 +142,7 @@ func (c *Client) RemoveRoots(uris ...string) { c.roots.remove(uris...) } -func (c *Client) listRoots(_ context.Context, _ *ListRootsParams) (*ListRootsResult, error) { +func (c *Client) listRoots(_ context.Context, _ *ClientSession, _ *ListRootsParams) (*ListRootsResult, error) { c.mu.Lock() defer c.mu.Unlock() return &ListRootsResult{ @@ -148,21 +150,43 @@ func (c *Client) listRoots(_ context.Context, _ *ListRootsParams) (*ListRootsRes }, nil } -func (c *ClientSession) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { - // TODO: when we switch to ClientSessions, use a copy of the server's dispatch function, or - // maybe just add another type parameter. - // - // No need to check that the connection is initialized, since we initialize - // it in Connect. - switch req.Method { - case "ping": - // The spec says that 'ping' expects an empty object result. - return struct{}{}, nil - case "roots/list": - // ListRootsParams happens to be unused. - return c.client.listRoots(ctx, nil) - } - return nil, jsonrpc2.ErrNotHandled +// AddMiddleware wraps the client's current method handler using the provided +// middleware. Middleware is applied from right to left, so that the first one +// is executed first. +// +// For example, AddMiddleware(m1, m2, m3) augments the client method handler as +// m1(m2(m3(handler))). +func (c *Client) AddMiddleware(middleware ...Middleware[ClientSession]) { + c.mu.Lock() + defer c.mu.Unlock() + addMiddleware(&c.methodHandler_, middleware) +} + +// clientMethodInfos maps from the RPC method name to serverMethodInfos. +var clientMethodInfos = map[string]methodInfo[ClientSession]{ + "ping": newMethodInfo(sessionMethod((*ClientSession).ping)), + "roots/list": newMethodInfo(clientMethod((*Client).listRoots)), + // TODO: notifications +} + +var _ session[ClientSession] = (*ClientSession)(nil) + +func (cs *ClientSession) methodInfos() map[string]methodInfo[ClientSession] { + return clientMethodInfos +} + +func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { + return handleRequest(ctx, req, cs) +} + +func (cs *ClientSession) methodHandler() MethodHandler[ClientSession] { + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.methodHandler_ +} + +func (c *ClientSession) ping(ct context.Context, params *PingParams) (struct{}, error) { + return struct{}{}, nil } // Ping makes an MCP "ping" request to the server. diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index fa61960aeeb..2cf0f5a2d0c 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -406,23 +406,34 @@ func (*ClientSession) ResourceTemplates(context.Context, *ListResourceTemplatesP ### Middleware -We provide a mechanism to add MCP-level middleware on the server side, which runs after the request has been parsed but before any normal handling. +We provide a mechanism to add MCP-level middleware on the both the client and server side, which runs after the request has been parsed but before any normal handling. ```go -// A ServerMethodHandler dispatches an MCP message to the appropriate handler. -// The params argument will be an XXXParams struct pointer, such as *GetPromptParams. -// The response if err is non-nil should be an XXXResult struct pointer. -type ServerMethodHandler func(ctx context.Context, s *ServerSession, method string, params any) (result any, err error) +// A MethodHandler handles MCP messages. +// The params argument is an XXXParams struct pointer, such as *GetPromptParams. +// For methods, a MethodHandler must return either an XXResult struct pointer and a nil error, or +// nil with a non-nil error. +// For notifications, a MethodHandler must return nil, nil. +type MethodHandler[S ClientSession | ServerSession] func( + ctx context.Context, _ *S, method string, params any) (result any, err error) -// AddMiddlewarecalls each function from right to left on the previous result, beginning -// with the server's current dispatcher, and installs the result as the new dispatcher. -func (*Server) AddMiddleware(middleware ...func(ServerMethodHandler) ServerMethodHandler) +// Middleware is a function from MethodHandlers to MethodHandlers. +type Middleware[S ClientSession | ServerSession] func(MethodHandler[S]) MethodHandler[S] + +// AddMiddleware wraps the client/server's current method handler using the provided +// middleware. Middleware is applied from right to left, so that the first one +// is executed first. +// +// For example, AddMiddleware(m1, m2, m3) augments the server method handler as +// m1(m2(m3(handler))). +func (c *Client) AddMiddleware(middleware ...Middleware[ClientSession]) +func (s *Server) AddMiddleware(middleware ...Middleware[ServerSession]) ``` As an example, this code adds server-side logging: ```go -func withLogging(h mcp.ServerMethodHandler) mcp.ServerMethodHandler{ +func withLogging(h mcp.MethodHandler[ServerSession]) mcp.MethodHandler[ServerSession]{ return func(ctx context.Context, s *mcp.ServerSession, method string, params any) (res any, err error) { log.Printf("request: %s %v", method, params) defer func() { log.Printf("response: %v, %v", res, err) }() @@ -433,8 +444,6 @@ func withLogging(h mcp.ServerMethodHandler) mcp.ServerMethodHandler{ server.AddMiddleware(withLogging) ``` -We will provide the same functionality on the client side as well. - **Differences from mcp-go**: Version 0.26.0 of mcp-go defines 24 server hooks. Each hook consists of a field in the `Hooks` struct, a `Hooks.Add` method, and a type for the hook function. These are rarely used. The most common is `OnError`, which occurs fewer than ten times in open-source code. ### Errors diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index ebd988e4329..ce66e9806eb 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -9,6 +9,7 @@ import ( "context" "errors" "fmt" + "io" "path/filepath" "runtime" "slices" @@ -27,8 +28,8 @@ type hiParams struct { Name string } -func sayHi(ctx context.Context, cc *ServerSession, v hiParams) ([]*Content, error) { - if err := cc.Ping(ctx, nil); err != nil { +func sayHi(ctx context.Context, ss *ServerSession, v hiParams) ([]*Content, error) { + if err := ss.Ping(ctx, nil); err != nil { return nil, fmt.Errorf("ping failed: %v", err) } return []*Content{NewTextContent("hi " + v.Name)}, nil @@ -398,7 +399,7 @@ func TestCancellation(t *testing.T) { } } -func TestAddMiddleware(t *testing.T) { +func TestMiddleware(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() s := NewServer("testServer", "v1.0.0", nil) @@ -416,26 +417,17 @@ func TestAddMiddleware(t *testing.T) { clientWG.Done() }() - var buf bytes.Buffer - buf.WriteByte('\n') - - // traceCalls creates a middleware function that prints the method before and after each call - // with the given prefix. - traceCalls := func(prefix string) func(ServerMethodHandler) ServerMethodHandler { - return func(d ServerMethodHandler) ServerMethodHandler { - return func(ctx context.Context, ss *ServerSession, method string, params any) (any, error) { - fmt.Fprintf(&buf, "%s >%s\n", prefix, method) - defer fmt.Fprintf(&buf, "%s <%s\n", prefix, method) - return d(ctx, ss, method, params) - } - } - } + var sbuf, cbuf bytes.Buffer + sbuf.WriteByte('\n') + cbuf.WriteByte('\n') // "1" is the outer middleware layer, called first; then "2" is called, and finally // the default dispatcher. - s.AddMiddleware(traceCalls("1"), traceCalls("2")) + s.AddMiddleware(traceCalls[ServerSession](&sbuf, "1"), traceCalls[ServerSession](&sbuf, "2")) c := NewClient("testClient", "v1.0.0", nil) + c.AddMiddleware(traceCalls[ClientSession](&cbuf, "1"), traceCalls[ClientSession](&cbuf, "2")) + cs, err := c.Connect(ctx, ct) if err != nil { t.Fatal(err) @@ -453,9 +445,33 @@ func TestAddMiddleware(t *testing.T) { 2 roots/list +2 >roots/list +2 %s\n", prefix, method) + defer fmt.Fprintf(w, "%s <%s\n", prefix, method) + return h(ctx, sess, method, params) + } + } } var falseSchema = &jsonschema.Schema{Not: &jsonschema.Schema{}} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 79f4fe0a04f..9ef1f4fc552 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -6,7 +6,6 @@ package mcp import ( "context" - "encoding/json" "fmt" "iter" "net/url" @@ -27,12 +26,12 @@ type Server struct { version string opts ServerOptions - mu sync.Mutex - prompts *featureSet[*ServerPrompt] - tools *featureSet[*ServerTool] - resources *featureSet[*ServerResource] - sessions []*ServerSession - methodHandler ServerMethodHandler + mu sync.Mutex + prompts *featureSet[*ServerPrompt] + tools *featureSet[*ServerTool] + resources *featureSet[*ServerResource] + sessions []*ServerSession + methodHandler_ MethodHandler[ServerSession] } // ServerOptions is used to configure behavior of the server. @@ -52,13 +51,13 @@ func NewServer(name, version string, opts *ServerOptions) *Server { opts = new(ServerOptions) } return &Server{ - name: name, - version: version, - opts: *opts, - prompts: newFeatureSet(func(p *ServerPrompt) string { return p.Prompt.Name }), - tools: newFeatureSet(func(t *ServerTool) string { return t.Tool.Name }), - resources: newFeatureSet(func(r *ServerResource) string { return r.Resource.URI }), - methodHandler: defaulMethodHandler, + name: name, + version: version, + opts: *opts, + prompts: newFeatureSet(func(p *ServerPrompt) string { return p.Prompt.Name }), + tools: newFeatureSet(func(t *ServerTool) string { return t.Tool.Name }), + resources: newFeatureSet(func(r *ServerResource) string { return r.Resource.URI }), + methodHandler_: defaultMethodHandler[ServerSession], } } @@ -321,7 +320,7 @@ type ServerSession struct { // Ping pings the client. func (ss *ServerSession) Ping(ctx context.Context, _ *PingParams) error { - return call(ctx, ss.conn, "ping", nil, nil) + return call(ctx, ss.conn, "ping", (*PingParams)(nil), nil) } // ListRoots lists the client roots. @@ -329,67 +328,20 @@ func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) return standardCall[ListRootsResult](ctx, ss.conn, "roots/list", params) } -// A ServerMethodHandler handles MCP messages from client to server. -// The params argument is an XXXParams struct pointer, such as *GetPromptParams. -// For methods, a MethodHandler must return either -// an XXResult struct pointer and a nil error, or -// nil with a non-nil error. -// For notifications, a MethodHandler must return nil, nil. -type ServerMethodHandler func(ctx context.Context, _ *ServerSession, method string, params any) (result any, err error) - // AddMiddleware wraps the server's current method handler using the provided // middleware. Middleware is applied from right to left, so that the first one // is executed first. // -// For example, AddMiddleware(m1, m2, m3) augments the server handler as +// For example, AddMiddleware(m1, m2, m3) augments the server method handler as // m1(m2(m3(handler))). -func (s *Server) AddMiddleware(middleware ...func(ServerMethodHandler) ServerMethodHandler) { +func (s *Server) AddMiddleware(middleware ...Middleware[ServerSession]) { s.mu.Lock() defer s.mu.Unlock() - for _, m := range slices.Backward(middleware) { - s.methodHandler = m(s.methodHandler) - } -} - -// defaulMethodHandler is the initial method handler installed on the server. -func defaulMethodHandler(ctx context.Context, ss *ServerSession, method string, params any) (any, error) { - info, ok := methodInfos[method] - assert(ok, "called with unknown method") - return info.handleMethod(ctx, ss, method, params) -} - -// methodInfo is information about invoking a method. -type methodInfo struct { - // unmarshal params from the wire into an XXXParams struct - unmarshalParams func(json.RawMessage) (any, error) - // run the code for the method - handleMethod ServerMethodHandler -} - -// The following definitions support converting from typed to untyped method handlers. -// Throughout, P is the type parameter for params, and R is the one for result. - -// A typedMethodHandler is like a MethodHandler, but with type information. -type typedMethodHandler[P, R any] func(context.Context, *ServerSession, P) (R, error) - -// newMethodInfo creates a methodInfo from a typedMethodHandler. -func newMethodInfo[P, R any](d typedMethodHandler[P, R]) methodInfo { - return methodInfo{ - unmarshalParams: func(m json.RawMessage) (any, error) { - var p P - if err := json.Unmarshal(m, &p); err != nil { - return nil, err - } - return p, nil - }, - handleMethod: func(ctx context.Context, ss *ServerSession, _ string, params any) (any, error) { - return d(ctx, ss, params.(P)) - }, - } + addMiddleware(&s.methodHandler_, middleware) } -// methodInfos maps from the RPC method name to methodInfos. -var methodInfos = map[string]methodInfo{ +// serverMethodInfos maps from the RPC method name to serverMethodInfos. +var serverMethodInfos = map[string]methodInfo[ServerSession]{ "initialize": newMethodInfo(sessionMethod((*ServerSession).initialize)), "ping": newMethodInfo(sessionMethod((*ServerSession).ping)), "prompts/list": newMethodInfo(serverMethod((*Server).listPrompts)), @@ -401,18 +353,18 @@ var methodInfos = map[string]methodInfo{ // TODO: notifications } -// serverMethod is glue for creating a typedMethodHandler from a method on Server. -func serverMethod[P, R any](f func(*Server, context.Context, *ServerSession, P) (R, error)) typedMethodHandler[P, R] { - return func(ctx context.Context, ss *ServerSession, p P) (R, error) { - return f(ss.server, ctx, ss, p) - } +// *ServerSession implements the session interface. +// See toSession for why this interface seems to be necessary. +var _ session[ServerSession] = (*ServerSession)(nil) + +func (ss *ServerSession) methodInfos() map[string]methodInfo[ServerSession] { + return serverMethodInfos } -// sessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. -func sessionMethod[P, R any](f func(*ServerSession, context.Context, P) (R, error)) typedMethodHandler[P, R] { - return func(ctx context.Context, ss *ServerSession, p P) (R, error) { - return f(ss, ctx, p) - } +func (ss *ServerSession) methodHandler() MethodHandler[ServerSession] { + ss.server.mu.Lock() + defer ss.server.mu.Unlock() + return ss.server.methodHandler_ } // handle invokes the method described by the given JSON RPC request. @@ -430,27 +382,10 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc2.Request) (any return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } } - - // TODO: embed the incoming request ID in the client context (or, more likely, + // TODO(rfindley): embed the incoming request ID in the client context (or, more likely, // a wrapper around it), so that we can correlate responses and notifications // to the handler; this is required for the new session-based transport. - info, ok := methodInfos[req.Method] - if !ok { - return nil, jsonrpc2.ErrNotHandled - } - params, err := info.unmarshalParams(req.Params) - if err != nil { - return nil, fmt.Errorf("ServerSession:handle %q: %w", req.Method, err) - } - ss.server.mu.Lock() - d := ss.server.methodHandler - ss.server.mu.Unlock() - // d might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. - res, err := d(ctx, ss, req.Method, params) - if err != nil { - return nil, err - } - return res, nil + return handleRequest(ctx, req, ss) } func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) { @@ -458,8 +393,8 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam ss.initializeParams = params ss.mu.Unlock() - // Mark the connection as initialized when this method exits. TODO: - // Technically, the server should not be considered initialized until it has + // Mark the connection as initialized when this method exits. + // TODO: Technically, the server should not be considered initialized until it has // *responded*, but we don't have adequate visibility into the jsonrpc2 // connection to implement that easily. In any case, once we've initialized // here, we can handle requests. @@ -488,7 +423,7 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam }, nil } -func (ss *ServerSession) ping(context.Context, struct{}) (struct{}, error) { +func (ss *ServerSession) ping(context.Context, *PingParams) (struct{}, error) { return struct{}{}, nil } diff --git a/internal/mcp/shared.go b/internal/mcp/shared.go new file mode 100644 index 00000000000..cbb40b006c5 --- /dev/null +++ b/internal/mcp/shared.go @@ -0,0 +1,136 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file contains code shared between client and server, including +// method handler and middleware definitions. +// TODO: much of this is here so that we can factor out commonalities using +// generics. Perhaps it can be simplified with reflection. + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "slices" + + jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" +) + +// A MethodHandler handles MCP messages. +// The params argument is an XXXParams struct pointer, such as *GetPromptParams. +// For methods, a MethodHandler must return either an XXResult struct pointer and a nil error, or +// nil with a non-nil error. +// For notifications, a MethodHandler must return nil, nil. +type MethodHandler[S ClientSession | ServerSession] func( + ctx context.Context, _ *S, method string, params any) (result any, err error) + +// Middleware is a function from MethodHandlers to MethodHandlers. +type Middleware[S ClientSession | ServerSession] func(MethodHandler[S]) MethodHandler[S] + +// addMiddleware wraps the handler in the middleware functions. +func addMiddleware[S ClientSession | ServerSession](handlerp *MethodHandler[S], middleware []Middleware[S]) { + for _, m := range slices.Backward(middleware) { + *handlerp = m(*handlerp) + } +} + +// session has methods common to both ClientSession and ServerSession. +type session[S ClientSession | ServerSession] interface { + methodHandler() MethodHandler[S] + methodInfos() map[string]methodInfo[S] +} + +// toSession[S] converts its argument to a session[S]. +// Note that since S is constrained to ClientSession | ServerSession, and pointers to those +// types both implement session[S] already, this should be a no-op. +// That it is not, is due (I believe) to a deficency in generics, possibly related to core types. +// TODO(jba): revisit in Go 1.26; perhaps the change in spec due to the removal of core types +// will have resulted by then in a more generous implementation. +func toSession[S ClientSession | ServerSession](sess *S) session[S] { + return any(sess).(session[S]) +} + +// defaultMethodHandler is the initial MethodHandler for servers and clients, before being wrapped by middleware. +func defaultMethodHandler[S ClientSession | ServerSession](ctx context.Context, sess *S, method string, params any) (any, error) { + info, ok := toSession(sess).methodInfos()[method] + if !ok { + // This can be called from user code, with an arbitrary value for method. + return nil, jsonrpc2.ErrNotHandled + } + return info.handleMethod(ctx, sess, method, params) +} + +func handleRequest[S ClientSession | ServerSession](ctx context.Context, req *jsonrpc2.Request, sess *S) (any, error) { + info, ok := toSession(sess).methodInfos()[req.Method] + if !ok { + return nil, jsonrpc2.ErrNotHandled + } + params, err := info.unmarshalParams(req.Params) + if err != nil { + return nil, fmt.Errorf("handleRequest %q: %w", req.Method, err) + } + + // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. + mh := toSession(sess).methodHandler() + res, err := mh(ctx, sess, req.Method, params) + if err != nil { + return nil, err + } + return res, nil +} + +// methodInfo is information about invoking a method. +type methodInfo[TSession ClientSession | ServerSession] struct { + // unmarshal params from the wire into an XXXParams struct + unmarshalParams func(json.RawMessage) (any, error) + // run the code for the method + handleMethod MethodHandler[TSession] +} + +// The following definitions support converting from typed to untyped method handlers. +// Type parameter meanings: +// - S: sessions +// - P: params +// - R: results + +// A typedMethodHandler is like a MethodHandler, but with type information. +type typedMethodHandler[S, P, R any] func(context.Context, *S, P) (R, error) + +// newMethodInfo creates a methodInfo from a typedMethodHandler. +func newMethodInfo[S ClientSession | ServerSession, P, R any](d typedMethodHandler[S, P, R]) methodInfo[S] { + return methodInfo[S]{ + unmarshalParams: func(m json.RawMessage) (any, error) { + var p P + if err := json.Unmarshal(m, &p); err != nil { + return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) + } + return p, nil + }, + handleMethod: func(ctx context.Context, ss *S, _ string, params any) (any, error) { + return d(ctx, ss, params.(P)) + }, + } +} + +// serverMethod is glue for creating a typedMethodHandler from a method on Server. +func serverMethod[P, R any](f func(*Server, context.Context, *ServerSession, P) (R, error)) typedMethodHandler[ServerSession, P, R] { + return func(ctx context.Context, ss *ServerSession, p P) (R, error) { + return f(ss.server, ctx, ss, p) + } +} + +// clientMethod is glue for creating a typedMethodHandler from a method on Server. +func clientMethod[P, R any](f func(*Client, context.Context, *ClientSession, P) (R, error)) typedMethodHandler[ClientSession, P, R] { + return func(ctx context.Context, cs *ClientSession, p P) (R, error) { + return f(cs.client, ctx, cs, p) + } +} + +// sessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. +func sessionMethod[S ClientSession | ServerSession, P, R any](f func(*S, context.Context, P) (R, error)) typedMethodHandler[S, P, R] { + return func(ctx context.Context, sess *S, p P) (R, error) { + return f(sess, ctx, p) + } +} From 86158bdc5721543ffce92b6f9920d4dbcb4a3db8 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 16 May 2025 08:53:48 -0400 Subject: [PATCH 098/196] internal/mcp: sampling Implement sampling (the CreateMessage method). Change-Id: Ie44a671dd62831122b25cd7386b5bfcdece7ff60 Reviewed-on: https://go-review.googlesource.com/c/tools/+/673176 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/client.go | 26 ++++++++-- internal/mcp/generate.go | 20 +++++++- internal/mcp/mcp_test.go | 42 ++++++++++++---- internal/mcp/protocol.go | 103 ++++++++++++++++++++++++++++++++++++++- internal/mcp/resource.go | 10 +--- internal/mcp/server.go | 5 ++ internal/mcp/shared.go | 13 +++++ 7 files changed, 192 insertions(+), 27 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index f19e779ef20..ab631a0f7c9 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -46,7 +46,11 @@ func NewClient(name, version string, opts *ClientOptions) *Client { } // ClientOptions configures the behavior of the client. -type ClientOptions struct{} +type ClientOptions struct { + // Handler for sampling. + // Called when a server calls CreateMessage. + CreateMessageHandler func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) +} // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. @@ -83,8 +87,13 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e if err != nil { return nil, err } + caps := &ClientCapabilities{} + if c.opts.CreateMessageHandler != nil { + caps.Sampling = &SamplingCapabilities{} + } params := &InitializeParams{ - ClientInfo: &implementation{Name: c.name, Version: c.version}, + ClientInfo: &implementation{Name: c.name, Version: c.version}, + Capabilities: caps, } if err := call(ctx, cs.conn, "initialize", params, &cs.initializeResult); err != nil { _ = cs.Close() @@ -150,6 +159,14 @@ func (c *Client) listRoots(_ context.Context, _ *ClientSession, _ *ListRootsPara }, nil } +func (c *Client) createMessage(ctx context.Context, cs *ClientSession, params *CreateMessageParams) (*CreateMessageResult, error) { + if c.opts.CreateMessageHandler == nil { + // TODO: wrap or annotate this error? Pick a standard code? + return nil, &jsonrpc2.WireError{Code: CodeUnsupportedMethod, Message: "client does not support CreateMessage"} + } + return c.opts.CreateMessageHandler(ctx, cs, params) +} + // AddMiddleware wraps the client's current method handler using the provided // middleware. Middleware is applied from right to left, so that the first one // is executed first. @@ -164,8 +181,9 @@ func (c *Client) AddMiddleware(middleware ...Middleware[ClientSession]) { // clientMethodInfos maps from the RPC method name to serverMethodInfos. var clientMethodInfos = map[string]methodInfo[ClientSession]{ - "ping": newMethodInfo(sessionMethod((*ClientSession).ping)), - "roots/list": newMethodInfo(clientMethod((*Client).listRoots)), + "ping": newMethodInfo(sessionMethod((*ClientSession).ping)), + "roots/list": newMethodInfo(clientMethod((*Client).listRoots)), + "sampling/createMessage": newMethodInfo(clientMethod((*Client).createMessage)), // TODO: notifications } diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index 2c549cde126..35dea980777 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -68,7 +68,14 @@ var declarations = config{ Name: "-", Fields: config{"Params": {Name: "CancelledParams"}}, }, - "ClientCapabilities": {}, + "ClientCapabilities": { + Fields: config{"Sampling": {Name: "SamplingCapabilities"}}, + }, + "CreateMessageRequest": { + Name: "-", + Fields: config{"Params": {Name: "CreateMessageParams"}}, + }, + "CreateMessageResult": {}, "GetPromptRequest": { Name: "-", Fields: config{"Params": {Name: "GetPromptParams"}}, @@ -103,7 +110,9 @@ var declarations = config{ Name: "-", Fields: config{"Params": {Name: "ListToolsParams"}}, }, - "ListToolsResult": {}, + "ListToolsResult": {}, + "ModelHint": {}, + "ModelPreferences": {}, "PingRequest": { Name: "-", Fields: config{"Params": {Name: "PingParams"}}, @@ -124,6 +133,8 @@ var declarations = config{ "Role": {}, "Root": {}, + "SamplingCapabilities": {Substitute: "struct{}"}, + "SamplingMessage": {}, "ServerCapabilities": { Name: "serverCapabilities", Fields: config{ @@ -350,6 +361,11 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma fieldTypeSchema = rs } needPointer := isStruct(fieldTypeSchema) + // Special case: there are no sampling capabilities defined, but + // we want it to be a struct for future expansion. + if !needPointer && name == "sampling" { + needPointer = true + } if config != nil && config.Fields[export] != nil { r := config.Fields[export] if r.Substitute != "" { diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index ce66e9806eb..98ced31bdc6 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -85,7 +85,12 @@ func TestEndToEnd(t *testing.T) { clientWG.Done() }() - c := NewClient("testClient", "v1.0.0", nil) + opts := &ClientOptions{ + CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel"}, nil + }, + } + c := NewClient("testClient", "v1.0.0", opts) rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) if err != nil { t.Fatal(err) @@ -233,8 +238,7 @@ func TestEndToEnd(t *testing.T) { } { rres, err := cs.ReadResource(ctx, &ReadResourceParams{URI: tt.uri}) if err != nil { - var werr *jsonrpc2.WireError - if errors.As(err, &werr) && werr.Code == codeResourceNotFound { + if code := errorCode(err); code == CodeResourceNotFound { if tt.mimeType != "" { t.Errorf("%s: not found but expected it to be", tt.uri) } @@ -252,13 +256,7 @@ func TestEndToEnd(t *testing.T) { } }) t.Run("roots", func(t *testing.T) { - // Take the server's first ServerSession. - var sc *ServerSession - for sc = range s.Sessions() { - break - } - - rootRes, err := sc.ListRoots(ctx, &ListRootsParams{}) + rootRes, err := ss.ListRoots(ctx, &ListRootsParams{}) if err != nil { t.Fatal(err) } @@ -268,6 +266,16 @@ func TestEndToEnd(t *testing.T) { t.Errorf("roots/list mismatch (-want +got):\n%s", diff) } }) + t.Run("sampling", func(t *testing.T) { + // TODO: test that a client that doesn't have the handler returns CodeUnsupportedMethod. + res, err := ss.CreateMessage(ctx, &CreateMessageParams{}) + if err != nil { + t.Fatal(err) + } + if g, w := res.Model, "aModel"; g != w { + t.Errorf("got %q, want %q", g, w) + } + }) // Disconnect. cs.Close() @@ -280,6 +288,20 @@ func TestEndToEnd(t *testing.T) { } } +// errorCode returns the code associated with err. +// If err is nil, it returns 0. +// If there is no code, it returns -1. +func errorCode(err error) int64 { + if err == nil { + return 0 + } + var werr *jsonrpc2.WireError + if errors.As(err, &werr) { + return werr.Code + } + return -1 +} + // basicConnection returns a new basic client-server connection configured with // the provided tools. // diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index 7b68e067edb..66f898c2aa3 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -78,8 +78,45 @@ type ClientCapabilities struct { ListChanged bool `json:"listChanged,omitempty"` } `json:"roots,omitempty"` // Present if the client supports sampling from an LLM. - Sampling struct { - } `json:"sampling,omitempty"` + Sampling *SamplingCapabilities `json:"sampling,omitempty"` +} + +type CreateMessageParams struct { + // A request to include context from one or more MCP servers (including the + // caller), to be attached to the prompt. The client MAY ignore this request. + IncludeContext string `json:"includeContext,omitempty"` + // The maximum number of tokens to sample, as requested by the server. The + // client MAY choose to sample fewer tokens than requested. + MaxTokens int64 `json:"maxTokens"` + Messages []*SamplingMessage `json:"messages"` + // Optional metadata to pass through to the LLM provider. The format of this + // metadata is provider-specific. + Metadata struct { + } `json:"metadata,omitempty"` + // The server's preferences for which model to select. The client MAY ignore + // these preferences. + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + // An optional system prompt the server wants to use for sampling. The client + // MAY modify or omit this prompt. + SystemPrompt string `json:"systemPrompt,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +// The client's response to a sampling/create_message request from the server. +// The client should inform the user before returning the sampled message, to +// allow them to inspect the response (human in the loop) and decide whether to +// allow the server to see it. +type CreateMessageResult struct { + // This result property is reserved by the protocol to allow clients and servers + // to attach additional metadata to their responses. + Meta map[string]json.RawMessage `json:"_meta,omitempty"` + Content *Content `json:"content"` + // The name of the model that generated the message. + Model string `json:"model"` + Role Role `json:"role"` + // The reason why sampling stopped, if known. + StopReason string `json:"stopReason,omitempty"` } type GetPromptParams struct { @@ -204,6 +241,58 @@ type ListToolsResult struct { Tools []*Tool `json:"tools"` } +// Hints to use for model selection. +// +// Keys not declared here are currently left unspecified by the spec and are up +// to the client to interpret. +type ModelHint struct { + // A hint for a model name. + // + // The client SHOULD treat this as a substring of a model name; for example: - + // `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022` - `sonnet` + // should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc. - + // `claude` should match any Claude model + // + // The client MAY also map the string to a different provider's model name or a + // different model family, as long as it fills a similar niche; for example: - + // `gemini-1.5-flash` could match `claude-3-haiku-20240307` + Name string `json:"name,omitempty"` +} + +// The server's preferences for model selection, requested of the client during +// sampling. +// +// Because LLMs can vary along multiple dimensions, choosing the "best" model is +// rarely straightforward. Different models excel in different areas—some are +// faster but less capable, others are more capable but more expensive, and so +// on. This interface allows servers to express their priorities across multiple +// dimensions to help clients make an appropriate selection for their use case. +// +// These preferences are always advisory. The client MAY ignore them. It is also +// up to the client to decide how to interpret these preferences and how to +// balance them against other considerations. +type ModelPreferences struct { + // How much to prioritize cost when selecting a model. A value of 0 means cost + // is not important, while a value of 1 means cost is the most important factor. + CostPriority float64 `json:"costPriority,omitempty"` + // Optional hints to use for model selection. + // + // If multiple hints are specified, the client MUST evaluate them in order (such + // that the first match is taken). + // + // The client SHOULD prioritize these hints over the numeric priorities, but MAY + // still use the priorities to select from ambiguous matches. + Hints []*ModelHint `json:"hints,omitempty"` + // How much to prioritize intelligence and capabilities when selecting a model. + // A value of 0 means intelligence is not important, while a value of 1 means + // intelligence is the most important factor. + IntelligencePriority float64 `json:"intelligencePriority,omitempty"` + // How much to prioritize sampling speed (latency) when selecting a model. A + // value of 0 means speed is not important, while a value of 1 means speed is + // the most important factor. + SpeedPriority float64 `json:"speedPriority,omitempty"` +} + type PingParams struct { Meta struct { // If specified, the caller is requesting out-of-band progress notifications for @@ -297,6 +386,16 @@ type Root struct { URI string `json:"uri"` } +// Present if the client supports sampling from an LLM. +type SamplingCapabilities struct { +} + +// Describes a message issued to or received from an LLM API. +type SamplingMessage struct { + Content *Content `json:"content"` + Role Role `json:"role"` +} + // Definition for a tool the client can call. type Tool struct { // Optional additional tool information. diff --git a/internal/mcp/resource.go b/internal/mcp/resource.go index 2d4bea75e3e..e09abe168c0 100644 --- a/internal/mcp/resource.go +++ b/internal/mcp/resource.go @@ -32,20 +32,12 @@ type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) // not be found. func ResourceNotFoundError(uri string) error { return &jsonrpc2.WireError{ - Code: codeResourceNotFound, + Code: CodeResourceNotFound, Message: "Resource not found", Data: json.RawMessage(fmt.Sprintf(`{"uri":%q}`, uri)), } } -// The error code to return when a resource isn't found. -// See https://modelcontextprotocol.io/specification/2025-03-26/server/resources#error-handling -// However, the code they chose in in the wrong space -// (see https://github.com/modelcontextprotocol/modelcontextprotocol/issues/509). -// so we pick a different one, arbirarily for now (until they fix it). -// The immediate problem is that jsonprc2 defines -32002 as "server closing". -const codeResourceNotFound = -31002 - // readFileResource reads from the filesystem at a URI relative to dirFilepath, respecting // the roots. // dirFilepath and rootFilepaths are absolute filesystem paths. diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 9ef1f4fc552..7b22179d7d2 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -328,6 +328,11 @@ func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) return standardCall[ListRootsResult](ctx, ss.conn, "roots/list", params) } +// CreateMessage sends a sampling request to the client. +func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { + return standardCall[CreateMessageResult](ctx, ss.conn, "sampling/createMessage", params) +} + // AddMiddleware wraps the server's current method handler using the provided // middleware. Middleware is applied from right to left, so that the first one // is executed first. diff --git a/internal/mcp/shared.go b/internal/mcp/shared.go index cbb40b006c5..e82ce20cab5 100644 --- a/internal/mcp/shared.go +++ b/internal/mcp/shared.go @@ -134,3 +134,16 @@ func sessionMethod[S ClientSession | ServerSession, P, R any](f func(*S, context return f(sess, ctx, p) } } + +// Error codes +const ( + // The error code to return when a resource isn't found. + // See https://modelcontextprotocol.io/specification/2025-03-26/server/resources#error-handling + // However, the code they chose in in the wrong space + // (see https://github.com/modelcontextprotocol/modelcontextprotocol/issues/509). + // so we pick a different one, arbirarily for now (until they fix it). + // The immediate problem is that jsonprc2 defines -32002 as "server closing". + CodeResourceNotFound = -31002 + // The error code if the method exists and was called properly, but the peer does not support it. + CodeUnsupportedMethod = -31001 +) From 6202e585494d79a991f1a7a44b09498accdf5c16 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 17 May 2025 13:02:49 -0400 Subject: [PATCH 099/196] internal/mcp: generate method names Augment the generator to create constants for all method names. Change-Id: I858e516b1afec0aa3cf3b9852cbdfc16aec4c553 Reviewed-on: https://go-review.googlesource.com/c/tools/+/673177 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Reviewed-by: Sam Thanawalla --- internal/mcp/client.go | 22 +++++++++++----------- internal/mcp/generate.go | 18 ++++++++++++++++++ internal/mcp/protocol.go | 27 +++++++++++++++++++++++++++ internal/mcp/server.go | 20 ++++++++++---------- 4 files changed, 66 insertions(+), 21 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index ab631a0f7c9..23cd53726a8 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -99,7 +99,7 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e _ = cs.Close() return nil, err } - if err := cs.conn.Notify(ctx, "notifications/initialized", &InitializedParams{}); err != nil { + if err := cs.conn.Notify(ctx, notificationInitialized, &InitializedParams{}); err != nil { _ = cs.Close() return nil, err } @@ -181,9 +181,9 @@ func (c *Client) AddMiddleware(middleware ...Middleware[ClientSession]) { // clientMethodInfos maps from the RPC method name to serverMethodInfos. var clientMethodInfos = map[string]methodInfo[ClientSession]{ - "ping": newMethodInfo(sessionMethod((*ClientSession).ping)), - "roots/list": newMethodInfo(clientMethod((*Client).listRoots)), - "sampling/createMessage": newMethodInfo(clientMethod((*Client).createMessage)), + methodPing: newMethodInfo(sessionMethod((*ClientSession).ping)), + methodListRoots: newMethodInfo(clientMethod((*Client).listRoots)), + methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage)), // TODO: notifications } @@ -209,22 +209,22 @@ func (c *ClientSession) ping(ct context.Context, params *PingParams) (struct{}, // Ping makes an MCP "ping" request to the server. func (c *ClientSession) Ping(ctx context.Context, params *PingParams) error { - return call(ctx, c.conn, "ping", params, nil) + return call(ctx, c.conn, methodPing, params, nil) } // ListPrompts lists prompts that are currently available on the server. func (c *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { - return standardCall[ListPromptsResult](ctx, c.conn, "prompts/list", params) + return standardCall[ListPromptsResult](ctx, c.conn, methodListPrompts, params) } // GetPrompt gets a prompt from the server. func (c *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { - return standardCall[GetPromptResult](ctx, c.conn, "prompts/get", params) + return standardCall[GetPromptResult](ctx, c.conn, methodGetPrompt, params) } // ListTools lists tools that are currently available on the server. func (c *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { - return standardCall[ListToolsResult](ctx, c.conn, "tools/list", params) + return standardCall[ListToolsResult](ctx, c.conn, methodListTools, params) } // CallTool calls the tool with the given name and arguments. @@ -244,7 +244,7 @@ func (c *ClientSession) CallTool(ctx context.Context, name string, args map[stri Name: name, Arguments: json.RawMessage(data), } - return standardCall[CallToolResult](ctx, c.conn, "tools/call", params) + return standardCall[CallToolResult](ctx, c.conn, methodCallTool, params) } // NOTE: the following struct should consist of all fields of callToolParams except name and arguments. @@ -256,12 +256,12 @@ type CallToolOptions struct { // ListResources lists the resources that are currently available on the server. func (c *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { - return standardCall[ListResourcesResult](ctx, c.conn, "resources/list", params) + return standardCall[ListResourcesResult](ctx, c.conn, methodListResources, params) } // ReadResource ask the server to read a resource and return its contents. func (c *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { - return standardCall[ReadResourceResult](ctx, c.conn, "resources/read", params) + return standardCall[ReadResourceResult](ctx, c.conn, methodReadResource, params) } func standardCall[TRes, TParams any](ctx context.Context, conn *jsonrpc2.Connection, method string, params TParams) (*TRes, error) { diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index 35dea980777..3a67bd6b8db 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -202,6 +202,24 @@ import ( fmt.Fprintln(buf) fmt.Fprint(buf, b.String()) } + // Write out method names. + fmt.Fprintln(buf, `const (`) + for name, s := range schema.Definitions { + prefix := "method" + method, found := strings.CutSuffix(name, "Request") + if !found { + prefix = "notification" + method, found = strings.CutSuffix(name, "Notification") + } + if found { + if ms, ok := s.Properties["method"]; ok { + if c := ms.Const; c != nil { + fmt.Fprintf(buf, "%s%s = %q\n", prefix, method, *c) + } + } + } + } + fmt.Fprintln(buf, `)`) formatted, err := format.Source(buf.Bytes()) if err != nil { diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index 66f898c2aa3..7b3ca5fc8b9 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -494,3 +494,30 @@ type toolCapabilities struct { // Whether this server supports notifications for changes to the tool list. ListChanged bool `json:"listChanged,omitempty"` } + +const ( + notificationCancelled = "notifications/cancelled" + methodInitialize = "initialize" + notificationProgress = "notifications/progress" + methodSetLevel = "logging/setLevel" + methodCreateMessage = "sampling/createMessage" + notificationResourceListChanged = "notifications/resources/list_changed" + notificationInitialized = "notifications/initialized" + methodUnsubscribe = "resources/unsubscribe" + notificationLoggingMessage = "notifications/message" + methodSubscribe = "resources/subscribe" + methodComplete = "completion/complete" + methodCallTool = "tools/call" + notificationPromptListChanged = "notifications/prompts/list_changed" + methodReadResource = "resources/read" + methodListResourceTemplates = "resources/templates/list" + methodListRoots = "roots/list" + notificationToolListChanged = "notifications/tools/list_changed" + methodGetPrompt = "prompts/get" + methodListPrompts = "prompts/list" + methodPing = "ping" + notificationRootsListChanged = "notifications/roots/list_changed" + methodListTools = "tools/list" + methodListResources = "resources/list" + notificationResourceUpdated = "notifications/resources/updated" +) diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 7b22179d7d2..cb94158cdb6 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -325,12 +325,12 @@ func (ss *ServerSession) Ping(ctx context.Context, _ *PingParams) error { // ListRoots lists the client roots. func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { - return standardCall[ListRootsResult](ctx, ss.conn, "roots/list", params) + return standardCall[ListRootsResult](ctx, ss.conn, methodListRoots, params) } // CreateMessage sends a sampling request to the client. func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { - return standardCall[CreateMessageResult](ctx, ss.conn, "sampling/createMessage", params) + return standardCall[CreateMessageResult](ctx, ss.conn, methodCreateMessage, params) } // AddMiddleware wraps the server's current method handler using the provided @@ -347,14 +347,14 @@ func (s *Server) AddMiddleware(middleware ...Middleware[ServerSession]) { // serverMethodInfos maps from the RPC method name to serverMethodInfos. var serverMethodInfos = map[string]methodInfo[ServerSession]{ - "initialize": newMethodInfo(sessionMethod((*ServerSession).initialize)), - "ping": newMethodInfo(sessionMethod((*ServerSession).ping)), - "prompts/list": newMethodInfo(serverMethod((*Server).listPrompts)), - "prompts/get": newMethodInfo(serverMethod((*Server).getPrompt)), - "tools/list": newMethodInfo(serverMethod((*Server).listTools)), - "tools/call": newMethodInfo(serverMethod((*Server).callTool)), - "resources/list": newMethodInfo(serverMethod((*Server).listResources)), - "resources/read": newMethodInfo(serverMethod((*Server).readResource)), + methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize)), + methodPing: newMethodInfo(sessionMethod((*ServerSession).ping)), + methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts)), + methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt)), + methodListTools: newMethodInfo(serverMethod((*Server).listTools)), + methodCallTool: newMethodInfo(serverMethod((*Server).callTool)), + methodListResources: newMethodInfo(serverMethod((*Server).listResources)), + methodReadResource: newMethodInfo(serverMethod((*Server).readResource)), // TODO: notifications } From 77de774c241337a034fad919afa8417c17667471 Mon Sep 17 00:00:00 2001 From: Madeline Kalil Date: Mon, 28 Apr 2025 16:40:46 -0400 Subject: [PATCH 100/196] gopls/internal/golang: modify extract behavior for error handling returns Check if all return statements in the extracted code are of the form if err != nil { return ..., err }. If this is the case, then we bypass the logic that creates the "shouldReturn" clause, instead adding an error check after the call to the new function. Fixes golang/go#66289 Change-Id: I88bcc80772d442bc264630dabd8d62916e5a7cb5 Reviewed-on: https://go-review.googlesource.com/c/tools/+/668676 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- gopls/internal/golang/extract.go | 168 +++++++++++------- .../testdata/codeaction/extract_control.txt | 8 +- .../codeaction/extract_return_err.txt | 140 +++++++++++++++ .../codeaction/functionextraction.txt | 16 +- .../functionextraction_issue66289.txt | 18 +- 5 files changed, 267 insertions(+), 83 deletions(-) create mode 100644 gopls/internal/test/marker/testdata/codeaction/extract_return_err.txt diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 70a84f6159d..cc0ee536b1c 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -27,7 +27,6 @@ import ( goplsastutil "golang.org/x/tools/gopls/internal/util/astutil" "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" - "golang.org/x/tools/internal/analysisinternal" "golang.org/x/tools/internal/typesinternal" ) @@ -628,31 +627,75 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to // A return statement is non-nested if its parent node is equal to the parent node // of the first node in the selection. These cases must be handled separately because // non-nested return statements are guaranteed to execute. - var retStmts []*ast.ReturnStmt var hasNonNestedReturn bool - startParent := findParent(outer, node) - ast.Inspect(outer, func(n ast.Node) bool { - if n == nil { - return false - } - if n.Pos() < start || n.End() > end { - return n.Pos() <= end + curStart, ok := pgf.Cursor.FindNode(node) + if !ok { + return nil, nil, bug.Errorf("cannot find Cursor for start Node") + } + curOuter, ok := pgf.Cursor.FindNode(outer) + if !ok { + return nil, nil, bug.Errorf("cannot find Cursor for start Node") + } + // Determine whether all return statements in the selection are + // error-handling return statements. They must be of the form: + // if err != nil { + // return ..., err + // } + // If all return statements in the extracted block have a non-nil error, we + // can replace the "shouldReturn" check with an error check to produce a + // more concise output. + allReturnsFinalErr := true // all ReturnStmts have final 'err' expression + hasReturn := false // selection contains a ReturnStmt + filter := []ast.Node{(*ast.ReturnStmt)(nil), (*ast.FuncLit)(nil)} + curOuter.Inspect(filter, func(cur inspector.Cursor) (descend bool) { + if funcLit, ok := cur.Node().(*ast.FuncLit); ok { + // Exclude return statements in function literals because they don't affect the refactor. + // Keep descending into func lits whose declaration is not included in the extracted block. + return !(start < funcLit.Pos() && funcLit.End() < end) + } + ret := cur.Node().(*ast.ReturnStmt) + if ret.Pos() < start || ret.End() > end { + return false // not part of the extracted block + } + hasReturn = true + + if cur.Parent() == curStart.Parent() { + hasNonNestedReturn = true } - // exclude return statements in function literals because they don't affect the refactor. - if _, ok := n.(*ast.FuncLit); ok { + + if !allReturnsFinalErr { + // Stop the traversal if we have already found a non error-handling return statement. return false } - ret, ok := n.(*ast.ReturnStmt) - if !ok { - return true - } - if findParent(outer, n) == startParent { - hasNonNestedReturn = true + // Check if the return statement returns a non-nil error as the last value. + if len(ret.Results) > 0 { + typ := info.TypeOf(ret.Results[len(ret.Results)-1]) + if typ != nil && types.Identical(typ, errorType) { + // Have: return ..., err + // Check for enclosing "if err != nil { return ..., err }". + // In that case, we can lift the error return to the caller. + if ifstmt, ok := cur.Parent().Parent().Node().(*ast.IfStmt); ok { + // Only handle the case where the if statement body contains a single statement. + if body, ok := cur.Parent().Node().(*ast.BlockStmt); ok && len(body.List) <= 1 { + if cond, ok := ifstmt.Cond.(*ast.BinaryExpr); ok { + tx := info.TypeOf(cond.X) + ty := info.TypeOf(cond.Y) + isErr := tx != nil && types.Identical(tx, errorType) + isNil := ty != nil && types.Identical(ty, types.Typ[types.UntypedNil]) + if cond.Op == token.NEQ && isErr && isNil { + // allReturnsErrHandling remains true + return false + } + } + } + } + } } - retStmts = append(retStmts, ret) + allReturnsFinalErr = false return false }) - containsReturnStatement := len(retStmts) > 0 + + allReturnsFinalErr = hasReturn && allReturnsFinalErr // Now that we have determined the correct range for the selection block, // we must determine the signature of the extracted function. We will then replace @@ -754,6 +797,7 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to // // The second condition below handles the case when // v's block is the FuncDecl.Body itself. + startParent := curStart.Parent().Node() if vscope.Pos() == startParent.Pos() || startParent == outer.Body && vscope == info.Scopes[outer.Type] { canRedefineCount++ @@ -894,13 +938,26 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to var retVars []*returnVariable var ifReturn *ast.IfStmt - if containsReturnStatement { + + // Determine if the extracted block contains any free branch statements, for + // example: "continue label" where "label" is declared outside of the + // extracted block, or continue inside a "for" statement where the for + // statement is declared outside of the extracted block. These will be + // handled below, after adjusting return statements and generating return + // info. + curSel, _ := pgf.Cursor.FindByPos(start, end) // since canExtractFunction succeeded, this will always return a valid cursor + freeBranches := freeBranches(info, curSel, start, end) + + // All return statements in the extracted block are error handling returns, and there are no free control statements. + isErrHandlingReturnsCase := allReturnsFinalErr && len(freeBranches) == 0 + + if hasReturn { if !hasNonNestedReturn { // The selected block contained return statements, so we have to modify the // signature of the extracted function as described above. Adjust all of // the return statements in the extracted function to reflect this change in // signature. - if err := adjustReturnStatements(returnTypes, seenVars, extractedBlock, qual); err != nil { + if err := adjustReturnStatements(returnTypes, seenVars, extractedBlock, qual, isErrHandlingReturnsCase); err != nil { return nil, nil, err } } @@ -908,17 +965,12 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to // statements in the selection. Update the type signature of the extracted // function and construct the if statement that will be inserted in the enclosing // function. - retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, start, end, hasNonNestedReturn) + retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, start, end, hasNonNestedReturn, isErrHandlingReturnsCase) if err != nil { return nil, nil, err } } - // Determine if the extracted block contains any free branch statements, for - // example: "continue label" where "label" is declared outside of the - // extracted block, or continue inside a "for" statement where the for - // statement is declared outside of the extracted block. - // If the extracted block contains free branch statements, we add another // return value "ctrl" to the extracted function that will be used to // determine the control flow. See the following example, where === denotes @@ -956,9 +1008,6 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to // } // - curSel, _ := pgf.Cursor.FindByPos(start, end) // since canExtractFunction succeeded, this will always return a valid cursor - freeBranches := freeBranches(info, curSel, start, end) - // Generate an unused identifier for the control value. ctrlVar, _ := freshName(info, file, start, "ctrl", 0) if len(freeBranches) > 0 { @@ -1079,8 +1128,16 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to return nil, nil, err } if ifReturn != nil { - if err := format.Node(&ifBuf, fset, ifReturn); err != nil { - return nil, nil, err + if isErrHandlingReturnsCase { + errName := retVars[len(retVars)-1] + fmt.Fprintf(&ifBuf, "if %s != nil ", errName.name.String()) + if err := format.Node(&ifBuf, fset, ifReturn.Body); err != nil { + return nil, nil, err + } + } else { + if err := format.Node(&ifBuf, fset, ifReturn); err != nil { + return nil, nil, err + } } } @@ -1307,20 +1364,6 @@ func isGoWhiteSpace(b byte) bool { return uint64(scanner.GoWhitespace)&(1< 2 { -+ return false, 0, 1 ++ return 0, false, 1 + } + if test == 10 { + return true, 1, 0 + } -+ return false, 0, 0 ++ return 0, false, 0 +} -- @freeControl8/freecontrol.go -- @@ -68,5 +68,3 @@ diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_return_err.txt b/gopls/internal/test/marker/testdata/codeaction/extract_return_err.txt new file mode 100644 index 00000000000..ba5b3e85f9e --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/extract_return_err.txt @@ -0,0 +1,140 @@ +This test verifies various behaviors of function extraction when every return statement in the extracted block is an error handling return. + +-- go.mod -- +module mod.test/extract + +go 1.18 + +-- errhandling.go -- +package err_handling +import ( + "encoding/json" + "fmt" +) + +//@codeaction(errHandlingBlk1, "refactor.extract.function", edit=err_handling_1) +//@codeaction(errHandlingBlk2, "refactor.extract.function", edit=err_handling_2) +//@codeaction(errHandlingBlk3, "refactor.extract.function", edit=err_handling_3) +//@codeaction(errHandlingBlk4, "refactor.extract.function", edit=err_handling_4) +//@codeaction(errHandlingBlk5, "refactor.extract.function", edit=err_handling_5) + +func Func() error { + a, err := json.Marshal(0) //@loc(errHandlingBlk1, re`(?s)a.*err1....`) + if err != nil { + return fmt.Errorf("1: %w", err) + } + b, err1 := json.Marshal(0) + if err1 != nil { + return fmt.Errorf("2: %w", err1) + } + fmt.Println(string(a), string(b)) + return nil +} + +func FuncReturnsInt() (int, error) { + a, err := json.Marshal(0) //@loc(errHandlingBlk2, re`(?s)a.*err2....`) + if err != nil { + return 0, fmt.Errorf("1: %w", err) + } + b, err2 := json.Marshal(0) + if err2 != nil { + return 1, fmt.Errorf("2: %w", err2) + } + fmt.Println(string(a), string(b)) + return 3, nil +} + +func FuncHasNilReturns() error { + if _, err := json.Marshal(0); err != nil { //@loc(errHandlingBlk3, re`(?s)if.*return.nil`) + return err + } + if _, err := json.Marshal(1); err != nil { + return err + } + return nil +} + +func FuncHasOtherReturns() ([]byte, error) { + if a, err := json.Marshal(0); err != nil { //@loc(errHandlingBlk4, re`(?s)if.*Marshal.1.`) + return a, err + } + return json.Marshal(1) +} + +func FuncErrNameAlreadyExists(err error) ([]byte, error) { + if a, err := json.Marshal(0); err != nil { //@loc(errHandlingBlk5, re`(?s)if.*a,.err...`) + return a, err + } + if a, err := json.Marshal(3); err != nil { + return a, err + } + return []byte{}, nil +} + +-- @err_handling_1/errhandling.go -- +@@ -14 +14,9 @@ ++ a, b, err := newFunction() ++ if err != nil { ++ return err ++ } ++ fmt.Println(string(a), string(b)) ++ return nil ++} ++ ++func newFunction() ([]byte, []byte, error) { +@@ -16 +25 @@ +- return fmt.Errorf("1: %w", err) ++ return nil, nil, fmt.Errorf("1: %w", err) +@@ -20 +29 @@ +- return fmt.Errorf("2: %w", err1) ++ return nil, nil, fmt.Errorf("2: %w", err1) +@@ -22,2 +31 @@ +- fmt.Println(string(a), string(b)) +- return nil ++ return a, b, nil +-- @err_handling_2/errhandling.go -- +@@ -27 +27,9 @@ ++ a, b, i, err := newFunction() ++ if err != nil { ++ return i, err ++ } ++ fmt.Println(string(a), string(b)) ++ return 3, nil ++} ++ ++func newFunction() ([]byte, []byte, int, error) { +@@ -29 +38 @@ +- return 0, fmt.Errorf("1: %w", err) ++ return nil, nil, 0, fmt.Errorf("1: %w", err) +@@ -33 +42 @@ +- return 1, fmt.Errorf("2: %w", err2) ++ return nil, nil, 1, fmt.Errorf("2: %w", err2) +@@ -35,2 +44 @@ +- fmt.Println(string(a), string(b)) +- return 3, nil ++ return a, b, 0, nil +-- @err_handling_3/errhandling.go -- +@@ -40 +40,4 @@ ++ return newFunction() ++} ++ ++func newFunction() error { +-- @err_handling_4/errhandling.go -- +@@ -50 +50,4 @@ ++ return newFunction() ++} ++ ++func newFunction() ([]byte, error) { +-- @err_handling_5/errhandling.go -- +@@ -57 +57,8 @@ ++ result, err1 := newFunction() ++ if err1 != nil { ++ return result, err1 ++ } ++ return []byte{}, nil ++} ++ ++func newFunction() ([]byte, error) { +@@ -63 +71 @@ +- return []byte{}, nil ++ return nil, nil diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt index 73276cbd03b..3006f34f3eb 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt @@ -53,7 +53,7 @@ package extract func _() bool { x := 1 - shouldReturn, b := newFunction(x) + b, shouldReturn := newFunction(x) if shouldReturn { return b } //@loc(ifend, "}") @@ -119,14 +119,14 @@ import "fmt" func _() (int, string, error) { x := 1 y := "hello" - z, shouldReturn, i, s, err := newFunction(y, x) + z, i, s, err, shouldReturn := newFunction(y, x) if shouldReturn { return i, s, err } //@loc(rcEnd, "}") return x, z, nil } -func newFunction(y string, x int) (string, bool, int, string, error) { +func newFunction(y string, x int) (string, int, string, error, bool) { z := "bye" //@codeaction("z", "refactor.extract.function", end=rcEnd, result=rc) if y == z { return "", true, x, y, fmt.Errorf("same") @@ -134,7 +134,7 @@ func newFunction(y string, x int) (string, bool, int, string, error) { z = "hi" return "", true, x, z, nil } - return z, false, 0, "", nil + return z, 0, "", nil, false } -- return_complex_nonnested.go -- @@ -198,7 +198,7 @@ import "go/ast" func _() { ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool { - shouldReturn, b := newFunction(n) + b, shouldReturn := newFunction(n) if shouldReturn { return b } //@loc(rflEnd, "}") @@ -263,7 +263,7 @@ package extract func _() string { x := 1 - shouldReturn, s := newFunction(x) + s, shouldReturn := newFunction(x) if shouldReturn { return s } //@loc(riEnd, "}") @@ -271,12 +271,12 @@ func _() string { return "b" } -func newFunction(x int) (bool, string) { +func newFunction(x int) (string, bool) { if x == 0 { //@codeaction("if", "refactor.extract.function", end=riEnd, result=ri) x = 3 return true, "a" } - return false, "" + return "", false } -- return_init_nonnested.go -- diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt index 0b2622f1d58..71b664f8f7d 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt @@ -29,24 +29,24 @@ import ( ) func F() error { - a, b, shouldReturn, err := newFunction() - if shouldReturn { + a, b, err := newFunction() + if err != nil { return err } //@loc(endF, "}") fmt.Printf("%s %s", a, b) return nil } -func newFunction() ([]byte, []byte, bool, error) { +func newFunction() ([]byte, []byte, error) { a, err := json.Marshal(0) //@codeaction("a", "refactor.extract.function", end=endF, result=F) if err != nil { - return nil, nil, true, fmt.Errorf("1: %w", err) + return nil, nil, fmt.Errorf("1: %w", err) } b, err := json.Marshal(0) if err != nil { - return nil, nil, true, fmt.Errorf("2: %w", err) + return nil, nil, fmt.Errorf("2: %w", err) } - return a, b, false, nil + return a, b, nil } -- b.go -- @@ -77,7 +77,7 @@ import ( ) func G() (x, y int) { - v, shouldReturn, x1, y1 := newFunction() + v, x1, y1, shouldReturn := newFunction() if shouldReturn { return x1, y1 } //@loc(endG, "}") @@ -85,7 +85,7 @@ func G() (x, y int) { return 5, 6 } -func newFunction() (int, bool, int, int) { +func newFunction() (int, int, int, bool) { v := rand.Int() //@codeaction("v", "refactor.extract.function", end=endG, result=G) if v < 0 { return 0, true, 1, 2 @@ -93,5 +93,5 @@ func newFunction() (int, bool, int, int) { if v > 0 { return 0, true, 3, 4 } - return v, false, 0, 0 + return v, 0, 0, false } From 19c36ab2ce291f2930975f43afa2ad48b61c6a3e Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 14 May 2025 15:25:24 -0400 Subject: [PATCH 101/196] internal/tokeninternal: use go1.25's FileSet.AddExistingFiles Updates golang/go#73205 Change-Id: Ib21f5b20b690c28553f3e2e063635b0c3c4b44df Reviewed-on: https://go-review.googlesource.com/c/tools/+/672618 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/astutil/util_test.go | 1 - internal/tokeninternal/tokeninternal.go | 87 ----------------- internal/tokeninternal/tokeninternal_go124.go | 97 +++++++++++++++++++ internal/tokeninternal/tokeninternal_go125.go | 18 ++++ 4 files changed, 115 insertions(+), 88 deletions(-) create mode 100644 internal/tokeninternal/tokeninternal_go124.go create mode 100644 internal/tokeninternal/tokeninternal_go125.go diff --git a/internal/astutil/util_test.go b/internal/astutil/util_test.go index da07ea88594..bbb72d75ab2 100644 --- a/internal/astutil/util_test.go +++ b/internal/astutil/util_test.go @@ -63,5 +63,4 @@ func g() { if got := fmt.Sprint(gotStack); got != wantStack { t.Errorf("PreorderStack stack:\ngot: %s\nwant: %s", got, wantStack) } - } diff --git a/internal/tokeninternal/tokeninternal.go b/internal/tokeninternal/tokeninternal.go index 549bb183976..ccc09da57c6 100644 --- a/internal/tokeninternal/tokeninternal.go +++ b/internal/tokeninternal/tokeninternal.go @@ -7,97 +7,10 @@ package tokeninternal import ( - "fmt" "go/token" "slices" - "sort" - "sync" - "sync/atomic" - "unsafe" ) -// AddExistingFiles adds the specified files to the FileSet if they -// are not already present. It panics if any pair of files in the -// resulting FileSet would overlap. -// -// TODO(adonovan): add this a method to FileSet; see -// https://github.com/golang/go/issues/73205 -func AddExistingFiles(fset *token.FileSet, files []*token.File) { - - // This function cannot be implemented as: - // - // for _, file := range files { - // if prev := fset.File(token.Pos(file.Base())); prev != nil { - // if prev != file { - // panic("FileSet contains a different file at the same base") - // } - // continue - // } - // file2 := fset.AddFile(file.Name(), file.Base(), file.Size()) - // file2.SetLines(file.Lines()) - // } - // - // because all calls to AddFile must be in increasing order. - // AddExistingFiles lets us augment an existing FileSet - // sequentially, so long as all sets of files have disjoint - // ranges. - - // Punch through the FileSet encapsulation. - type tokenFileSet struct { - // This type remained essentially consistent from go1.16 to go1.21. - mutex sync.RWMutex - base int - files []*token.File - _ atomic.Pointer[token.File] - } - - // If the size of token.FileSet changes, this will fail to compile. - const delta = int64(unsafe.Sizeof(tokenFileSet{})) - int64(unsafe.Sizeof(token.FileSet{})) - var _ [-delta * delta]int - - type uP = unsafe.Pointer - var ptr *tokenFileSet - *(*uP)(uP(&ptr)) = uP(fset) - ptr.mutex.Lock() - defer ptr.mutex.Unlock() - - // Merge and sort. - newFiles := append(ptr.files, files...) - sort.Slice(newFiles, func(i, j int) bool { - return newFiles[i].Base() < newFiles[j].Base() - }) - - // Reject overlapping files. - // Discard adjacent identical files. - out := newFiles[:0] - for i, file := range newFiles { - if i > 0 { - prev := newFiles[i-1] - if file == prev { - continue - } - if prev.Base()+prev.Size()+1 > file.Base() { - panic(fmt.Sprintf("file %s (%d-%d) overlaps with file %s (%d-%d)", - prev.Name(), prev.Base(), prev.Base()+prev.Size(), - file.Name(), file.Base(), file.Base()+file.Size())) - } - } - out = append(out, file) - } - newFiles = out - - ptr.files = newFiles - - // Advance FileSet.Base(). - if len(newFiles) > 0 { - last := newFiles[len(newFiles)-1] - newBase := last.Base() + last.Size() + 1 - if ptr.base < newBase { - ptr.base = newBase - } - } -} - // FileSetFor returns a new FileSet containing a sequence of new Files with // the same base, size, and line as the input files, for use in APIs that // require a FileSet. diff --git a/internal/tokeninternal/tokeninternal_go124.go b/internal/tokeninternal/tokeninternal_go124.go new file mode 100644 index 00000000000..6a002fcbb83 --- /dev/null +++ b/internal/tokeninternal/tokeninternal_go124.go @@ -0,0 +1,97 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.25 + +package tokeninternal + +import ( + "fmt" + "go/token" + "sort" + "sync" + "sync/atomic" + "unsafe" +) + +// AddExistingFiles adds the specified files to the FileSet if they +// are not already present. It panics if any pair of files in the +// resulting FileSet would overlap. +// +// TODO(adonovan): replace with FileSet.AddExistingFiles in go1.25. +func AddExistingFiles(fset *token.FileSet, files []*token.File) { + + // This function cannot be implemented as: + // + // for _, file := range files { + // if prev := fset.File(token.Pos(file.Base())); prev != nil { + // if prev != file { + // panic("FileSet contains a different file at the same base") + // } + // continue + // } + // file2 := fset.AddFile(file.Name(), file.Base(), file.Size()) + // file2.SetLines(file.Lines()) + // } + // + // because all calls to AddFile must be in increasing order. + // AddExistingFiles lets us augment an existing FileSet + // sequentially, so long as all sets of files have disjoint + // ranges. + + // Punch through the FileSet encapsulation. + type tokenFileSet struct { + // This type remained essentially consistent from go1.16 to go1.21. + mutex sync.RWMutex + base int + files []*token.File + _ atomic.Pointer[token.File] + } + + // If the size of token.FileSet changes, this will fail to compile. + const delta = int64(unsafe.Sizeof(tokenFileSet{})) - int64(unsafe.Sizeof(token.FileSet{})) + var _ [-delta * delta]int + + type uP = unsafe.Pointer + var ptr *tokenFileSet + *(*uP)(uP(&ptr)) = uP(fset) + ptr.mutex.Lock() + defer ptr.mutex.Unlock() + + // Merge and sort. + newFiles := append(ptr.files, files...) + sort.Slice(newFiles, func(i, j int) bool { + return newFiles[i].Base() < newFiles[j].Base() + }) + + // Reject overlapping files. + // Discard adjacent identical files. + out := newFiles[:0] + for i, file := range newFiles { + if i > 0 { + prev := newFiles[i-1] + if file == prev { + continue + } + if prev.Base()+prev.Size()+1 > file.Base() { + panic(fmt.Sprintf("file %s (%d-%d) overlaps with file %s (%d-%d)", + prev.Name(), prev.Base(), prev.Base()+prev.Size(), + file.Name(), file.Base(), file.Base()+file.Size())) + } + } + out = append(out, file) + } + newFiles = out + + ptr.files = newFiles + + // Advance FileSet.Base(). + if len(newFiles) > 0 { + last := newFiles[len(newFiles)-1] + newBase := last.Base() + last.Size() + 1 + if ptr.base < newBase { + ptr.base = newBase + } + } +} diff --git a/internal/tokeninternal/tokeninternal_go125.go b/internal/tokeninternal/tokeninternal_go125.go new file mode 100644 index 00000000000..9c9e9745935 --- /dev/null +++ b/internal/tokeninternal/tokeninternal_go125.go @@ -0,0 +1,18 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.25 + +package tokeninternal + +import "go/token" + +// AddExistingFiles adds the specified files to the FileSet if they +// are not already present. It panics if any pair of files in the +// resulting FileSet would overlap. +// +// TODO(adonovan): eliminate when go1.25 is always available. +func AddExistingFiles(fset *token.FileSet, files []*token.File) { + fset.AddExistingFiles(files...) +} From 87f67c8d21617631b8a1ecf8bcbc657c6cdf7164 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 19 May 2025 16:12:59 -0400 Subject: [PATCH 102/196] gopls/internal/debug: display Session.View.Folder.Options Change-Id: I71185aac510756055af54f7ec9be29da68fae4e7 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674275 Reviewed-by: Robert Findley Auto-Submit: Alan Donovan LUCI-TryBot-Result: Go LUCI --- gopls/internal/debug/serve.go | 7 +++++++ gopls/internal/settings/settings.go | 25 +++++++++++++++++++++++++ gopls/internal/settings/staticcheck.go | 4 ++-- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/gopls/internal/debug/serve.go b/gopls/internal/debug/serve.go index 7cfe2b3d23e..b8fdfe0791f 100644 --- a/gopls/internal/debug/serve.go +++ b/gopls/internal/debug/serve.go @@ -800,6 +800,7 @@ var SessionTmpl = template.Must(template.Must(BaseTemplate.Clone()).Parse(` {{define "title"}}Session {{.ID}}{{end}} {{define "body"}} From: {{template "cachelink" .Cache.ID}}
+

Views

    {{range .Views}} {{- $envOverlay := .EnvOverlay -}} @@ -810,7 +811,13 @@ Root: {{.Root}}
    Env overlay: {{$envOverlay}})
    {{end -}} Folder: {{.Folder.Name}}:{{.Folder.Dir}} +Settings:
    +
      +{{range .Folder.Options.Debug}}
    • {{.}}
    • +{{end}} +
    {{end}}
+

Overlays

{{$session := .}}
    {{range .Overlays}} diff --git a/gopls/internal/settings/settings.go b/gopls/internal/settings/settings.go index 8a694854edd..b4a2c2cdcf2 100644 --- a/gopls/internal/settings/settings.go +++ b/gopls/internal/settings/settings.go @@ -8,6 +8,7 @@ import ( "fmt" "maps" "path/filepath" + "reflect" "strings" "time" @@ -56,6 +57,30 @@ type Options struct { InternalOptions } +// Debug returns a list of "name = value" strings for each Options field. +func (o *Options) Debug() []string { + var res []string + + var visitStruct func(v reflect.Value, path []string) + visitStruct = func(v reflect.Value, path []string) { + for i := range v.NumField() { + f := v.Field(i) + ftyp := v.Type().Field(i) + path := append(path, ftyp.Name) + if ftyp.Type.Kind() == reflect.Struct { + visitStruct(f, path) + } else { + res = append(res, fmt.Sprintf("%s = %#v", + strings.Join(path, "."), + f.Interface())) + } + } + } + visitStruct(reflect.ValueOf(o).Elem(), nil) + + return res +} + // ClientOptions holds LSP-specific configuration that is provided by the // client. // diff --git a/gopls/internal/settings/staticcheck.go b/gopls/internal/settings/staticcheck.go index 68e48819cfc..04b89d5629e 100644 --- a/gopls/internal/settings/staticcheck.go +++ b/gopls/internal/settings/staticcheck.go @@ -425,7 +425,7 @@ func initStaticcheckAnalyzers() (res []*Analyzer) { sa5005.SCAnalyzer: false, // requires buildir sa5007.SCAnalyzer: false, // requires buildir sa5008.SCAnalyzer: true, - sa5009.SCAnalyzer: nil, // requires buildir; redundant wrt 'printf' (#34494, + sa5009.SCAnalyzer: nil, // requires buildir; redundant wrt 'printf' (#34494) sa5010.SCAnalyzer: false, // requires buildir sa5011.SCAnalyzer: false, // requires buildir sa5012.SCAnalyzer: false, // requires buildir @@ -435,7 +435,7 @@ func initStaticcheckAnalyzers() (res []*Analyzer) { sa6003.SCAnalyzer: false, // requires buildir sa6005.SCAnalyzer: true, sa6006.SCAnalyzer: true, - sa9001.SCAnalyzer: false, // reports a "maybe" bug (low signal/noise, + sa9001.SCAnalyzer: false, // reports a "maybe" bug (low signal/noise) sa9002.SCAnalyzer: true, sa9003.SCAnalyzer: false, // requires buildir; NonDefault sa9004.SCAnalyzer: true, From babda1308b25201b5f4dc76f88ba7aaf8b38fbb8 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 19 May 2025 16:34:45 -0400 Subject: [PATCH 103/196] gopls: update to github.com/dominikh/go-tools@v0.6.1 Change-Id: I8fcbd1b3c29db50121b53f3b12c6d0ec07ead9db Reviewed-on: https://go-review.googlesource.com/c/tools/+/674256 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Auto-Submit: Alan Donovan --- gopls/go.mod | 2 +- gopls/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gopls/go.mod b/gopls/go.mod index 96c3fbb127a..9868579f20d 100644 --- a/gopls/go.mod +++ b/gopls/go.mod @@ -14,7 +14,7 @@ require ( golang.org/x/tools v0.30.0 golang.org/x/vuln v1.1.4 gopkg.in/yaml.v3 v3.0.1 - honnef.co/go/tools v0.6.0 + honnef.co/go/tools v0.6.1 mvdan.cc/gofumpt v0.7.0 mvdan.cc/xurls/v2 v2.6.0 ) diff --git a/gopls/go.sum b/gopls/go.sum index 27f999d51a4..143edbc8909 100644 --- a/gopls/go.sum +++ b/gopls/go.sum @@ -59,8 +59,8 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogR gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.6.0 h1:TAODvD3knlq75WCp2nyGJtT4LeRV/o7NN9nYPeVJXf8= -honnef.co/go/tools v0.6.0/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= +honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= +honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU= mvdan.cc/gofumpt v0.7.0/go.mod h1:txVFJy/Sc/mvaycET54pV8SW8gWxTlUuGHVEcncmNUo= mvdan.cc/xurls/v2 v2.6.0 h1:3NTZpeTxYVWNSokW3MKeyVkz/j7uYXYiMtXRUfmjbgI= From b62c6c16625b2809d92238b950757447e7a28a56 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 19 May 2025 20:52:51 +0000 Subject: [PATCH 104/196] internal/mcp: remove misplaced TODO comment Remove a TODO that was in package doc position. Change-Id: Ib98dbcd7d733fdf0059bc3ced1fd6a1008b91a26 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674257 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/mcp/client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 23cd53726a8..c9fed8b134f 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -2,7 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TODO: consider passing Transport to NewClient and merging {Connection,Client}Options package mcp import ( From a2c2a72556492d3a7c1430afe2bb4fa87d586cb5 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 20 May 2025 13:39:42 +0000 Subject: [PATCH 105/196] internal/tokeninternal: avoid use of AddExistingFiles, for now Recently, token.FileSet.AddExistingFiles was added to the standard library. However, our immediate use of it in x/tools is causing problems due to version skew when we import into Google, where the toolchain may be a week or two old. Temporarily avoid the use of this new API. Change-Id: I2cc6d89b32505d07c3b7275b4279f952e0ecf28c Reviewed-on: https://go-review.googlesource.com/c/tools/+/674435 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI Auto-Submit: Robert Findley --- internal/tokeninternal/tokeninternal_go124.go | 5 ++++- internal/tokeninternal/tokeninternal_go125.go | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/internal/tokeninternal/tokeninternal_go124.go b/internal/tokeninternal/tokeninternal_go124.go index 6a002fcbb83..da34ae608ca 100644 --- a/internal/tokeninternal/tokeninternal_go124.go +++ b/internal/tokeninternal/tokeninternal_go124.go @@ -2,7 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !go1.25 +// TODO(rfindley): once the new AddExistingFiles API has had some time to soak +// in std, use it in x/tools and change the condition below to !go1.25. + +//go:build !addexistingfiles package tokeninternal diff --git a/internal/tokeninternal/tokeninternal_go125.go b/internal/tokeninternal/tokeninternal_go125.go index 9c9e9745935..712c3414130 100644 --- a/internal/tokeninternal/tokeninternal_go125.go +++ b/internal/tokeninternal/tokeninternal_go125.go @@ -2,7 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.25 +// TODO(rfindley): once the new AddExistingFiles API has had some time to soak +// in std, use it here behind the go1.25 build tag. + +//go:build addexistingfiles package tokeninternal From 10eb2f3731a56476495d194ce2127245f6203af9 Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Fri, 16 May 2025 13:13:30 -0400 Subject: [PATCH 106/196] gopls/internal/mcp: add location info to mcp tool input arg Make a copy of the lsp protocol.Location type and all its dependency types in mcp server. The MCP server use json tag (converted to json schema) to inform the MCP client about the parameters. In marker test, the value marker @loc can collect location information and passed to the MCP tool's input parameter. For now, the marker test hard-coded the field name to "loc". The hello_world mcp tool will return both the request name and the location's file name. For golang/go#73580 Change-Id: Iaccdcde845ff76d21a8a04233d96be9f2622f775 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674255 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- gopls/internal/mcp/mcp.go | 44 +++++++++++++++---- gopls/internal/test/marker/doc.go | 13 +++--- gopls/internal/test/marker/marker_test.go | 8 +++- .../marker/testdata/mcptools/hello_world.txt | 6 ++- 4 files changed, 54 insertions(+), 17 deletions(-) diff --git a/gopls/internal/mcp/mcp.go b/gopls/internal/mcp/mcp.go index 1a4a595cd54..e8744f76a56 100644 --- a/gopls/internal/mcp/mcp.go +++ b/gopls/internal/mcp/mcp.go @@ -10,6 +10,7 @@ import ( "log" "net" "net/http" + "path" "sync" "golang.org/x/tools/gopls/internal/cache" @@ -114,18 +115,45 @@ func newServer(_ *cache.Cache, session *cache.Session) *mcp.Server { s := mcp.NewServer("golang", "v0.1", nil) // TODO(hxjiang): replace dummy tool with tools which use cache and session. - s.AddTools(mcp.NewTool("hello_world", "Say hello to someone", helloHandler(session))) + s.AddTools( + mcp.NewTool( + "hello_world", + "Say hello to someone", + func(ctx context.Context, _ *mcp.ServerSession, request HelloParams) ([]*mcp.Content, error) { + return helloHandler(ctx, session, request) + }, + ), + ) return s } type HelloParams struct { - Name string `json:"name" mcp:"the name to say hi to"` + Name string `json:"name" mcp:"the name to say hi to"` + Location Location `json:"loc" mcp:"location inside of a text file"` } -func helloHandler(_ *cache.Session) func(ctx context.Context, cc *mcp.ServerSession, request *HelloParams) ([]*mcp.Content, error) { - return func(ctx context.Context, cc *mcp.ServerSession, request *HelloParams) ([]*mcp.Content, error) { - return []*mcp.Content{ - mcp.NewTextContent("Hi " + request.Name), - }, nil - } +func helloHandler(_ context.Context, _ *cache.Session, request HelloParams) ([]*mcp.Content, error) { + return []*mcp.Content{ + mcp.NewTextContent(fmt.Sprintf("Hi %s, current file %s.", request.Name, path.Base(request.Location.URI))), + }, nil +} + +// Location describes a range within a text document. +// +// It is structurally equal to protocol.Location, but has mcp tags instead of json. +// TODO(hxjiang): experiment if the LLM can correctly provide the right location +// information. +type Location struct { + URI string `json:"uri" mcp:"URI to the text file"` + Range Range `json:"range" mcp:"range within text document"` +} + +type Range struct { + Start Position `json:"start" mcp:"the range's start position"` + End Position `json:"end" mcp:"the range's end position"` +} + +type Position struct { + Line uint32 `json:"line" mcp:"line number (zero-based)"` + Character uint32 `json:"character" mcp:"column number (zero-based, UTF-16 encoding)"` } diff --git a/gopls/internal/test/marker/doc.go b/gopls/internal/test/marker/doc.go index 131d799a758..f3ad975d6fd 100644 --- a/gopls/internal/test/marker/doc.go +++ b/gopls/internal/test/marker/doc.go @@ -307,12 +307,13 @@ Here is the list of supported action markers: location name kind - - mcptool(name string, args string, output=golden): Executes an MCP tool - call using the provided tool name and args (a JSON-encoded value). It then - asserts that the MCP server's response matches the content of the golden - file identified by output. Unlike golden references for file edits or file - results, which may contain multiple files (each with a path), the output - golden content here is a single entity, effectively having an empty path(""). + - mcptool(name string, arg string, src location, output=golden): Executes an + MCP tool call using the provided tool name and args (a JSON-encoded value) + with the source location. It then asserts that the MCP server's response + matches the content of the golden file identified by output. Unlike golden + references for file edits or file results, which may contain multiple files + (each with a path), the output golden content here is a single entity, + effectively having an empty path(""). # Argument conversion diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index f8aa59634d7..4288b756da9 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -2428,13 +2428,18 @@ func itemLocation(item protocol.CallHierarchyItem) protocol.Location { } } -func mcpToolMarker(mark marker, tool string, args string) { +func mcpToolMarker(mark marker, tool string, args string, loc protocol.Location) { var toolArgs map[string]any if err := json.Unmarshal([]byte(args), &toolArgs); err != nil { mark.errorf("fail to unmarshal arguments to map[string]any: %v", err) return } + // Inserts the location value into the MCP tool arguments map under the + // "loc" key. + // TODO(hxjiang): Make the "loc" key configurable. + toolArgs["loc"] = loc + res, err := mark.run.env.MCPSession.CallTool(mark.ctx(), tool, toolArgs, nil) if err != nil { mark.errorf("failed to call mcp tool: %v", err) @@ -2445,6 +2450,7 @@ func mcpToolMarker(mark marker, tool string, args string) { for i, c := range res.Content { if c.Type != "text" { mark.errorf("unsupported return content[%v] type: %s", i, c.Type) + continue } buf.WriteString(c.Text) } diff --git a/gopls/internal/test/marker/testdata/mcptools/hello_world.txt b/gopls/internal/test/marker/testdata/mcptools/hello_world.txt index 5bae5afa416..8ae6f745565 100644 --- a/gopls/internal/test/marker/testdata/mcptools/hello_world.txt +++ b/gopls/internal/test/marker/testdata/mcptools/hello_world.txt @@ -9,7 +9,9 @@ module golang.org/mcptests/mcptools -- mcp/tools/helloworld.go -- package helloworld -func A() {} //@mcptool("hello_world", `{"name": "jerry"}`, output=hello) +func A() {} //@loc(loc, "A") + +//@mcptool("hello_world", `{"name": "jerry"}`, loc, output=hello) -- @hello -- -Hi jerry +Hi jerry, current file helloworld.go. From 150502a9817395ac1ee4e3bcc6bde76c7e940d20 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Mon, 19 May 2025 14:57:58 +0000 Subject: [PATCH 107/196] internal/mcp: keep an ordered list of feature keys This change is necessary to support pagination which will be implemented in subsequent CLs. Change-Id: I8bc616a57463ea569dc7cadf1d4f06c69602ddd5 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674135 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/mcp/examples/sse/main.go | 2 +- internal/mcp/features.go | 52 +++++++++++++++++---- internal/mcp/features_test.go | 75 +++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 9 deletions(-) create mode 100644 internal/mcp/features_test.go diff --git a/internal/mcp/examples/sse/main.go b/internal/mcp/examples/sse/main.go index b5a1cec1aac..201f3a091ce 100644 --- a/internal/mcp/examples/sse/main.go +++ b/internal/mcp/examples/sse/main.go @@ -16,7 +16,7 @@ import ( var httpAddr = flag.String("http", "", "use SSE HTTP at this address") type SayHiParams struct { - Name string `json:"name" mcp:"the name to say hi to"` + Name string `json:"name"` } func SayHi(ctx context.Context, cc *mcp.ServerSession, params *SayHiParams) ([]*mcp.Content, error) { diff --git a/internal/mcp/features.go b/internal/mcp/features.go index 42e74c86aaf..af8b5416b67 100644 --- a/internal/mcp/features.go +++ b/internal/mcp/features.go @@ -6,8 +6,8 @@ package mcp import ( "iter" - - "golang.org/x/tools/internal/mcp/internal/util" + "maps" + "slices" ) // This file contains implementations that are common to all features. @@ -17,9 +17,11 @@ import ( // A featureSet is a collection of features of type T. // Every feature has a unique ID, and the spec never mentions // an ordering for the List calls, so what it calls a "list" is actually a set. +// TODO: switch to an ordered map type featureSet[T any] struct { - uniqueID func(T) string - features map[string]T + uniqueID func(T) string + features map[string]T + sortedKeys []string // lazily computed; nil after add or remove } // newFeatureSet creates a new featureSet for features of type T. @@ -37,6 +39,7 @@ func (s *featureSet[T]) add(fs ...T) { for _, f := range fs { s.features[s.uniqueID(f)] = f } + s.sortedKeys = nil } // remove removes all features with the given uids from the set if present, @@ -50,6 +53,9 @@ func (s *featureSet[T]) remove(uids ...string) bool { delete(s.features, uid) } } + if changed { + s.sortedKeys = nil + } return changed } @@ -63,11 +69,41 @@ func (s *featureSet[T]) get(uid string) (T, bool) { // all returns an iterator over of all the features in the set // sorted by unique ID. func (s *featureSet[T]) all() iter.Seq[T] { + s.sortKeys() + return func(yield func(T) bool) { + s.yieldFrom(0, yield) + } +} + +// above returns an iterator over features in the set whose unique IDs are +// greater than `uid`, in ascending ID order. +func (s *featureSet[T]) above(uid string) iter.Seq[T] { + s.sortKeys() + index, found := slices.BinarySearch(s.sortedKeys, uid) + if found { + index++ + } return func(yield func(T) bool) { - for _, f := range util.Sorted(s.features) { - if !yield(f) { - return - } + s.yieldFrom(index, yield) + } +} + +// sortKeys is a helper that maintains a sorted list of feature IDs. It +// computes this list lazily upon its first call after a modification, or +// if it's nil. +func (s *featureSet[T]) sortKeys() { + if s.sortedKeys != nil { + return + } + s.sortedKeys = slices.Sorted(maps.Keys(s.features)) +} + +// yieldFrom is a helper that iterates over the features in the set, +// starting at the given index, and calls the yield function for each one. +func (s *featureSet[T]) yieldFrom(index int, yield func(T) bool) { + for i := index; i < len(s.sortedKeys); i++ { + if !yield(s.features[s.sortedKeys[i]]) { + return } } } diff --git a/internal/mcp/features_test.go b/internal/mcp/features_test.go new file mode 100644 index 00000000000..2bda8745932 --- /dev/null +++ b/internal/mcp/features_test.go @@ -0,0 +1,75 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "golang.org/x/tools/internal/mcp/jsonschema" +) + +type SayHiParams struct { + Name string `json:"name"` +} + +func SayHi(ctx context.Context, cc *ServerSession, params *SayHiParams) ([]*Content, error) { + return []*Content{ + NewTextContent("Hi " + params.Name), + }, nil +} + +func TestFeatureSetOrder(t *testing.T) { + toolA := NewTool("apple", "apple tool", SayHi).Tool + toolB := NewTool("banana", "banana tool", SayHi).Tool + toolC := NewTool("cherry", "cherry tool", SayHi).Tool + + testCases := []struct { + tools []*Tool + want []*Tool + }{ + {[]*Tool{toolA, toolB, toolC}, []*Tool{toolA, toolB, toolC}}, + {[]*Tool{toolB, toolC, toolA}, []*Tool{toolA, toolB, toolC}}, + {[]*Tool{toolA, toolC}, []*Tool{toolA, toolC}}, + {[]*Tool{toolA, toolA, toolA}, []*Tool{toolA}}, + {[]*Tool{}, nil}, + } + for _, tc := range testCases { + fs := newFeatureSet(func(t *Tool) string { return t.Name }) + fs.add(tc.tools...) + got := slices.Collect(fs.all()) + if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Errorf("expected %v, got %v, (-want +got):\n%s", tc.want, got, diff) + } + } +} + +func TestFeatureSetAbove(t *testing.T) { + toolA := NewTool("apple", "apple tool", SayHi).Tool + toolB := NewTool("banana", "banana tool", SayHi).Tool + toolC := NewTool("cherry", "cherry tool", SayHi).Tool + + testCases := []struct { + tools []*Tool + above string + want []*Tool + }{ + {[]*Tool{toolA, toolB, toolC}, "apple", []*Tool{toolB, toolC}}, + {[]*Tool{toolA, toolB, toolC}, "banana", []*Tool{toolC}}, + {[]*Tool{toolA, toolB, toolC}, "cherry", nil}, + } + for _, tc := range testCases { + fs := newFeatureSet(func(t *Tool) string { return t.Name }) + fs.add(tc.tools...) + got := slices.Collect(fs.above(tc.above)) + if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Errorf("expected %v, got %v, (-want +got):\n%s", tc.want, got, diff) + + } + } +} From 9460f2fc87e695183ba5d72b3fb129874889be3c Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 20 May 2025 13:33:02 -0400 Subject: [PATCH 108/196] gopls/internal/analysis/recursiveiter: report inefficient recursion This CL adds a new gopls analyzer, recursiveiter, that reports inefficient recursive uses of range-over-func in the implementation of an iterator (iter.Seq or iter.Seq2). The issue is rather subtle and the problem hard to see, but it is quite easy to detect (though not fix) analytically. Here's an example bug from the module mirror corpus: https://go-mod-viewer.appspot.com/golang.org/x/arch@v0.17.0/internal/unify/yaml_test.go#L84 (I don't expect there will be many yet since iterators are new, but the check is cheap.) Also: report an error if a marker test file has duplicate sections. Thanks to Rob for pointing out my error. + test, doc, relnote Change-Id: If9e625bf321aa5fa3a2b1f2ab0a65f31d445be70 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674438 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- gopls/doc/analyzers.md | 89 ++++++++++++++++++ gopls/doc/release/v0.19.0.md | 11 +++ gopls/internal/analysis/recursiveiter/doc.go | 92 ++++++++++++++++++ gopls/internal/analysis/recursiveiter/main.go | 16 ++++ .../analysis/recursiveiter/recursiveiter.go | 94 +++++++++++++++++++ .../recursiveiter/recursiveiter_test.go | 17 ++++ .../recursiveiter/testdata/src/a/a.go | 30 ++++++ gopls/internal/doc/api.json | 12 +++ gopls/internal/settings/analysis.go | 6 +- gopls/internal/test/marker/marker_test.go | 5 + .../marker/testdata/diagnostics/analyzers.txt | 20 ++-- 11 files changed, 384 insertions(+), 8 deletions(-) create mode 100644 gopls/internal/analysis/recursiveiter/doc.go create mode 100644 gopls/internal/analysis/recursiveiter/main.go create mode 100644 gopls/internal/analysis/recursiveiter/recursiveiter.go create mode 100644 gopls/internal/analysis/recursiveiter/recursiveiter_test.go create mode 100644 gopls/internal/analysis/recursiveiter/testdata/src/a/a.go diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md index cea19c40ca3..2a974eaa496 100644 --- a/gopls/doc/analyzers.md +++ b/gopls/doc/analyzers.md @@ -3894,6 +3894,95 @@ Default: on. Package documentation: [printf](https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/printf) + +## `recursiveiter`: check for inefficient recursive iterators + + +This analyzer reports when a function that returns an iterator +(iter.Seq or iter.Seq2) calls itself as the operand of a range +statement, as this is inefficient. + +When implementing an iterator (e.g. iter.Seq[T]) for a recursive +data type such as a tree or linked list, it is tempting to +recursively range over the iterator for each child element. + +Here's an example of a naive iterator over a binary tree: + + type tree struct { + value int + left, right *tree + } + + func (t *tree) All() iter.Seq[int] { + return func(yield func(int) bool) { + if t != nil { + for elem := range t.left.All() { // "inefficient recursive iterator" + if !yield(elem) { + return + } + } + if !yield(t.value) { + return + } + for elem := range t.right.All() { // "inefficient recursive iterator" + if !yield(elem) { + return + } + } + } + } + } + +Though it correctly enumerates the elements of the tree, it hides a +significant performance problem--two, in fact. Consider a balanced +tree of N nodes. Iterating the root node will cause All to be +called once on every node of the tree. This results in a chain of +nested active range-over-func statements when yield(t.value) is +called on a leaf node. + +The first performance problem is that each range-over-func +statement must typically heap-allocate a variable, so iteration of +the tree allocates as many variables as there are elements in the +tree, for a total of O(N) allocations, all unnecessary. + +The second problem is that each call to yield for a leaf of the +tree causes each of the enclosing range loops to receive a value, +which they then immediately pass on to their respective yield +function. This results in a chain of log(N) dynamic yield calls per +element, a total of O(N*log N) dynamic calls overall, when only +O(N) are necessary. + +A better implementation strategy for recursive iterators is to +first define the "every" operator for your recursive data type, +where every(f) reports whether f(x) is true for every element x in +the data type. For our tree, the every function would be: + + func (t *tree) every(f func(int) bool) bool { + return t == nil || + t.left.every(f) && f(t.value) && t.right.every(f) + } + +Then the iterator can be simply expressed as a trivial wrapper +around this function: + + func (t *tree) All() iter.Seq[int] { + return func(yield func(int) bool) { + _ = t.every(yield) + } + } + +In effect, tree.All computes whether yield returns true for each +element, short-circuiting if it every returns false, then discards +the final boolean result. + +This has much better performance characteristics: it makes one +dynamic call per element of the tree, and it doesn't heap-allocate +anything. It is also clearer. + +Default: on. + +Package documentation: [recursiveiter](https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/recursiveiter) + ## `shadow`: check for possible unintended shadowing of variables diff --git a/gopls/doc/release/v0.19.0.md b/gopls/doc/release/v0.19.0.md index 8842098639e..9ae1feb6c36 100644 --- a/gopls/doc/release/v0.19.0.md +++ b/gopls/doc/release/v0.19.0.md @@ -47,6 +47,17 @@ setting now takes precedence over the `staticcheck` setting, so, regardless of what value of `staticcheck` you use (true/false/unset), you can make adjustments to your preferred set of analyzers. +## "Inefficient recursive iterator" analyzer + +A common pitfall when writing a function that returns an iterator +(iter.Seq) for a recursive data type is to recursively call the +function from its own implementation, leading to a stack of nested +coroutines, which is inefficient. + +The new `recursiveiter` analyzer detects such mistakes; see +[https://golang.org/x/tools/gopls/internal/analysis/recursiveiter](its +documentation) for details, including tips on how to define simple and +efficient recursive iterators. ## "Implementations" supports signature types diff --git a/gopls/internal/analysis/recursiveiter/doc.go b/gopls/internal/analysis/recursiveiter/doc.go new file mode 100644 index 00000000000..eb9c6c92bb0 --- /dev/null +++ b/gopls/internal/analysis/recursiveiter/doc.go @@ -0,0 +1,92 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package recursiveiter defines an Analyzer that checks for mistakes +// in iterators for recursive data structures. +// +// # Analyzer recursiveiter +// +// recursiveiter: check for inefficient recursive iterators +// +// This analyzer reports when a function that returns an iterator +// (iter.Seq or iter.Seq2) calls itself as the operand of a range +// statement, as this is inefficient. +// +// When implementing an iterator (e.g. iter.Seq[T]) for a recursive +// data type such as a tree or linked list, it is tempting to +// recursively range over the iterator for each child element. +// +// Here's an example of a naive iterator over a binary tree: +// +// type tree struct { +// value int +// left, right *tree +// } +// +// func (t *tree) All() iter.Seq[int] { +// return func(yield func(int) bool) { +// if t != nil { +// for elem := range t.left.All() { // "inefficient recursive iterator" +// if !yield(elem) { +// return +// } +// } +// if !yield(t.value) { +// return +// } +// for elem := range t.right.All() { // "inefficient recursive iterator" +// if !yield(elem) { +// return +// } +// } +// } +// } +// } +// +// Though it correctly enumerates the elements of the tree, it hides a +// significant performance problem--two, in fact. Consider a balanced +// tree of N nodes. Iterating the root node will cause All to be +// called once on every node of the tree. This results in a chain of +// nested active range-over-func statements when yield(t.value) is +// called on a leaf node. +// +// The first performance problem is that each range-over-func +// statement must typically heap-allocate a variable, so iteration of +// the tree allocates as many variables as there are elements in the +// tree, for a total of O(N) allocations, all unnecessary. +// +// The second problem is that each call to yield for a leaf of the +// tree causes each of the enclosing range loops to receive a value, +// which they then immediately pass on to their respective yield +// function. This results in a chain of log(N) dynamic yield calls per +// element, a total of O(N*log N) dynamic calls overall, when only +// O(N) are necessary. +// +// A better implementation strategy for recursive iterators is to +// first define the "every" operator for your recursive data type, +// where every(f) reports whether f(x) is true for every element x in +// the data type. For our tree, the every function would be: +// +// func (t *tree) every(f func(int) bool) bool { +// return t == nil || +// t.left.every(f) && f(t.value) && t.right.every(f) +// } +// +// Then the iterator can be simply expressed as a trivial wrapper +// around this function: +// +// func (t *tree) All() iter.Seq[int] { +// return func(yield func(int) bool) { +// _ = t.every(yield) +// } +// } +// +// In effect, tree.All computes whether yield returns true for each +// element, short-circuiting if it every returns false, then discards +// the final boolean result. +// +// This has much better performance characteristics: it makes one +// dynamic call per element of the tree, and it doesn't heap-allocate +// anything. It is also clearer. +package recursiveiter diff --git a/gopls/internal/analysis/recursiveiter/main.go b/gopls/internal/analysis/recursiveiter/main.go new file mode 100644 index 00000000000..5f4b9720681 --- /dev/null +++ b/gopls/internal/analysis/recursiveiter/main.go @@ -0,0 +1,16 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build ignore + +// The recursiveiter command applies the yield analyzer to the +// specified packages of Go source code. +package main + +import ( + "golang.org/x/tools/go/analysis/singlechecker" + "golang.org/x/tools/gopls/internal/analysis/recursiveiter" +) + +func main() { singlechecker.Main(recursiveiter.Analyzer) } diff --git a/gopls/internal/analysis/recursiveiter/recursiveiter.go b/gopls/internal/analysis/recursiveiter/recursiveiter.go new file mode 100644 index 00000000000..0064a37386f --- /dev/null +++ b/gopls/internal/analysis/recursiveiter/recursiveiter.go @@ -0,0 +1,94 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package recursiveiter + +import ( + _ "embed" + "go/ast" + "go/types" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" + "golang.org/x/tools/go/types/typeutil" + "golang.org/x/tools/internal/analysisinternal" + typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex" + "golang.org/x/tools/internal/typesinternal/typeindex" +) + +//go:embed doc.go +var doc string + +var Analyzer = &analysis.Analyzer{ + Name: "recursiveiter", + Doc: analysisinternal.MustExtractDoc(doc, "recursiveiter"), + Requires: []*analysis.Analyzer{inspect.Analyzer, typeindexanalyzer.Analyzer}, + Run: run, + URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/recursiveiter", +} + +func run(pass *analysis.Pass) (any, error) { + var ( + inspector = pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index) + info = pass.TypesInfo + iterSeq = index.Object("iter", "Seq") + iterSeq2 = index.Object("iter", "Seq2") + ) + if iterSeq == nil || iterSeq2 == nil { + return nil, nil // fast path: no iterators + } + + // Search for a function or method f that returns an iter.Seq + // or Seq2 and calls itself recursively within a range stmt: + // + // func f(...) iter.Seq[E] { + // return func(yield func(E) bool) { + // ... + // for range f(...) { ... } + // } + // } + for curDecl := range inspector.Root().Preorder((*ast.FuncDecl)(nil)) { + decl := curDecl.Node().(*ast.FuncDecl) + fn := info.Defs[decl.Name].(*types.Func) + results := fn.Signature().Results() + if results.Len() != 1 { + continue // result not a singleton + } + retType, ok := results.At(0).Type().(*types.Named) + if !ok { + continue // result not a named type + } + switch retType.Origin().Obj() { + case iterSeq, iterSeq2: + default: + continue // result not iter.Seq{,2} + } + // Have: a FuncDecl that returns an iterator. + for curRet := range curDecl.Preorder((*ast.ReturnStmt)(nil)) { + ret := curRet.Node().(*ast.ReturnStmt) + if len(ret.Results) != 1 || !is[*ast.FuncLit](ret.Results[0]) { + continue // not "return func(){...}" + } + for curRange := range curRet.Preorder((*ast.RangeStmt)(nil)) { + rng := curRange.Node().(*ast.RangeStmt) + call, ok := rng.X.(*ast.CallExpr) + if !ok { + continue + } + if typeutil.StaticCallee(info, call) == fn { + pass.Reportf(rng.Range, "inefficient recursion in iterator %s", fn.Name()) + } + } + } + } + + return nil, nil +} + +func is[T any](x any) bool { + _, ok := x.(T) + return ok +} diff --git a/gopls/internal/analysis/recursiveiter/recursiveiter_test.go b/gopls/internal/analysis/recursiveiter/recursiveiter_test.go new file mode 100644 index 00000000000..9dcf6c8b996 --- /dev/null +++ b/gopls/internal/analysis/recursiveiter/recursiveiter_test.go @@ -0,0 +1,17 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package recursiveiter_test + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" + "golang.org/x/tools/gopls/internal/analysis/recursiveiter" +) + +func Test(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, recursiveiter.Analyzer, "a") +} diff --git a/gopls/internal/analysis/recursiveiter/testdata/src/a/a.go b/gopls/internal/analysis/recursiveiter/testdata/src/a/a.go new file mode 100644 index 00000000000..091e17513fb --- /dev/null +++ b/gopls/internal/analysis/recursiveiter/testdata/src/a/a.go @@ -0,0 +1,30 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package recursiveiter + +import "iter" + +type cons struct { + car int + cdr *cons +} + +func (cons *cons) All() iter.Seq[int] { + return func(yield func(int) bool) { + // The correct recursion is: + // func (cons *cons) all(f func(int) bool) { + // return cons == nil || yield(cons.car) && cons.cdr.all() + // } + // then: + // _ = cons.all(yield) + if cons != nil && yield(cons.car) { + for elem := range cons.All() { // want "inefficient recursion in iterator All" + if !yield(elem) { + break + } + } + } + } +} diff --git a/gopls/internal/doc/api.json b/gopls/internal/doc/api.json index 24fdc0e3835..fa73b711868 100644 --- a/gopls/internal/doc/api.json +++ b/gopls/internal/doc/api.json @@ -1526,6 +1526,12 @@ "Default": "true", "Status": "" }, + { + "Name": "\"recursiveiter\"", + "Doc": "check for inefficient recursive iterators\n\nThis analyzer reports when a function that returns an iterator\n(iter.Seq or iter.Seq2) calls itself as the operand of a range\nstatement, as this is inefficient.\n\nWhen implementing an iterator (e.g. iter.Seq[T]) for a recursive\ndata type such as a tree or linked list, it is tempting to\nrecursively range over the iterator for each child element.\n\nHere's an example of a naive iterator over a binary tree:\n\n\ttype tree struct {\n\t\tvalue int\n\t\tleft, right *tree\n\t}\n\n\tfunc (t *tree) All() iter.Seq[int] {\n\t\treturn func(yield func(int) bool) {\n\t\t\tif t != nil {\n\t\t\t\tfor elem := range t.left.All() { // \"inefficient recursive iterator\"\n\t\t\t\t\tif !yield(elem) {\n\t\t\t\t\t\treturn\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\tif !yield(t.value) {\n\t\t\t\t\treturn\n\t\t\t\t}\n\t\t\t\tfor elem := range t.right.All() { // \"inefficient recursive iterator\"\n\t\t\t\t\tif !yield(elem) {\n\t\t\t\t\t\treturn\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t}\n\nThough it correctly enumerates the elements of the tree, it hides a\nsignificant performance problem--two, in fact. Consider a balanced\ntree of N nodes. Iterating the root node will cause All to be\ncalled once on every node of the tree. This results in a chain of\nnested active range-over-func statements when yield(t.value) is\ncalled on a leaf node.\n\nThe first performance problem is that each range-over-func\nstatement must typically heap-allocate a variable, so iteration of\nthe tree allocates as many variables as there are elements in the\ntree, for a total of O(N) allocations, all unnecessary.\n\nThe second problem is that each call to yield for a leaf of the\ntree causes each of the enclosing range loops to receive a value,\nwhich they then immediately pass on to their respective yield\nfunction. This results in a chain of log(N) dynamic yield calls per\nelement, a total of O(N*log N) dynamic calls overall, when only\nO(N) are necessary.\n\nA better implementation strategy for recursive iterators is to\nfirst define the \"every\" operator for your recursive data type,\nwhere every(f) reports whether f(x) is true for every element x in\nthe data type. For our tree, the every function would be:\n\n\tfunc (t *tree) every(f func(int) bool) bool {\n\t\treturn t == nil ||\n\t\t\tt.left.every(f) \u0026\u0026 f(t.value) \u0026\u0026 t.right.every(f)\n\t}\n\nThen the iterator can be simply expressed as a trivial wrapper\naround this function:\n\n\tfunc (t *tree) All() iter.Seq[int] {\n\t\treturn func(yield func(int) bool) {\n\t\t\t_ = t.every(yield)\n\t\t}\n\t}\n\nIn effect, tree.All computes whether yield returns true for each\nelement, short-circuiting if it every returns false, then discards\nthe final boolean result.\n\nThis has much better performance characteristics: it makes one\ndynamic call per element of the tree, and it doesn't heap-allocate\nanything. It is also clearer.", + "Default": "true", + "Status": "" + }, { "Name": "\"shadow\"", "Doc": "check for possible unintended shadowing of variables\n\nThis analyzer check for shadowed variables.\nA shadowed variable is a variable declared in an inner scope\nwith the same name and type as a variable in an outer scope,\nand where the outer variable is mentioned after the inner one\nis declared.\n\n(This definition can be refined; the module generates too many\nfalse positives and is not yet enabled by default.)\n\nFor example:\n\n\tfunc BadRead(f *os.File, buf []byte) error {\n\t\tvar err error\n\t\tfor {\n\t\t\tn, err := f.Read(buf) // shadows the function variable 'err'\n\t\t\tif err != nil {\n\t\t\t\tbreak // causes return of wrong value\n\t\t\t}\n\t\t\tfoo(buf)\n\t\t}\n\t\treturn err\n\t}", @@ -3246,6 +3252,12 @@ "URL": "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/printf", "Default": true }, + { + "Name": "recursiveiter", + "Doc": "check for inefficient recursive iterators\n\nThis analyzer reports when a function that returns an iterator\n(iter.Seq or iter.Seq2) calls itself as the operand of a range\nstatement, as this is inefficient.\n\nWhen implementing an iterator (e.g. iter.Seq[T]) for a recursive\ndata type such as a tree or linked list, it is tempting to\nrecursively range over the iterator for each child element.\n\nHere's an example of a naive iterator over a binary tree:\n\n\ttype tree struct {\n\t\tvalue int\n\t\tleft, right *tree\n\t}\n\n\tfunc (t *tree) All() iter.Seq[int] {\n\t\treturn func(yield func(int) bool) {\n\t\t\tif t != nil {\n\t\t\t\tfor elem := range t.left.All() { // \"inefficient recursive iterator\"\n\t\t\t\t\tif !yield(elem) {\n\t\t\t\t\t\treturn\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\tif !yield(t.value) {\n\t\t\t\t\treturn\n\t\t\t\t}\n\t\t\t\tfor elem := range t.right.All() { // \"inefficient recursive iterator\"\n\t\t\t\t\tif !yield(elem) {\n\t\t\t\t\t\treturn\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t}\n\nThough it correctly enumerates the elements of the tree, it hides a\nsignificant performance problem--two, in fact. Consider a balanced\ntree of N nodes. Iterating the root node will cause All to be\ncalled once on every node of the tree. This results in a chain of\nnested active range-over-func statements when yield(t.value) is\ncalled on a leaf node.\n\nThe first performance problem is that each range-over-func\nstatement must typically heap-allocate a variable, so iteration of\nthe tree allocates as many variables as there are elements in the\ntree, for a total of O(N) allocations, all unnecessary.\n\nThe second problem is that each call to yield for a leaf of the\ntree causes each of the enclosing range loops to receive a value,\nwhich they then immediately pass on to their respective yield\nfunction. This results in a chain of log(N) dynamic yield calls per\nelement, a total of O(N*log N) dynamic calls overall, when only\nO(N) are necessary.\n\nA better implementation strategy for recursive iterators is to\nfirst define the \"every\" operator for your recursive data type,\nwhere every(f) reports whether f(x) is true for every element x in\nthe data type. For our tree, the every function would be:\n\n\tfunc (t *tree) every(f func(int) bool) bool {\n\t\treturn t == nil ||\n\t\t\tt.left.every(f) \u0026\u0026 f(t.value) \u0026\u0026 t.right.every(f)\n\t}\n\nThen the iterator can be simply expressed as a trivial wrapper\naround this function:\n\n\tfunc (t *tree) All() iter.Seq[int] {\n\t\treturn func(yield func(int) bool) {\n\t\t\t_ = t.every(yield)\n\t\t}\n\t}\n\nIn effect, tree.All computes whether yield returns true for each\nelement, short-circuiting if it every returns false, then discards\nthe final boolean result.\n\nThis has much better performance characteristics: it makes one\ndynamic call per element of the tree, and it doesn't heap-allocate\nanything. It is also clearer.", + "URL": "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/recursiveiter", + "Default": true + }, { "Name": "shadow", "Doc": "check for possible unintended shadowing of variables\n\nThis analyzer check for shadowed variables.\nA shadowed variable is a variable declared in an inner scope\nwith the same name and type as a variable in an outer scope,\nand where the outer variable is mentioned after the inner one\nis declared.\n\n(This definition can be refined; the module generates too many\nfalse positives and is not yet enabled by default.)\n\nFor example:\n\n\tfunc BadRead(f *os.File, buf []byte) error {\n\t\tvar err error\n\t\tfor {\n\t\t\tn, err := f.Read(buf) // shadows the function variable 'err'\n\t\t\tif err != nil {\n\t\t\t\tbreak // causes return of wrong value\n\t\t\t}\n\t\t\tfoo(buf)\n\t\t}\n\t\treturn err\n\t}", diff --git a/gopls/internal/settings/analysis.go b/gopls/internal/settings/analysis.go index 99b55cc6b24..48a783ac486 100644 --- a/gopls/internal/settings/analysis.go +++ b/gopls/internal/settings/analysis.go @@ -56,6 +56,7 @@ import ( "golang.org/x/tools/gopls/internal/analysis/modernize" "golang.org/x/tools/gopls/internal/analysis/nonewvars" "golang.org/x/tools/gopls/internal/analysis/noresultvalues" + "golang.org/x/tools/gopls/internal/analysis/recursiveiter" "golang.org/x/tools/gopls/internal/analysis/simplifycompositelit" "golang.org/x/tools/gopls/internal/analysis/simplifyrange" "golang.org/x/tools/gopls/internal/analysis/simplifyslice" @@ -207,8 +208,9 @@ var DefaultAnalyzers = []*Analyzer{ {analyzer: yield.Analyzer}, // uses go/ssa {analyzer: sortslice.Analyzer}, {analyzer: embeddirective.Analyzer}, - {analyzer: waitgroup.Analyzer}, // to appear in cmd/vet@go1.25 - {analyzer: hostport.Analyzer}, // to appear in cmd/vet@go1.25 + {analyzer: waitgroup.Analyzer}, // to appear in cmd/vet@go1.25 + {analyzer: hostport.Analyzer}, // to appear in cmd/vet@go1.25 + {analyzer: recursiveiter.Analyzer}, // under evaluation // disabled due to high false positives {analyzer: shadow.Analyzer, nonDefault: true}, // very noisy diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index 4288b756da9..261add7b3b7 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -784,7 +784,12 @@ func loadMarkerTest(name string, content []byte) (*markerTest, error) { files: make(map[string][]byte), golden: make(map[expect.Identifier]*Golden), } + seen := make(map[string]bool) for _, file := range archive.Files { + if seen[file.Name] { + return nil, fmt.Errorf("duplicate archive section %q", file.Name) + } + seen[file.Name] = true switch { case file.Name == "skip": reason := strings.ReplaceAll(string(file.Data), "\n", " ") diff --git a/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt b/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt index ba9f125ebd6..2fa9fbeb2cc 100644 --- a/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt +++ b/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt @@ -8,16 +8,18 @@ copylocks, printf, slog, tests, timeformat, nilness, and cgocall. -- go.mod -- module example.com -go 1.18 +go 1.23 -- flags -- +-min_go_command=go1.23 -cgo --- bad_test.go -- -package analyzer +-- bad/bad_test.go -- +package bad import ( "fmt" + "iter" "log/slog" "sync" "testing" @@ -32,7 +34,7 @@ func _() { // printf func _() { - printfWrapper("%s") //@diag(re`printfWrapper\(.*?\)`, re"example.com.printfWrapper format %s reads arg #1, but call has 0 args") + printfWrapper("%s") //@diag(re`printfWrapper\(.*?\)`, re"example.com/bad.printfWrapper format %s reads arg #1, but call has 0 args") } func printfWrapper(format string, args ...any) { @@ -76,12 +78,19 @@ func _() { // inline func _() { - f() //@diag("f", re"Call of analyzer.f should be inlined") + f() //@diag("f", re"Call of bad.f should be inlined") } //go:fix inline func f() { fmt.Println(1) } +// recursiveiter +func F() iter.Seq[int] { + return func(yield func(int) bool) { + for range F() {} //@ diag("range", re"inefficient recursion in iterator F") + } +} + -- cgocall/cgocall.go -- package cgocall @@ -112,4 +121,3 @@ func S1011(x, y []int) { x = append(x, e) // no "replace loop with append" diagnostic } } - From edbd9df7a79b6d773025c5b7a9b2c2adcb3ae9f1 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 26 Apr 2025 20:09:26 -0400 Subject: [PATCH 109/196] jsonschema: dynamic references A dynamic reference is like a dynamic variable in a programming language, with two twists: - There must be a dynamic anchor in lexical scope for the ref to resolve. (Think of it as declaring the dynamic variable.) - The ref resolves to the schema highest on the stack, not lowest. That feels odd, but it turns out to work for the intended purpose, which is to provide extension points for additional features. There is some distracting code motion and refactoring in this CL: resolveRefs was moved down in the file, and resolution of a single ref was split off. Change-Id: I17da6052e0782fc540d14b8a41b06a2645923872 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674155 Auto-Submit: Jonathan Amsterdam Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- internal/mcp/jsonschema/resolve.go | 186 ++-- internal/mcp/jsonschema/resolve_test.go | 17 +- internal/mcp/jsonschema/schema.go | 49 +- .../testdata/draft2020-12/anchor.json | 120 +++ .../testdata/draft2020-12/defs.json | 21 + .../testdata/draft2020-12/dynamicRef.json | 815 ++++++++++++++++++ internal/mcp/jsonschema/validate.go | 51 +- internal/mcp/jsonschema/validate_test.go | 5 - 8 files changed, 1163 insertions(+), 101 deletions(-) create mode 100644 internal/mcp/jsonschema/testdata/draft2020-12/anchor.json create mode 100644 internal/mcp/jsonschema/testdata/draft2020-12/defs.json create mode 100644 internal/mcp/jsonschema/testdata/draft2020-12/dynamicRef.json diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index 2ba51443773..f82eec1b78f 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -102,7 +102,7 @@ func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { // which may differ if the schema has an $id. // We must set the map before calling resolveRefs, or ref cycles will cause unbounded recursion. r.loaded[baseURI.String()] = rs - r.loaded[s.baseURI.String()] = rs + r.loaded[s.uri.String()] = rs if err := r.resolveRefs(rs); err != nil { return nil, err @@ -110,57 +110,6 @@ func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { return rs, nil } -// resolveRefs replaces all refs in the schemas with the schema they refer to. -// A reference that doesn't resolve within the schema may refer to some other schema -// that needs to be loaded. -func (r *resolver) resolveRefs(rs *Resolved) error { - for s := range rs.root.all() { - if s.Ref == "" { - continue - } - refURI, err := url.Parse(s.Ref) - if err != nil { - return err - } - // URI-resolve the ref against the current base URI to get a complete URI. - refURI = s.baseURI.ResolveReference(refURI) - // The non-fragment part of a ref URI refers to the base URI of some schema. - u := *refURI - u.Fragment = "" - fraglessRefURI := &u - // Look it up locally. - referencedSchema := rs.resolvedURIs[fraglessRefURI.String()] - if referencedSchema == nil { - // The schema is remote. Maybe we've already loaded it. - // We assume that the non-fragment part of refURI refers to a top-level schema - // document. That is, we don't support the case exemplified by - // http://foo.com/bar.json/baz, where the document is in bar.json and - // the reference points to a subschema within it. - // TODO: support that case. - loadedResolved := r.loaded[fraglessRefURI.String()] - if loadedResolved == nil { - // Try to load the schema. - ls, err := r.loader(fraglessRefURI) - if err != nil { - return fmt.Errorf("loading %s: %w", fraglessRefURI, err) - } - loadedResolved, err = r.resolve(ls, fraglessRefURI) - if err != nil { - return err - } - } - referencedSchema = loadedResolved.root - assert(referencedSchema != nil, "nil referenced schema") - } - // The fragment selects the referenced schema, or a subschema of it. - s.resolvedRef, err = lookupFragment(referencedSchema, refURI.Fragment) - if err != nil { - return err - } - } - return nil -} - func (root *Schema) check() error { if root == nil { return errors.New("nil schema") @@ -267,10 +216,7 @@ func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { var resolve func(s, base *Schema) error resolve = func(s, base *Schema) error { // ids are scoped to the root. - if s.ID == "" { - // If a schema doesn't have an $id, its base is the parent base. - s.baseURI = base.baseURI - } else { + if s.ID != "" { // A non-empty ID establishes a new base. idURI, err := url.Parse(s.ID) if err != nil { @@ -280,26 +226,33 @@ func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { return fmt.Errorf("$id %s must not have a fragment", s.ID) } // The base URI for this schema is its $id resolved against the parent base. - s.baseURI = base.baseURI.ResolveReference(idURI) - if !s.baseURI.IsAbs() { - return fmt.Errorf("$id %s does not resolve to an absolute URI (base is %s)", s.ID, s.baseURI) + s.uri = base.uri.ResolveReference(idURI) + if !s.uri.IsAbs() { + return fmt.Errorf("$id %s does not resolve to an absolute URI (base is %s)", s.ID, s.base.uri) } - resolvedURIs[s.baseURI.String()] = s + resolvedURIs[s.uri.String()] = s base = s // needed for anchors } + s.base = base - // Anchors are URI fragments that are scoped to their base. + // Anchors and dynamic anchors are URI fragments that are scoped to their base. // We treat them as keys in a map stored within the schema. - if s.Anchor != "" { - if base.anchors[s.Anchor] != nil { - return fmt.Errorf("duplicate anchor %q in %s", s.Anchor, base.baseURI) - } - if base.anchors == nil { - base.anchors = map[string]*Schema{} + setAnchor := func(anchor string, dynamic bool) error { + if anchor != "" { + if _, ok := base.anchors[anchor]; ok { + return fmt.Errorf("duplicate anchor %q in %s", anchor, base.uri) + } + if base.anchors == nil { + base.anchors = map[string]anchorInfo{} + } + base.anchors[anchor] = anchorInfo{s, dynamic} } - base.anchors[s.Anchor] = s + return nil } + setAnchor(s.Anchor, false) + setAnchor(s.DynamicAnchor, true) + for c := range s.children() { if err := resolve(c, base); err != nil { return err @@ -308,8 +261,8 @@ func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { return nil } - // Set the root URI to the base for now. If the root has an $id, the base will change. - root.baseURI = baseURI + // Set the root URI to the base for now. If the root has an $id, this will change. + root.uri = baseURI // The original base, even if changed, is still a valid way to refer to the root. resolvedURIs[baseURI.String()] = root if err := resolve(root, root); err != nil { @@ -318,18 +271,95 @@ func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { return resolvedURIs, nil } -// lookupFragment returns the schema referenced by frag in s, or an error -// if there isn't one or something else went wrong. -func lookupFragment(s *Schema, frag string) (*Schema, error) { +// resolveRefs replaces every ref in the schemas with the schema it refers to. +// A reference that doesn't resolve within the schema may refer to some other schema +// that needs to be loaded. +func (r *resolver) resolveRefs(rs *Resolved) error { + for s := range rs.root.all() { + if s.Ref != "" { + refSchema, _, err := r.resolveRef(rs, s, s.Ref) + if err != nil { + return err + } + // Whether or not the anchor referred to by $ref fragment is dynamic, + // the ref still treats it lexically. + s.resolvedRef = refSchema + } + if s.DynamicRef != "" { + refSchema, frag, err := r.resolveRef(rs, s, s.DynamicRef) + if err != nil { + return err + } + if frag != "" { + // The dynamic ref's fragment points to a dynamic anchor. + // We must resolve the fragment at validation time. + s.dynamicRefAnchor = frag + } else { + // There is no dynamic anchor in the lexically referenced schema, + // so the dynamic ref behaves like a lexical ref. + s.resolvedDynamicRef = refSchema + } + } + } + return nil +} + +// resolveRef resolves the reference ref, which is either s.Ref or s.DynamicRef. +func (r *resolver) resolveRef(rs *Resolved, s *Schema, ref string) (_ *Schema, dynamicFragment string, err error) { + refURI, err := url.Parse(ref) + if err != nil { + return nil, "", err + } + // URI-resolve the ref against the current base URI to get a complete URI. + refURI = s.base.uri.ResolveReference(refURI) + // The non-fragment part of a ref URI refers to the base URI of some schema. + // This part is the same for dynamic refs too: their non-fragment part resolves + // lexically. + u := *refURI + u.Fragment = "" + fraglessRefURI := &u + // Look it up locally. + referencedSchema := rs.resolvedURIs[fraglessRefURI.String()] + if referencedSchema == nil { + // The schema is remote. Maybe we've already loaded it. + // We assume that the non-fragment part of refURI refers to a top-level schema + // document. That is, we don't support the case exemplified by + // http://foo.com/bar.json/baz, where the document is in bar.json and + // the reference points to a subschema within it. + // TODO: support that case. + if lrs := r.loaded[fraglessRefURI.String()]; lrs != nil { + referencedSchema = lrs.root + } else { + // Try to load the schema. + ls, err := r.loader(fraglessRefURI) + if err != nil { + return nil, "", fmt.Errorf("loading %s: %w", fraglessRefURI, err) + } + lrs, err := r.resolve(ls, fraglessRefURI) + if err != nil { + return nil, "", err + } + referencedSchema = lrs.root + assert(referencedSchema != nil, "nil referenced schema") + } + } + + frag := refURI.Fragment + // Look up frag in refSchema. // frag is either a JSON Pointer or the name of an anchor. // A JSON Pointer is either the empty string or begins with a '/', // whereas anchors are always non-empty strings that don't contain slashes. if frag != "" && !strings.HasPrefix(frag, "/") { - if fs := s.anchors[frag]; fs != nil { - return fs, nil + info, found := referencedSchema.anchors[frag] + if !found { + return nil, "", fmt.Errorf("no anchor %q in %s", frag, s) + } + if info.dynamic { + dynamicFragment = frag } - return nil, fmt.Errorf("no anchor %q in %s", frag, s) + return info.schema, dynamicFragment, nil } - // frag is a JSON Pointer. Follow it. - return dereferenceJSONPointer(s, frag) + // frag is a JSON Pointer. + s, err = dereferenceJSONPointer(referencedSchema, frag) + return s, "", err } diff --git a/internal/mcp/jsonschema/resolve_test.go b/internal/mcp/jsonschema/resolve_test.go index 67b2fe0f687..717bdbb0c04 100644 --- a/internal/mcp/jsonschema/resolve_test.go +++ b/internal/mcp/jsonschema/resolve_test.go @@ -66,8 +66,9 @@ func TestResolveURIs(t *testing.T) { ID: "/foo.json", }, Contains: &Schema{ - ID: "/bar.json", - Anchor: "a", + ID: "/bar.json", + Anchor: "a", + DynamicAnchor: "da", Items: &Schema{ Anchor: "b", Items: &Schema{ @@ -97,9 +98,15 @@ func TestResolveURIs(t *testing.T) { if baseURI != root.ID { wantIDs[root.ID] = root } - wantAnchors := map[*Schema]map[string]*Schema{ - root.Contains: {"a": root.Contains, "b": root.Contains.Items}, - root.Contains.Items.Items: {"c": root.Contains.Items.Items}, + wantAnchors := map[*Schema]map[string]anchorInfo{ + root.Contains: { + "a": anchorInfo{root.Contains, false}, + "da": anchorInfo{root.Contains, true}, + "b": anchorInfo{root.Contains.Items, false}, + }, + root.Contains.Items.Items: { + "c": anchorInfo{root.Contains.Items.Items, false}, + }, } gotKeys := slices.Sorted(maps.Keys(got)) diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index f1e16a5decc..1ec4b0d4bf2 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -8,6 +8,7 @@ package jsonschema import ( "bytes" + "cmp" "encoding/json" "errors" "fmt" @@ -110,15 +111,35 @@ type Schema struct { DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` // computed fields - // If the schema doesn't have an ID, the base URI is that of its parent. - // Otherwise, the base URI is the ID resolved against the parent's baseURI. - // The parent base URI at top level is where the schema was loaded from, or - // if not loaded, then it should be provided to Schema.Resolve. - baseURI *url.URL + + // This schema's base schema. + // If the schema is the root or has an ID, its base is itself. + // Otherwise, its base is the innermost enclosing schema whose base + // is itself. + // Intuitively, a base schema is one that can be referred to with a + // fragmentless URI. + base *Schema + + // The URI for the schema, if it is the root or has an ID. + // Otherwise nil. + // Invariants: + // s.base.uri != nil. + // s.base == s <=> s.uri != nil + uri *url.URL + // The schema to which Ref refers. resolvedRef *Schema - // map from anchors to subschemas - anchors map[string]*Schema + + // If the schema has a dynamic ref, exactly one of the next two fields + // will be non-zero after successful resolution. + // The schema to which the dynamic ref refers when it acts lexically. + resolvedDynamicRef *Schema + // The anchor to look up on the stack when the dynamic ref acts dynamically. + dynamicRefAnchor string + + // Map from anchors to subschemas. + anchors map[string]anchorInfo + // compiled regexps pattern *regexp.Regexp patternProperties map[*regexp.Regexp]*Schema @@ -129,10 +150,20 @@ func falseSchema() *Schema { return &Schema{Not: &Schema{}} } +// anchorInfo records the subschema to which an anchor refers, and whether +// the anchor keyword is $anchor or $dynamicAnchor. +type anchorInfo struct { + schema *Schema + dynamic bool +} + // String returns a short description of the schema. func (s *Schema) String() string { - if s.ID != "" { - return s.ID + if s.uri != nil { + return s.uri.String() + } + if a := cmp.Or(s.Anchor, s.DynamicAnchor); a != "" { + return fmt.Sprintf("%q, anchor %s", s.base.uri.String(), a) } // TODO: return something better, like a JSON Pointer from the base. return "" diff --git a/internal/mcp/jsonschema/testdata/draft2020-12/anchor.json b/internal/mcp/jsonschema/testdata/draft2020-12/anchor.json new file mode 100644 index 00000000000..99143fa1160 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/draft2020-12/anchor.json @@ -0,0 +1,120 @@ +[ + { + "description": "Location-independent identifier", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "#foo", + "$defs": { + "A": { + "$anchor": "foo", + "type": "integer" + } + } + }, + "tests": [ + { + "data": 1, + "description": "match", + "valid": true + }, + { + "data": "a", + "description": "mismatch", + "valid": false + } + ] + }, + { + "description": "Location-independent identifier with absolute URI", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "http://localhost:1234/draft2020-12/bar#foo", + "$defs": { + "A": { + "$id": "http://localhost:1234/draft2020-12/bar", + "$anchor": "foo", + "type": "integer" + } + } + }, + "tests": [ + { + "data": 1, + "description": "match", + "valid": true + }, + { + "data": "a", + "description": "mismatch", + "valid": false + } + ] + }, + { + "description": "Location-independent identifier with base URI change in subschema", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/root", + "$ref": "http://localhost:1234/draft2020-12/nested.json#foo", + "$defs": { + "A": { + "$id": "nested.json", + "$defs": { + "B": { + "$anchor": "foo", + "type": "integer" + } + } + } + } + }, + "tests": [ + { + "data": 1, + "description": "match", + "valid": true + }, + { + "data": "a", + "description": "mismatch", + "valid": false + } + ] + }, + { + "description": "same $anchor with different base uri", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/foobar", + "$defs": { + "A": { + "$id": "child1", + "allOf": [ + { + "$id": "child2", + "$anchor": "my_anchor", + "type": "number" + }, + { + "$anchor": "my_anchor", + "type": "string" + } + ] + } + }, + "$ref": "child1#my_anchor" + }, + "tests": [ + { + "description": "$ref resolves to /$defs/A/allOf/1", + "data": "a", + "valid": true + }, + { + "description": "$ref does not resolve to /$defs/A/allOf/0", + "data": 1, + "valid": false + } + ] + } +] diff --git a/internal/mcp/jsonschema/testdata/draft2020-12/defs.json b/internal/mcp/jsonschema/testdata/draft2020-12/defs.json new file mode 100644 index 00000000000..da2a503bfb9 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/draft2020-12/defs.json @@ -0,0 +1,21 @@ +[ + { + "description": "validate definition against metaschema", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "https://json-schema.org/draft/2020-12/schema" + }, + "tests": [ + { + "description": "valid definition schema", + "data": {"$defs": {"foo": {"type": "integer"}}}, + "valid": true + }, + { + "description": "invalid definition schema", + "data": {"$defs": {"foo": {"type": 1}}}, + "valid": false + } + ] + } +] diff --git a/internal/mcp/jsonschema/testdata/draft2020-12/dynamicRef.json b/internal/mcp/jsonschema/testdata/draft2020-12/dynamicRef.json new file mode 100644 index 00000000000..ffa211ba2f6 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/draft2020-12/dynamicRef.json @@ -0,0 +1,815 @@ +[ + { + "description": "A $dynamicRef to a $dynamicAnchor in the same schema resource behaves like a normal $ref to an $anchor", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/dynamicRef-dynamicAnchor-same-schema/root", + "type": "array", + "items": { "$dynamicRef": "#items" }, + "$defs": { + "foo": { + "$dynamicAnchor": "items", + "type": "string" + } + } + }, + "tests": [ + { + "description": "An array of strings is valid", + "data": ["foo", "bar"], + "valid": true + }, + { + "description": "An array containing non-strings is invalid", + "data": ["foo", 42], + "valid": false + } + ] + }, + { + "description": "A $dynamicRef to an $anchor in the same schema resource behaves like a normal $ref to an $anchor", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/dynamicRef-anchor-same-schema/root", + "type": "array", + "items": { "$dynamicRef": "#items" }, + "$defs": { + "foo": { + "$anchor": "items", + "type": "string" + } + } + }, + "tests": [ + { + "description": "An array of strings is valid", + "data": ["foo", "bar"], + "valid": true + }, + { + "description": "An array containing non-strings is invalid", + "data": ["foo", 42], + "valid": false + } + ] + }, + { + "description": "A $ref to a $dynamicAnchor in the same schema resource behaves like a normal $ref to an $anchor", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/ref-dynamicAnchor-same-schema/root", + "type": "array", + "items": { "$ref": "#items" }, + "$defs": { + "foo": { + "$dynamicAnchor": "items", + "type": "string" + } + } + }, + "tests": [ + { + "description": "An array of strings is valid", + "data": ["foo", "bar"], + "valid": true + }, + { + "description": "An array containing non-strings is invalid", + "data": ["foo", 42], + "valid": false + } + ] + }, + { + "description": "A $dynamicRef resolves to the first $dynamicAnchor still in scope that is encountered when the schema is evaluated", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/typical-dynamic-resolution/root", + "$ref": "list", + "$defs": { + "foo": { + "$dynamicAnchor": "items", + "type": "string" + }, + "list": { + "$id": "list", + "type": "array", + "items": { "$dynamicRef": "#items" }, + "$defs": { + "items": { + "$comment": "This is only needed to satisfy the bookending requirement", + "$dynamicAnchor": "items" + } + } + } + } + }, + "tests": [ + { + "description": "An array of strings is valid", + "data": ["foo", "bar"], + "valid": true + }, + { + "description": "An array containing non-strings is invalid", + "data": ["foo", 42], + "valid": false + } + ] + }, + { + "description": "A $dynamicRef without anchor in fragment behaves identical to $ref", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/dynamicRef-without-anchor/root", + "$ref": "list", + "$defs": { + "foo": { + "$dynamicAnchor": "items", + "type": "string" + }, + "list": { + "$id": "list", + "type": "array", + "items": { "$dynamicRef": "#/$defs/items" }, + "$defs": { + "items": { + "$comment": "This is only needed to satisfy the bookending requirement", + "$dynamicAnchor": "items", + "type": "number" + } + } + } + } + }, + "tests": [ + { + "description": "An array of strings is invalid", + "data": ["foo", "bar"], + "valid": false + }, + { + "description": "An array of numbers is valid", + "data": [24, 42], + "valid": true + } + ] + }, + { + "description": "A $dynamicRef with intermediate scopes that don't include a matching $dynamicAnchor does not affect dynamic scope resolution", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/dynamic-resolution-with-intermediate-scopes/root", + "$ref": "intermediate-scope", + "$defs": { + "foo": { + "$dynamicAnchor": "items", + "type": "string" + }, + "intermediate-scope": { + "$id": "intermediate-scope", + "$ref": "list" + }, + "list": { + "$id": "list", + "type": "array", + "items": { "$dynamicRef": "#items" }, + "$defs": { + "items": { + "$comment": "This is only needed to satisfy the bookending requirement", + "$dynamicAnchor": "items" + } + } + } + } + }, + "tests": [ + { + "description": "An array of strings is valid", + "data": ["foo", "bar"], + "valid": true + }, + { + "description": "An array containing non-strings is invalid", + "data": ["foo", 42], + "valid": false + } + ] + }, + { + "description": "An $anchor with the same name as a $dynamicAnchor is not used for dynamic scope resolution", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/dynamic-resolution-ignores-anchors/root", + "$ref": "list", + "$defs": { + "foo": { + "$anchor": "items", + "type": "string" + }, + "list": { + "$id": "list", + "type": "array", + "items": { "$dynamicRef": "#items" }, + "$defs": { + "items": { + "$comment": "This is only needed to satisfy the bookending requirement", + "$dynamicAnchor": "items" + } + } + } + } + }, + "tests": [ + { + "description": "Any array is valid", + "data": ["foo", 42], + "valid": true + } + ] + }, + { + "description": "A $dynamicRef without a matching $dynamicAnchor in the same schema resource behaves like a normal $ref to $anchor", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/dynamic-resolution-without-bookend/root", + "$ref": "list", + "$defs": { + "foo": { + "$dynamicAnchor": "items", + "type": "string" + }, + "list": { + "$id": "list", + "type": "array", + "items": { "$dynamicRef": "#items" }, + "$defs": { + "items": { + "$comment": "This is only needed to give the reference somewhere to resolve to when it behaves like $ref", + "$anchor": "items" + } + } + } + } + }, + "tests": [ + { + "description": "Any array is valid", + "data": ["foo", 42], + "valid": true + } + ] + }, + { + "description": "A $dynamicRef with a non-matching $dynamicAnchor in the same schema resource behaves like a normal $ref to $anchor", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/unmatched-dynamic-anchor/root", + "$ref": "list", + "$defs": { + "foo": { + "$dynamicAnchor": "items", + "type": "string" + }, + "list": { + "$id": "list", + "type": "array", + "items": { "$dynamicRef": "#items" }, + "$defs": { + "items": { + "$comment": "This is only needed to give the reference somewhere to resolve to when it behaves like $ref", + "$anchor": "items", + "$dynamicAnchor": "foo" + } + } + } + } + }, + "tests": [ + { + "description": "Any array is valid", + "data": ["foo", 42], + "valid": true + } + ] + }, + { + "description": "A $dynamicRef that initially resolves to a schema with a matching $dynamicAnchor resolves to the first $dynamicAnchor in the dynamic scope", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/relative-dynamic-reference/root", + "$dynamicAnchor": "meta", + "type": "object", + "properties": { + "foo": { "const": "pass" } + }, + "$ref": "extended", + "$defs": { + "extended": { + "$id": "extended", + "$dynamicAnchor": "meta", + "type": "object", + "properties": { + "bar": { "$ref": "bar" } + } + }, + "bar": { + "$id": "bar", + "type": "object", + "properties": { + "baz": { "$dynamicRef": "extended#meta" } + } + } + } + }, + "tests": [ + { + "description": "The recursive part is valid against the root", + "data": { + "foo": "pass", + "bar": { + "baz": { "foo": "pass" } + } + }, + "valid": true + }, + { + "description": "The recursive part is not valid against the root", + "data": { + "foo": "pass", + "bar": { + "baz": { "foo": "fail" } + } + }, + "valid": false + } + ] + }, + { + "description": "A $dynamicRef that initially resolves to a schema without a matching $dynamicAnchor behaves like a normal $ref to $anchor", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/relative-dynamic-reference-without-bookend/root", + "$dynamicAnchor": "meta", + "type": "object", + "properties": { + "foo": { "const": "pass" } + }, + "$ref": "extended", + "$defs": { + "extended": { + "$id": "extended", + "$anchor": "meta", + "type": "object", + "properties": { + "bar": { "$ref": "bar" } + } + }, + "bar": { + "$id": "bar", + "type": "object", + "properties": { + "baz": { "$dynamicRef": "extended#meta" } + } + } + } + }, + "tests": [ + { + "description": "The recursive part doesn't need to validate against the root", + "data": { + "foo": "pass", + "bar": { + "baz": { "foo": "fail" } + } + }, + "valid": true + } + ] + }, + { + "description": "multiple dynamic paths to the $dynamicRef keyword", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/dynamic-ref-with-multiple-paths/main", + "if": { + "properties": { + "kindOfList": { "const": "numbers" } + }, + "required": ["kindOfList"] + }, + "then": { "$ref": "numberList" }, + "else": { "$ref": "stringList" }, + + "$defs": { + "genericList": { + "$id": "genericList", + "properties": { + "list": { + "items": { "$dynamicRef": "#itemType" } + } + }, + "$defs": { + "defaultItemType": { + "$comment": "Only needed to satisfy bookending requirement", + "$dynamicAnchor": "itemType" + } + } + }, + "numberList": { + "$id": "numberList", + "$defs": { + "itemType": { + "$dynamicAnchor": "itemType", + "type": "number" + } + }, + "$ref": "genericList" + }, + "stringList": { + "$id": "stringList", + "$defs": { + "itemType": { + "$dynamicAnchor": "itemType", + "type": "string" + } + }, + "$ref": "genericList" + } + } + }, + "tests": [ + { + "description": "number list with number values", + "data": { + "kindOfList": "numbers", + "list": [1.1] + }, + "valid": true + }, + { + "description": "number list with string values", + "data": { + "kindOfList": "numbers", + "list": ["foo"] + }, + "valid": false + }, + { + "description": "string list with number values", + "data": { + "kindOfList": "strings", + "list": [1.1] + }, + "valid": false + }, + { + "description": "string list with string values", + "data": { + "kindOfList": "strings", + "list": ["foo"] + }, + "valid": true + } + ] + }, + { + "description": "after leaving a dynamic scope, it is not used by a $dynamicRef", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/dynamic-ref-leaving-dynamic-scope/main", + "if": { + "$id": "first_scope", + "$defs": { + "thingy": { + "$comment": "this is first_scope#thingy", + "$dynamicAnchor": "thingy", + "type": "number" + } + } + }, + "then": { + "$id": "second_scope", + "$ref": "start", + "$defs": { + "thingy": { + "$comment": "this is second_scope#thingy, the final destination of the $dynamicRef", + "$dynamicAnchor": "thingy", + "type": "null" + } + } + }, + "$defs": { + "start": { + "$comment": "this is the landing spot from $ref", + "$id": "start", + "$dynamicRef": "inner_scope#thingy" + }, + "thingy": { + "$comment": "this is the first stop for the $dynamicRef", + "$id": "inner_scope", + "$dynamicAnchor": "thingy", + "type": "string" + } + } + }, + "tests": [ + { + "description": "string matches /$defs/thingy, but the $dynamicRef does not stop here", + "data": "a string", + "valid": false + }, + { + "description": "first_scope is not in dynamic scope for the $dynamicRef", + "data": 42, + "valid": false + }, + { + "description": "/then/$defs/thingy is the final stop for the $dynamicRef", + "data": null, + "valid": true + } + ] + }, + { + "description": "strict-tree schema, guards against misspelled properties", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/strict-tree.json", + "$dynamicAnchor": "node", + + "$ref": "tree.json", + "unevaluatedProperties": false + }, + "tests": [ + { + "description": "instance with misspelled field", + "data": { + "children": [{ + "daat": 1 + }] + }, + "valid": false + }, + { + "description": "instance with correct field", + "data": { + "children": [{ + "data": 1 + }] + }, + "valid": true + } + ] + }, + { + "description": "tests for implementation dynamic anchor and reference link", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/strict-extendible.json", + "$ref": "extendible-dynamic-ref.json", + "$defs": { + "elements": { + "$dynamicAnchor": "elements", + "properties": { + "a": true + }, + "required": ["a"], + "additionalProperties": false + } + } + }, + "tests": [ + { + "description": "incorrect parent schema", + "data": { + "a": true + }, + "valid": false + }, + { + "description": "incorrect extended schema", + "data": { + "elements": [ + { "b": 1 } + ] + }, + "valid": false + }, + { + "description": "correct extended schema", + "data": { + "elements": [ + { "a": 1 } + ] + }, + "valid": true + } + ] + }, + { + "description": "$ref and $dynamicAnchor are independent of order - $defs first", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/strict-extendible-allof-defs-first.json", + "allOf": [ + { + "$ref": "extendible-dynamic-ref.json" + }, + { + "$defs": { + "elements": { + "$dynamicAnchor": "elements", + "properties": { + "a": true + }, + "required": ["a"], + "additionalProperties": false + } + } + } + ] + }, + "tests": [ + { + "description": "incorrect parent schema", + "data": { + "a": true + }, + "valid": false + }, + { + "description": "incorrect extended schema", + "data": { + "elements": [ + { "b": 1 } + ] + }, + "valid": false + }, + { + "description": "correct extended schema", + "data": { + "elements": [ + { "a": 1 } + ] + }, + "valid": true + } + ] + }, + { + "description": "$ref and $dynamicAnchor are independent of order - $ref first", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "http://localhost:1234/draft2020-12/strict-extendible-allof-ref-first.json", + "allOf": [ + { + "$defs": { + "elements": { + "$dynamicAnchor": "elements", + "properties": { + "a": true + }, + "required": ["a"], + "additionalProperties": false + } + } + }, + { + "$ref": "extendible-dynamic-ref.json" + } + ] + }, + "tests": [ + { + "description": "incorrect parent schema", + "data": { + "a": true + }, + "valid": false + }, + { + "description": "incorrect extended schema", + "data": { + "elements": [ + { "b": 1 } + ] + }, + "valid": false + }, + { + "description": "correct extended schema", + "data": { + "elements": [ + { "a": 1 } + ] + }, + "valid": true + } + ] + }, + { + "description": "$ref to $dynamicRef finds detached $dynamicAnchor", + "schema": { + "$ref": "http://localhost:1234/draft2020-12/detached-dynamicref.json#/$defs/foo" + }, + "tests": [ + { + "description": "number is valid", + "data": 1, + "valid": true + }, + { + "description": "non-number is invalid", + "data": "a", + "valid": false + } + ] + }, + { + "description": "$dynamicRef points to a boolean schema", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "true": true, + "false": false + }, + "properties": { + "true": { + "$dynamicRef": "#/$defs/true" + }, + "false": { + "$dynamicRef": "#/$defs/false" + } + } + }, + "tests": [ + { + "description": "follow $dynamicRef to a true schema", + "data": { "true": 1 }, + "valid": true + }, + { + "description": "follow $dynamicRef to a false schema", + "data": { "false": 1 }, + "valid": false + } + ] + }, + { + "description": "$dynamicRef skips over intermediate resources - direct reference", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://test.json-schema.org/dynamic-ref-skips-intermediate-resource/main", + "type": "object", + "properties": { + "bar-item": { + "$ref": "item" + } + }, + "$defs": { + "bar": { + "$id": "bar", + "type": "array", + "items": { + "$ref": "item" + }, + "$defs": { + "item": { + "$id": "item", + "type": "object", + "properties": { + "content": { + "$dynamicRef": "#content" + } + }, + "$defs": { + "defaultContent": { + "$dynamicAnchor": "content", + "type": "integer" + } + } + }, + "content": { + "$dynamicAnchor": "content", + "type": "string" + } + } + } + } + }, + "tests": [ + { + "description": "integer property passes", + "data": { "bar-item": { "content": 42 } }, + "valid": true + }, + { + "description": "string property fails", + "data": { "bar-item": { "content": "value" } }, + "valid": false + } + ] + } +] diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index b529e232ad5..bc1701428c6 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -35,6 +35,10 @@ func (rs *Resolved) Validate(instance any) error { type state struct { rs *Resolved depth int + // stack holds the schemas from recursive calls to validate. + // These are the "dynamic scopes" used to resolve dynamic references. + // https://json-schema.org/draft/2020-12/json-schema-core#scopes + stack []*Schema } // validate validates the reflected value of the instance. @@ -48,10 +52,12 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an } }() - st.depth++ - defer func() { st.depth-- }() - if st.depth >= 100 { - return fmt.Errorf("max recursion depth of %d reached", st.depth) + st.stack = append(st.stack, schema) // push + defer func() { + st.stack = st.stack[:len(st.stack)-1] // pop + }() + if depth := len(st.stack); depth >= 100 { + return fmt.Errorf("max recursion depth of %d reached", depth) } // We checked for nil schemas in [Schema.Resolve]. @@ -162,6 +168,43 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an } } + // $dynamicRef: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2 + if schema.DynamicRef != "" { + // The ref behaves lexically or dynamically, but not both. + assert((schema.resolvedDynamicRef == nil) != (schema.dynamicRefAnchor == ""), + "DynamicRef not resolved properly") + if schema.resolvedDynamicRef != nil { + // Same as $ref. + if err := st.validate(instance, schema.resolvedDynamicRef, &anns, path); err != nil { + return err + } + } else { + // Dynamic behavior. + // Look for the base of the outermost schema on the stack with this dynamic + // anchor. (Yes, outermost: the one farthest from here. This the opposite + // of how ordinary dynamic variables behave.) + // Why the base of the schema being validated and not the schema itself? + // Because the base is the scope for anchors. In fact it's possible to + // refer to a schema that is not on the stack, but a child of some base + // on the stack. + // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. + var dynamicSchema *Schema + for _, s := range st.stack { + info, ok := s.base.anchors[schema.dynamicRefAnchor] + if ok && info.dynamic { + dynamicSchema = info.schema + break + } + } + if dynamicSchema == nil { + return fmt.Errorf("missing dynamic anchor %q", schema.dynamicRefAnchor) + } + if err := st.validate(instance, dynamicSchema, &anns, path); err != nil { + return err + } + } + } + // logic // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2 // These must happen before arrays and objects because if they evaluate an item or property, diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index bd66560ef83..3d096dfcef1 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -54,11 +54,6 @@ func TestValidate(t *testing.T) { } for _, g := range groups { t.Run(g.Description, func(t *testing.T) { - for s := range g.Schema.all() { - if s.DynamicAnchor != "" || s.DynamicRef != "" { - t.Skip("schema or subschema has unimplemented keywords") - } - } rs, err := g.Schema.Resolve("", loadRemote) if err != nil { t.Fatal(err) From 87749a790071448436810a689ae374a82f7c0a09 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 21 May 2025 13:47:00 +0000 Subject: [PATCH 110/196] all: fix typos Fix some typos found during the Google import of x/tools. Change-Id: I8ea38d3e1bd09bedbc98ed8c8146e88e84aa051f Reviewed-on: https://go-review.googlesource.com/c/tools/+/674975 Auto-Submit: Robert Findley Reviewed-by: Hongxiang Jiang LUCI-TryBot-Result: Go LUCI --- gopls/internal/cmd/serve.go | 8 ++++---- internal/mcp/design/design.md | 2 +- internal/mcp/shared.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/gopls/internal/cmd/serve.go b/gopls/internal/cmd/serve.go index 761895a73e2..67f27093af0 100644 --- a/gopls/internal/cmd/serve.go +++ b/gopls/internal/cmd/serve.go @@ -125,14 +125,14 @@ func (s *Serve) Run(ctx context.Context, args ...string) error { group, ctx := errgroup.WithContext(ctx) // Indicate success by a special error so that successful termination // of one server causes cancellation of the other. - sucess := errors.New("success") + success := errors.New("success") // Start MCP server. if eventChan != nil { group.Go(func() (err error) { defer func() { if err == nil { - err = sucess + err = success } }() @@ -149,7 +149,7 @@ func (s *Serve) Run(ctx context.Context, args ...string) error { close(eventChan) } if err == nil { - err = sucess + err = success } }() @@ -200,7 +200,7 @@ func (s *Serve) Run(ctx context.Context, args ...string) error { // Wait for all servers to terminate, returning only the first error // encountered. Subsequent errors are typically due to context cancellation // and are disregarded. - if err := group.Wait(); err != nil && !errors.Is(err, sucess) { + if err := group.Wait(); err != nil && !errors.Is(err, success) { return err } return nil diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 2cf0f5a2d0c..91963d8f84d 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -744,7 +744,7 @@ func (s *Server) RemoveResources(uris ...string) func (s *Server) RemoveResourceTemplates(uriTemplates ...string) ``` -The `ReadResource` method finds a resource or resource template matching the argument URI and calls its assocated handler. +The `ReadResource` method finds a resource or resource template matching the argument URI and calls its associated handler. To read files from the local filesystem, we recommend using `FileResourceHandler` to construct a handler: diff --git a/internal/mcp/shared.go b/internal/mcp/shared.go index e82ce20cab5..cd05f407743 100644 --- a/internal/mcp/shared.go +++ b/internal/mcp/shared.go @@ -139,7 +139,7 @@ func sessionMethod[S ClientSession | ServerSession, P, R any](f func(*S, context const ( // The error code to return when a resource isn't found. // See https://modelcontextprotocol.io/specification/2025-03-26/server/resources#error-handling - // However, the code they chose in in the wrong space + // However, the code they chose is in the wrong space // (see https://github.com/modelcontextprotocol/modelcontextprotocol/issues/509). // so we pick a different one, arbirarily for now (until they fix it). // The immediate problem is that jsonprc2 defines -32002 as "server closing". From dc3456874973f675cfc03ab8cb9020111c111968 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 20 May 2025 16:06:34 -0400 Subject: [PATCH 111/196] gopls/internal/test/integration/completion: relax expectations go1.25 adds TB.Attr; relax the test expecations accordingly. Change-Id: I5aa2bc28f8f5935f4438c8548b4c02ef8d6c90c4 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674595 LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil --- .../test/integration/completion/completion18_test.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/gopls/internal/test/integration/completion/completion18_test.go b/gopls/internal/test/integration/completion/completion18_test.go index a35061d693b..c834878e4f5 100644 --- a/gopls/internal/test/integration/completion/completion18_test.go +++ b/gopls/internal/test/integration/completion/completion18_test.go @@ -90,13 +90,18 @@ func FuzzHex(f *testing.F) { -- c_test.go -- ` + part0 + part1 + part2 + ad := []string{"Add"} + if _, ok := any(t).(interface{ Attr(k, v string) }); ok { // go1.25 added TBF.Attr + ad = append(ad, "Attr") + } + tests := []struct { file string pat string offset uint32 // UTF16 length from the beginning of pat to what the user just typed want []string }{ - {"a_test.go", "f.Ad", 3, []string{"Add"}}, + {"a_test.go", "f.Ad", 3, ad}, {"c_test.go", " f.F", 4, []string{"Failed"}}, {"c_test.go", "f.N", 3, []string{"Name"}}, {"b_test.go", "f.F", 3, []string{"Fuzz(func(t *testing.T, a []byte)", "Fail", "FailNow", @@ -111,7 +116,7 @@ func FuzzHex(f *testing.F) { completions := env.Completion(loc) result := compareCompletionLabels(test.want, completions.Items) if result != "" { - t.Errorf("pat %q %q", test.pat, result) + t.Errorf("pat=%q <<%s>>", test.pat, result) for i, it := range completions.Items { t.Errorf("%d got %q %q", i, it.Label, it.Detail) } From 423c5afcceff141ac88ff673700825624dcf6aa2 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 21 May 2025 15:47:37 -0400 Subject: [PATCH 112/196] gopls/internal/analysis/recursiveiter: set Diagnostic.End Change-Id: Ie77f08435ded0ce54ac2a8e23691546dd5c76343 Reviewed-on: https://go-review.googlesource.com/c/tools/+/675275 Auto-Submit: Alan Donovan Commit-Queue: Alan Donovan Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- gopls/internal/analysis/recursiveiter/recursiveiter.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/gopls/internal/analysis/recursiveiter/recursiveiter.go b/gopls/internal/analysis/recursiveiter/recursiveiter.go index 0064a37386f..364855ba418 100644 --- a/gopls/internal/analysis/recursiveiter/recursiveiter.go +++ b/gopls/internal/analysis/recursiveiter/recursiveiter.go @@ -6,6 +6,7 @@ package recursiveiter import ( _ "embed" + "fmt" "go/ast" "go/types" @@ -79,7 +80,11 @@ func run(pass *analysis.Pass) (any, error) { continue } if typeutil.StaticCallee(info, call) == fn { - pass.Reportf(rng.Range, "inefficient recursion in iterator %s", fn.Name()) + pass.Report(analysis.Diagnostic{ + Pos: rng.Range, + End: rng.X.End(), + Message: fmt.Sprintf("inefficient recursion in iterator %s", fn.Name()), + }) } } } From aebd3be41bc6c85777b98b34cefc2c561c5f85b0 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 21 May 2025 14:59:43 -0400 Subject: [PATCH 113/196] gopls/internal/test/integration: simplify GoToDefinition - Definition no longer calls OpenFile on the first location; it merely returns the list. Also, rename to Definitions. - GoToDefinition is renamed FirstDefinition; it asserts that the list is non-empty. - Tests that assumed OpenFile now call it explicitly. These changes should make it easier to write both positive and negative marker tests of the Definitions result set. Change-Id: I1e57bba8c78983f7b266ad876b1862b8af86a797 Reviewed-on: https://go-review.googlesource.com/c/tools/+/675016 LUCI-TryBot-Result: Go LUCI Reviewed-by: Madeline Kalil --- .../test/integration/bench/definition_test.go | 4 +- .../integration/diagnostics/builtin_test.go | 2 +- .../diagnostics/diagnostics_test.go | 2 +- .../internal/test/integration/fake/editor.go | 44 ++------- .../test/integration/misc/definition_test.go | 94 +++++++++++-------- .../test/integration/misc/highlight_test.go | 11 ++- .../test/integration/misc/hover_test.go | 6 +- .../test/integration/misc/imports_test.go | 2 +- .../test/integration/misc/references_test.go | 7 +- .../test/integration/modfile/modfile_test.go | 4 +- .../integration/template/template_test.go | 3 +- .../test/integration/workspace/broken_test.go | 2 +- .../integration/workspace/metadata_test.go | 2 +- .../integration/workspace/standalone_test.go | 8 +- .../test/integration/workspace/vendor_test.go | 3 +- .../integration/workspace/workspace_test.go | 22 ++--- .../integration/workspace/zero_config_test.go | 3 +- gopls/internal/test/integration/wrappers.go | 26 +++-- gopls/internal/test/marker/marker_test.go | 6 +- 19 files changed, 127 insertions(+), 124 deletions(-) diff --git a/gopls/internal/test/integration/bench/definition_test.go b/gopls/internal/test/integration/bench/definition_test.go index e456d5a7c87..ea10cae16de 100644 --- a/gopls/internal/test/integration/bench/definition_test.go +++ b/gopls/internal/test/integration/bench/definition_test.go @@ -31,7 +31,7 @@ func BenchmarkDefinition(b *testing.B) { loc := env.RegexpSearch(test.file, test.regexp) env.Await(env.DoneWithOpen()) - env.GoToDefinition(loc) // pre-warm the query, and open the target file + env.FirstDefinition(loc) // pre-warm the query, and open the target file b.ResetTimer() if stopAndRecord := startProfileIfSupported(b, env, qualifiedName(test.repo, "definition")); stopAndRecord != nil { @@ -39,7 +39,7 @@ func BenchmarkDefinition(b *testing.B) { } for b.Loop() { - env.GoToDefinition(loc) // pre-warm the query + env.FirstDefinition(loc) // pre-warm the query } }) } diff --git a/gopls/internal/test/integration/diagnostics/builtin_test.go b/gopls/internal/test/integration/diagnostics/builtin_test.go index d6828a0df5c..eb4b098be74 100644 --- a/gopls/internal/test/integration/diagnostics/builtin_test.go +++ b/gopls/internal/test/integration/diagnostics/builtin_test.go @@ -26,7 +26,7 @@ const ( ` Run(t, src, func(t *testing.T, env *Env) { env.OpenFile("a.go") - loc := env.GoToDefinition(env.RegexpSearch("a.go", "iota")) + loc := env.FirstDefinition(env.RegexpSearch("a.go", "iota")) if !strings.HasSuffix(string(loc.URI), "builtin.go") { t.Fatalf("jumped to %q, want builtin.go", loc.URI) } diff --git a/gopls/internal/test/integration/diagnostics/diagnostics_test.go b/gopls/internal/test/integration/diagnostics/diagnostics_test.go index 5ef39a5f0c5..0d074333352 100644 --- a/gopls/internal/test/integration/diagnostics/diagnostics_test.go +++ b/gopls/internal/test/integration/diagnostics/diagnostics_test.go @@ -1790,7 +1790,7 @@ func helloHelper() {} env.AfterChange( NoDiagnostics(ForFile("nested/hello/hello.go")), ) - loc := env.GoToDefinition(env.RegexpSearch("nested/hello/hello.go", "helloHelper")) + loc := env.FirstDefinition(env.RegexpSearch("nested/hello/hello.go", "helloHelper")) want := "nested/hello/hello_helper.go" if got := env.Sandbox.Workdir.URIToPath(loc.URI); got != want { t.Errorf("Definition() returned %q, want %q", got, want) diff --git a/gopls/internal/test/integration/fake/editor.go b/gopls/internal/test/integration/fake/editor.go index a2dabf61c46..b5d8e4ccda0 100644 --- a/gopls/internal/test/integration/fake/editor.go +++ b/gopls/internal/test/integration/fake/editor.go @@ -932,54 +932,30 @@ func (e *Editor) setBufferContentLocked(ctx context.Context, path string, dirty return nil } -// GoToDefinition jumps to the definition of the symbol at the given position -// in an open buffer. It returns the location of the resulting jump. -func (e *Editor) Definition(ctx context.Context, loc protocol.Location) (protocol.Location, error) { +// Definitions returns the definitions of the symbol at the given +// location in an open buffer. +func (e *Editor) Definitions(ctx context.Context, loc protocol.Location) ([]protocol.Location, error) { if err := e.checkBufferLocation(loc); err != nil { - return protocol.Location{}, err + return nil, err } params := &protocol.DefinitionParams{} params.TextDocument.URI = loc.URI params.Position = loc.Range.Start - resp, err := e.Server.Definition(ctx, params) - if err != nil { - return protocol.Location{}, fmt.Errorf("definition: %w", err) - } - return e.extractFirstLocation(ctx, resp) + return e.Server.Definition(ctx, params) } -// TypeDefinition jumps to the type definition of the symbol at the given -// location in an open buffer. -func (e *Editor) TypeDefinition(ctx context.Context, loc protocol.Location) (protocol.Location, error) { +// TypeDefinitions returns the type definitions of the symbol at the +// given location in an open buffer. +func (e *Editor) TypeDefinitions(ctx context.Context, loc protocol.Location) ([]protocol.Location, error) { if err := e.checkBufferLocation(loc); err != nil { - return protocol.Location{}, err + return nil, err } params := &protocol.TypeDefinitionParams{} params.TextDocument.URI = loc.URI params.Position = loc.Range.Start - resp, err := e.Server.TypeDefinition(ctx, params) - if err != nil { - return protocol.Location{}, fmt.Errorf("type definition: %w", err) - } - return e.extractFirstLocation(ctx, resp) -} - -// extractFirstLocation returns the first location. -// It opens the file if needed. -func (e *Editor) extractFirstLocation(ctx context.Context, locs []protocol.Location) (protocol.Location, error) { - if len(locs) == 0 { - return protocol.Location{}, nil - } - - newPath := e.sandbox.Workdir.URIToPath(locs[0].URI) - if !e.HasBuffer(newPath) { - if err := e.OpenFile(ctx, newPath); err != nil { - return protocol.Location{}, fmt.Errorf("OpenFile: %w", err) - } - } - return locs[0], nil + return e.Server.TypeDefinition(ctx, params) } // Symbol performs a workspace symbol search using query diff --git a/gopls/internal/test/integration/misc/definition_test.go b/gopls/internal/test/integration/misc/definition_test.go index d36bb024672..8a9f27d20ac 100644 --- a/gopls/internal/test/integration/misc/definition_test.go +++ b/gopls/internal/test/integration/misc/definition_test.go @@ -41,13 +41,13 @@ const message = "Hello World." func TestGoToInternalDefinition(t *testing.T) { Run(t, internalDefinition, func(t *testing.T, env *Env) { env.OpenFile("main.go") - loc := env.GoToDefinition(env.RegexpSearch("main.go", "message")) + loc := env.FirstDefinition(env.RegexpSearch("main.go", "message")) name := env.Sandbox.Workdir.URIToPath(loc.URI) if want := "const.go"; name != want { - t.Errorf("GoToDefinition: got file %q, want %q", name, want) + t.Errorf("Definition: got file %q, want %q", name, want) } if want := env.RegexpSearch("const.go", "message"); loc != want { - t.Errorf("GoToDefinition: got location %v, want %v", loc, want) + t.Errorf("Definition: got location %v, want %v", loc, want) } }) } @@ -90,13 +90,13 @@ func TestGoToLinknameDefinition(t *testing.T) { // Jump from directives 2nd arg. start := env.RegexpSearch("upper/upper.go", `lower.bar`) - loc := env.GoToDefinition(start) + loc := env.FirstDefinition(start) name := env.Sandbox.Workdir.URIToPath(loc.URI) if want := "lower/lower.go"; name != want { - t.Errorf("GoToDefinition: got file %q, want %q", name, want) + t.Errorf("Definition: got file %q, want %q", name, want) } if want := env.RegexpSearch("lower/lower.go", `bar`); loc != want { - t.Errorf("GoToDefinition: got position %v, want %v", loc, want) + t.Errorf("Definition: got position %v, want %v", loc, want) } }) } @@ -139,13 +139,13 @@ func TestGoToLinknameDefinitionInReverseDep(t *testing.T) { // Jump from directives 2nd arg. start := env.RegexpSearch("lower/lower.go", `upper.foo`) - loc := env.GoToDefinition(start) + loc := env.FirstDefinition(start) name := env.Sandbox.Workdir.URIToPath(loc.URI) if want := "upper/upper.go"; name != want { - t.Errorf("GoToDefinition: got file %q, want %q", name, want) + t.Errorf("Definition: got file %q, want %q", name, want) } if want := env.RegexpSearch("upper/upper.go", `foo`); loc != want { - t.Errorf("GoToDefinition: got position %v, want %v", loc, want) + t.Errorf("Definition: got position %v, want %v", loc, want) } }) } @@ -178,13 +178,13 @@ func TestGoToLinknameDefinitionDisconnected(t *testing.T) { // Jump from directives 2nd arg. start := env.RegexpSearch("a/a.go", `b.bar`) - loc := env.GoToDefinition(start) + loc := env.FirstDefinition(start) name := env.Sandbox.Workdir.URIToPath(loc.URI) if want := "b/b.go"; name != want { - t.Errorf("GoToDefinition: got file %q, want %q", name, want) + t.Errorf("Definition: got file %q, want %q", name, want) } if want := env.RegexpSearch("b/b.go", `bar`); loc != want { - t.Errorf("GoToDefinition: got position %v, want %v", loc, want) + t.Errorf("Definition: got position %v, want %v", loc, want) } }) } @@ -206,21 +206,22 @@ func main() { func TestGoToStdlibDefinition_Issue37045(t *testing.T) { Run(t, stdlibDefinition, func(t *testing.T, env *Env) { env.OpenFile("main.go") - loc := env.GoToDefinition(env.RegexpSearch("main.go", `fmt.(Printf)`)) + loc := env.FirstDefinition(env.RegexpSearch("main.go", `fmt.(Printf)`)) name := env.Sandbox.Workdir.URIToPath(loc.URI) if got, want := path.Base(name), "print.go"; got != want { - t.Errorf("GoToDefinition: got file %q, want %q", name, want) + t.Errorf("Definition: got file %q, want %q", name, want) } + env.OpenFile(name) // Test that we can jump to definition from outside our workspace. // See golang.org/issues/37045. - newLoc := env.GoToDefinition(loc) + newLoc := env.FirstDefinition(loc) newName := env.Sandbox.Workdir.URIToPath(newLoc.URI) if newName != name { - t.Errorf("GoToDefinition is not idempotent: got %q, want %q", newName, name) + t.Errorf("Definition is not idempotent: got %q, want %q", newName, name) } if newLoc != loc { - t.Errorf("GoToDefinition is not idempotent: got %v, want %v", newLoc, loc) + t.Errorf("Definition is not idempotent: got %v, want %v", newLoc, loc) } }) } @@ -228,8 +229,9 @@ func TestGoToStdlibDefinition_Issue37045(t *testing.T) { func TestUnexportedStdlib_Issue40809(t *testing.T) { Run(t, stdlibDefinition, func(t *testing.T, env *Env) { env.OpenFile("main.go") - loc := env.GoToDefinition(env.RegexpSearch("main.go", `fmt.(Printf)`)) + loc := env.FirstDefinition(env.RegexpSearch("main.go", `fmt.(Printf)`)) name := env.Sandbox.Workdir.URIToPath(loc.URI) + env.OpenFile(name) loc = env.RegexpSearch(name, `:=\s*(newPrinter)\(\)`) @@ -239,7 +241,7 @@ func TestUnexportedStdlib_Issue40809(t *testing.T) { t.Errorf("expected 5+ references to newPrinter, found: %#v", refs) } - loc = env.GoToDefinition(loc) + loc = env.FirstDefinition(loc) content, _ := env.Hover(loc) if !strings.Contains(content.Value, "newPrinter") { t.Fatal("definition of newPrinter went to the incorrect place") @@ -306,7 +308,7 @@ func main() {} Settings{"importShortcut": tt.importShortcut}, ).Run(t, mod, func(t *testing.T, env *Env) { env.OpenFile("main.go") - loc := env.GoToDefinition(env.RegexpSearch("main.go", `"fmt"`)) + loc := env.FirstDefinition(env.RegexpSearch("main.go", `"fmt"`)) if loc == (protocol.Location{}) { t.Fatalf("expected definition, got none") } @@ -354,7 +356,7 @@ func main() {} Run(t, mod, func(t *testing.T, env *Env) { env.OpenFile("main.go") - loc, err := env.Editor.TypeDefinition(env.Ctx, env.RegexpSearch("main.go", tt.re)) + locs, err := env.Editor.TypeDefinitions(env.Ctx, env.RegexpSearch("main.go", tt.re)) if tt.wantError { if err == nil { t.Fatal("expected error, got nil") @@ -364,10 +366,13 @@ func main() {} if err != nil { t.Fatalf("expected nil error, got %s", err) } + if len(locs) == 0 { + t.Fatalf("TypeDefinitions: empty result") + } typeLoc := env.RegexpSearch("main.go", tt.wantTypeRe) - if loc != typeLoc { - t.Errorf("invalid pos: want %+v, got %+v", typeLoc, loc) + if locs[0] != typeLoc { + t.Errorf("invalid pos: want %+v, got %+v", typeLoc, locs[0]) } }) }) @@ -389,7 +394,16 @@ func F[T comparable]() {} Run(t, mod, func(t *testing.T, env *Env) { env.OpenFile("main.go") - _ = env.TypeDefinition(env.RegexpSearch("main.go", "comparable")) // must not panic + // TypeDefinition of comparable should + // returns an empty result, not panic. + loc := env.RegexpSearch("main.go", "comparable") + locs, err := env.Editor.TypeDefinitions(env.Ctx, loc) + if err != nil { + t.Fatal(err) + } + if len(locs) > 0 { + t.Fatalf("unexpected result: %v", locs) + } }) } @@ -429,7 +443,7 @@ package client ` Run(t, mod, func(t *testing.T, env *Env) { env.OpenFile("client/client_role_test.go") - env.GoToDefinition(env.RegexpSearch("client/client_role_test.go", "RoleSetup")) + env.FirstDefinition(env.RegexpSearch("client/client_role_test.go", "RoleSetup")) }) } @@ -487,11 +501,11 @@ const _ = b.K refLoc := env.RegexpSearch("a.go", "K") // find "b.K" reference // Initially, b.K is defined in the module cache. - gotLoc := env.GoToDefinition(refLoc) + gotLoc := env.FirstDefinition(refLoc) gotFile := env.Sandbox.Workdir.URIToPath(gotLoc.URI) wantCache := filepath.ToSlash(env.Sandbox.GOPATH()) + "/pkg/mod/other.com/b@v1.0.0/b.go" if gotFile != wantCache { - t.Errorf("GoToDefinition, before: got file %q, want %q", gotFile, wantCache) + t.Errorf("Definition, before: got file %q, want %q", gotFile, wantCache) } // Run 'go mod vendor' outside the editor. @@ -501,10 +515,10 @@ const _ = b.K env.Await(env.DoneWithChangeWatchedFiles()) // Now, b.K is defined in the vendor tree. - gotLoc = env.GoToDefinition(refLoc) + gotLoc = env.FirstDefinition(refLoc) wantVendor := "vendor/other.com/b/b.go" if gotFile != wantVendor { - t.Errorf("GoToDefinition, after go mod vendor: got file %q, want %q", gotFile, wantVendor) + t.Errorf("Definition, after go mod vendor: got file %q, want %q", gotFile, wantVendor) } // Delete the vendor tree. @@ -520,10 +534,10 @@ const _ = b.K env.Await(env.DoneWithChangeWatchedFiles()) // b.K is once again defined in the module cache. - gotLoc = env.GoToDefinition(gotLoc) + gotLoc = env.FirstDefinition(gotLoc) gotFile = env.Sandbox.Workdir.URIToPath(gotLoc.URI) if gotFile != wantCache { - t.Errorf("GoToDefinition, after rm -rf vendor: got file %q, want %q", gotFile, wantCache) + t.Errorf("Definition, after rm -rf vendor: got file %q, want %q", gotFile, wantCache) } }) } @@ -554,16 +568,16 @@ FOO SKIP ` -func TestGoToEmbedDefinition(t *testing.T) { +func TestEmbedDefinition(t *testing.T) { Run(t, embedDefinition, func(t *testing.T, env *Env) { env.OpenFile("main.go") start := env.RegexpSearch("main.go", `\*.txt`) - loc := env.GoToDefinition(start) + loc := env.FirstDefinition(start) name := env.Sandbox.Workdir.URIToPath(loc.URI) if want := "foo.txt"; name != want { - t.Errorf("GoToDefinition: got file %q, want %q", name, want) + t.Errorf("Definition: got file %q, want %q", name, want) } }) } @@ -588,10 +602,10 @@ func _(err error) { env.OpenFile("a.go") start := env.RegexpSearch("a.go", `Error`) - loc := env.GoToDefinition(start) + loc := env.FirstDefinition(start) if !strings.HasSuffix(string(loc.URI), "builtin.go") { - t.Errorf("GoToDefinition(err.Error) = %#v, want builtin.go", loc) + t.Errorf("Definition(err.Error) = %#v, want builtin.go", loc) } }) } @@ -628,13 +642,13 @@ var _ = foo(123) // call // Definition at the call"foo(123)" takes us to the Go declaration. callLoc := env.RegexpSearch("a.go", regexp.QuoteMeta("foo(123)")) - declLoc := env.GoToDefinition(callLoc) + declLoc := env.FirstDefinition(callLoc) if got, want := locString(declLoc), "a.go:5:5-5:8"; got != want { t.Errorf("Definition(call): got %s, want %s", got, want) } // Definition a second time takes us to the assembly implementation. - implLoc := env.GoToDefinition(declLoc) + implLoc := env.FirstDefinition(declLoc) if got, want := locString(implLoc), "foo_darwin_arm64.s:2:6-2:9"; got != want { t.Errorf("Definition(go decl): got %s, want %s", got, want) } @@ -670,7 +684,7 @@ func Foo() { env.OpenFile("a.go") fooLoc := env.RegexpSearch("a.go", "()Foo") - loc0 := env.GoToDefinition(fooLoc) + loc0 := env.FirstDefinition(fooLoc) // Insert a space that will be removed by formatting. env.EditBuffer("a.go", protocol.TextEdit{ @@ -679,7 +693,7 @@ func Foo() { }) env.SaveBuffer("a.go") // reformats the file before save env.AfterChange() - loc1 := env.GoToDefinition(env.RegexpSearch("a.go", "Foo")) + loc1 := env.FirstDefinition(env.RegexpSearch("a.go", "Foo")) if diff := cmp.Diff(loc0, loc1); diff != "" { t.Errorf("mismatching locations (-want +got):\n%s", diff) } diff --git a/gopls/internal/test/integration/misc/highlight_test.go b/gopls/internal/test/integration/misc/highlight_test.go index 36bddf25057..95105df4d7c 100644 --- a/gopls/internal/test/integration/misc/highlight_test.go +++ b/gopls/internal/test/integration/misc/highlight_test.go @@ -30,7 +30,7 @@ func main() { Run(t, mod, func(t *testing.T, env *Env) { const file = "main.go" env.OpenFile(file) - loc := env.GoToDefinition(env.RegexpSearch(file, `var (A) string`)) + loc := env.FirstDefinition(env.RegexpSearch(file, `var (A) string`)) checkHighlights(env, loc, 3) }) @@ -53,8 +53,9 @@ func main() { Run(t, mod, func(t *testing.T, env *Env) { env.OpenFile("main.go") - defLoc := env.GoToDefinition(env.RegexpSearch("main.go", `fmt\.(Printf)`)) + defLoc := env.FirstDefinition(env.RegexpSearch("main.go", `fmt\.(Printf)`)) file := env.Sandbox.Workdir.URIToPath(defLoc.URI) + env.OpenFile(file) loc := env.RegexpSearch(file, `func Printf\((format) string`) checkHighlights(env, loc, 2) @@ -111,13 +112,15 @@ func main() {}` ).Run(t, mod, func(t *testing.T, env *Env) { env.OpenFile("main.go") - defLoc := env.GoToDefinition(env.RegexpSearch("main.go", `"example.com/global"`)) + defLoc := env.FirstDefinition(env.RegexpSearch("main.go", `"example.com/global"`)) file := env.Sandbox.Workdir.URIToPath(defLoc.URI) + env.OpenFile(file) loc := env.RegexpSearch(file, `const (A)`) checkHighlights(env, loc, 4) - defLoc = env.GoToDefinition(env.RegexpSearch("main.go", `"example.com/local"`)) + defLoc = env.FirstDefinition(env.RegexpSearch("main.go", `"example.com/local"`)) file = env.Sandbox.Workdir.URIToPath(defLoc.URI) + env.OpenFile(file) loc = env.RegexpSearch(file, `const (b)`) checkHighlights(env, loc, 5) }) diff --git a/gopls/internal/test/integration/misc/hover_test.go b/gopls/internal/test/integration/misc/hover_test.go index 7be50efe6d4..b6b7b679357 100644 --- a/gopls/internal/test/integration/misc/hover_test.go +++ b/gopls/internal/test/integration/misc/hover_test.go @@ -66,8 +66,9 @@ func main() { t.Errorf("Workspace hover: missing expected field 'unexported'. Got:\n%q", got.Value) } - cacheLoc := env.GoToDefinition(mixedLoc) + cacheLoc := env.FirstDefinition(mixedLoc) cacheFile := env.Sandbox.Workdir.URIToPath(cacheLoc.URI) + env.OpenFile(cacheFile) argLoc := env.RegexpSearch(cacheFile, "printMixed.*(Mixed)") got, _ = env.Hover(argLoc) if !strings.Contains(got.Value, "unexported") { @@ -644,7 +645,8 @@ func (e) Error() string for _, builtin := range tests { useLocation := env.RegexpSearch("p.go", builtin) calleeHover, _ := env.Hover(useLocation) - declLocation := env.GoToDefinition(useLocation) + declLocation := env.FirstDefinition(useLocation) + env.OpenFile(env.Sandbox.Workdir.URIToPath(declLocation.URI)) declHover, _ := env.Hover(declLocation) if diff := cmp.Diff(calleeHover, declHover); diff != "" { t.Errorf("Hover mismatch (-callee hover +decl hover):\n%s", diff) diff --git a/gopls/internal/test/integration/misc/imports_test.go b/gopls/internal/test/integration/misc/imports_test.go index bcbfacc967a..bdb5ea25318 100644 --- a/gopls/internal/test/integration/misc/imports_test.go +++ b/gopls/internal/test/integration/misc/imports_test.go @@ -241,7 +241,7 @@ var _, _ = x.X, y.Y env.AfterChange(Diagnostics(env.AtRegexp("main.go", `y.Y`))) env.SaveBuffer("main.go") env.AfterChange(NoDiagnostics(ForFile("main.go"))) - loc := env.GoToDefinition(env.RegexpSearch("main.go", `y.(Y)`)) + loc := env.FirstDefinition(env.RegexpSearch("main.go", `y.(Y)`)) path := env.Sandbox.Workdir.URIToPath(loc.URI) if !strings.HasPrefix(path, filepath.ToSlash(modcache)) { t.Errorf("found module dependency outside of GOMODCACHE: got %v, wanted subdir of %v", path, filepath.ToSlash(modcache)) diff --git a/gopls/internal/test/integration/misc/references_test.go b/gopls/internal/test/integration/misc/references_test.go index 58fdb3c5cd8..1cc6c593174 100644 --- a/gopls/internal/test/integration/misc/references_test.go +++ b/gopls/internal/test/integration/misc/references_test.go @@ -37,7 +37,8 @@ func main() { Run(t, files, func(t *testing.T, env *Env) { env.OpenFile("main.go") - loc := env.GoToDefinition(env.RegexpSearch("main.go", `fmt.(Print)`)) + loc := env.FirstDefinition(env.RegexpSearch("main.go", `fmt.(Print)`)) + env.OpenFile(env.Sandbox.Workdir.URIToPath(loc.URI)) refs, err := env.Editor.References(env.Ctx, loc) if err != nil { t.Fatal(err) @@ -82,7 +83,7 @@ func _() { ` Run(t, files, func(t *testing.T, env *Env) { env.OpenFile("main.go") - loc := env.GoToDefinition(env.RegexpSearch("main.go", `Error`)) + loc := env.FirstDefinition(env.RegexpSearch("main.go", `Error`)) refs, err := env.Editor.References(env.Ctx, loc) if err != nil { t.Fatalf("references on (*s).Error failed: %v", err) @@ -131,7 +132,7 @@ var _ = unsafe.Slice(nil, 0) loc := env.RegexpSearch("a.go", `\b`+name+`\b`) // definition -> {builtin,unsafe}.go - def := env.GoToDefinition(loc) + def := env.FirstDefinition(loc) if (!strings.HasSuffix(string(def.URI), "builtin.go") && !strings.HasSuffix(string(def.URI), "unsafe.go")) || def.Range.Start.Line == 0 { diff --git a/gopls/internal/test/integration/modfile/modfile_test.go b/gopls/internal/test/integration/modfile/modfile_test.go index 5a194246a42..dfd50c3effb 100644 --- a/gopls/internal/test/integration/modfile/modfile_test.go +++ b/gopls/internal/test/integration/modfile/modfile_test.go @@ -869,13 +869,13 @@ func hello() {} env.RegexpReplace("go.mod", "module", "modul") // Confirm that we still have metadata with only on-disk edits. env.OpenFile("main.go") - loc := env.GoToDefinition(env.RegexpSearch("main.go", "hello")) + loc := env.FirstDefinition(env.RegexpSearch("main.go", "hello")) if filepath.Base(string(loc.URI)) != "hello.go" { t.Fatalf("expected definition in hello.go, got %s", loc.URI) } // Confirm that we no longer have metadata when the file is saved. env.SaveBufferWithoutActions("go.mod") - _, err := env.Editor.Definition(env.Ctx, env.RegexpSearch("main.go", "hello")) + _, err := env.Editor.Definitions(env.Ctx, env.RegexpSearch("main.go", "hello")) if err == nil { t.Fatalf("expected error, got none") } diff --git a/gopls/internal/test/integration/template/template_test.go b/gopls/internal/test/integration/template/template_test.go index 3087e1d60fd..796fe5e0a57 100644 --- a/gopls/internal/test/integration/template/template_test.go +++ b/gopls/internal/test/integration/template/template_test.go @@ -191,7 +191,8 @@ go 1.12 ).Run(t, files, func(t *testing.T, env *Env) { env.OpenFile("a.tmpl") x := env.RegexpSearch("a.tmpl", `A`) - loc := env.GoToDefinition(x) + loc := env.FirstDefinition(x) + env.OpenFile(env.Sandbox.Workdir.URIToPath(loc.URI)) refs := env.References(loc) if len(refs) != 2 { t.Fatalf("got %v reference(s), want 2", len(refs)) diff --git a/gopls/internal/test/integration/workspace/broken_test.go b/gopls/internal/test/integration/workspace/broken_test.go index 33b0b834eb6..ba70959c9a7 100644 --- a/gopls/internal/test/integration/workspace/broken_test.go +++ b/gopls/internal/test/integration/workspace/broken_test.go @@ -107,7 +107,7 @@ const CompleteMe = 222 env.AfterChange(NoOutstandingWork(IgnoreTelemetryPromptWork)) // Check that definitions in package1 go to the copy vendored in package2. - location := string(env.GoToDefinition(env.RegexpSearch("package1/main.go", "CompleteMe")).URI) + location := string(env.FirstDefinition(env.RegexpSearch("package1/main.go", "CompleteMe")).URI) const wantLocation = "package2/vendor/example.com/foo/foo.go" if !strings.HasSuffix(location, wantLocation) { t.Errorf("got definition of CompleteMe at %q, want %q", location, wantLocation) diff --git a/gopls/internal/test/integration/workspace/metadata_test.go b/gopls/internal/test/integration/workspace/metadata_test.go index 71ca4329777..b8d3c7ee25d 100644 --- a/gopls/internal/test/integration/workspace/metadata_test.go +++ b/gopls/internal/test/integration/workspace/metadata_test.go @@ -158,7 +158,7 @@ func Hello() int { // Now, to satisfy a definition request, gopls will try to reload moda. But // without access to the proxy (because this is no longer a // reinitialization), this loading will fail. - loc := env.GoToDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) + loc := env.FirstDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) got := env.Sandbox.Workdir.URIToPath(loc.URI) if want := "b.com@v1.2.3/b/b.go"; !strings.HasSuffix(got, want) { t.Errorf("expected %s, got %v", want, got) diff --git a/gopls/internal/test/integration/workspace/standalone_test.go b/gopls/internal/test/integration/workspace/standalone_test.go index 3b690465744..6cd8e7fa81a 100644 --- a/gopls/internal/test/integration/workspace/standalone_test.go +++ b/gopls/internal/test/integration/workspace/standalone_test.go @@ -118,17 +118,17 @@ func main() { } // We should resolve workspace definitions in the standalone file. - fileLoc := env.GoToDefinition(env.RegexpSearch("lib/ignore.go", "lib.(K)")) + fileLoc := env.FirstDefinition(env.RegexpSearch("lib/ignore.go", "lib.(K)")) file := env.Sandbox.Workdir.URIToPath(fileLoc.URI) if got, want := file, "lib/lib.go"; got != want { - t.Errorf("GoToDefinition(lib.K) = %v, want %v", got, want) + t.Errorf("Definition(lib.K) = %v, want %v", got, want) } // ...as well as intra-file definitions - loc := env.GoToDefinition(env.RegexpSearch("lib/ignore.go", "\\+ (K)")) + loc := env.FirstDefinition(env.RegexpSearch("lib/ignore.go", "\\+ (K)")) wantLoc := env.RegexpSearch("lib/ignore.go", "const (K)") if loc != wantLoc { - t.Errorf("GoToDefinition(K) = %v, want %v", loc, wantLoc) + t.Errorf("Definition(K) = %v, want %v", loc, wantLoc) } // Renaming "lib.K" to "lib.D" should cause a diagnostic in the standalone diff --git a/gopls/internal/test/integration/workspace/vendor_test.go b/gopls/internal/test/integration/workspace/vendor_test.go index 10826430164..0b07d4acddc 100644 --- a/gopls/internal/test/integration/workspace/vendor_test.go +++ b/gopls/internal/test/integration/workspace/vendor_test.go @@ -56,7 +56,8 @@ var _ b.B env.AfterChange( NoDiagnostics(), // as b is not a workspace package ) - env.GoToDefinition(env.RegexpSearch("a.go", `b\.(B)`)) + loc := env.FirstDefinition(env.RegexpSearch("a.go", `b\.(B)`)) + env.OpenFile(env.Sandbox.Workdir.URIToPath(loc.URI)) env.AfterChange( Diagnostics(env.AtRegexp("vendor/other.com/b/b.go", "V"), WithMessage("not used")), ) diff --git a/gopls/internal/test/integration/workspace/workspace_test.go b/gopls/internal/test/integration/workspace/workspace_test.go index fc96a47dbe0..ed4961312ef 100644 --- a/gopls/internal/test/integration/workspace/workspace_test.go +++ b/gopls/internal/test/integration/workspace/workspace_test.go @@ -265,7 +265,7 @@ func TestWorkspaceVendoring(t *testing.T) { env.OpenFile("moda/a/a.go") env.RunGoCommand("work", "vendor") env.AfterChange() - loc := env.GoToDefinition(env.RegexpSearch("moda/a/a.go", "b.(Hello)")) + loc := env.FirstDefinition(env.RegexpSearch("moda/a/a.go", "b.(Hello)")) const want = "vendor/b.com/b/b.go" if got := env.Sandbox.Workdir.URIToPath(loc.URI); got != want { t.Errorf("Definition: got location %q, want %q", got, want) @@ -375,12 +375,11 @@ func Hello() int { env.OpenFile("moda/a/a.go") env.Await(env.DoneWithOpen()) - originalLoc := env.GoToDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) + originalLoc := env.FirstDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) original := env.Sandbox.Workdir.URIToPath(originalLoc.URI) if want := "modb/b/b.go"; !strings.HasSuffix(original, want) { t.Errorf("expected %s, got %v", want, original) } - env.CloseBuffer(original) env.AfterChange() env.RemoveWorkspaceFile("modb/b/b.go") @@ -388,7 +387,7 @@ func Hello() int { env.WriteWorkspaceFile("go.work", "go 1.18\nuse moda/a") env.AfterChange() - gotLoc := env.GoToDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) + gotLoc := env.FirstDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) got := env.Sandbox.Workdir.URIToPath(gotLoc.URI) if want := "b.com@v1.2.3/b/b.go"; !strings.HasSuffix(got, want) { t.Errorf("expected %s, got %v", want, got) @@ -429,12 +428,11 @@ func main() { ProxyFiles(workspaceModuleProxy), ).Run(t, multiModule, func(t *testing.T, env *Env) { env.OpenFile("moda/a/a.go") - loc := env.GoToDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) + loc := env.FirstDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) original := env.Sandbox.Workdir.URIToPath(loc.URI) if want := "b.com@v1.2.3/b/b.go"; !strings.HasSuffix(original, want) { t.Errorf("expected %s, got %v", want, original) } - env.CloseBuffer(original) env.WriteWorkspaceFiles(map[string]string{ "go.work": `go 1.18 @@ -452,7 +450,7 @@ func Hello() int { `, }) env.AfterChange(Diagnostics(env.AtRegexp("modb/b/b.go", "x"))) - gotLoc := env.GoToDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) + gotLoc := env.FirstDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) got := env.Sandbox.Workdir.URIToPath(gotLoc.URI) if want := "modb/b/b.go"; !strings.HasSuffix(got, want) { t.Errorf("expected %s, got %v", want, original) @@ -587,7 +585,7 @@ use ( // To verify which modules are loaded, we'll jump to the definition of // b.Hello. checkHelloLocation := func(want string) error { - loc := env.GoToDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) + loc := env.FirstDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) file := env.Sandbox.Workdir.URIToPath(loc.URI) if !strings.HasSuffix(file, want) { return fmt.Errorf("expected %s, got %v", want, file) @@ -812,7 +810,7 @@ use ( ).Run(t, workspace, func(t *testing.T, env *Env) { env.OpenFile("moda/a/a.go") env.Await(env.DoneWithOpen()) - loc := env.GoToDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) + loc := env.FirstDefinition(env.RegexpSearch("moda/a/a.go", "Hello")) file := env.Sandbox.Workdir.URIToPath(loc.URI) want := "modb/b/b.go" if !strings.HasSuffix(file, want) { @@ -857,7 +855,7 @@ const B = 0 WorkspaceFolders("a"), ).Run(t, workspace, func(t *testing.T, env *Env) { env.OpenFile("a/a.go") - loc := env.GoToDefinition(env.RegexpSearch("a/a.go", "b.(B)")) + loc := env.FirstDefinition(env.RegexpSearch("a/a.go", "b.(B)")) got := env.Sandbox.Workdir.URIToPath(loc.URI) want := "b/b.go" if got != want { @@ -886,7 +884,7 @@ var _ = fmt.Printf ).Run(t, files, func(t *testing.T, env *Env) { env.CreateBuffer("outside/foo.go", "") env.EditBuffer("outside/foo.go", fake.NewEdit(0, 0, 0, 0, code)) - env.GoToDefinition(env.RegexpSearch("outside/foo.go", `Printf`)) + env.FirstDefinition(env.RegexpSearch("outside/foo.go", `Printf`)) }) } @@ -1083,7 +1081,7 @@ func (Server) Foo() {} ) // This will cause a test failure if other_test.go is not in any package. - _ = env.GoToDefinition(env.RegexpSearch("other_test.go", "Server")) + _ = env.FirstDefinition(env.RegexpSearch("other_test.go", "Server")) }) } diff --git a/gopls/internal/test/integration/workspace/zero_config_test.go b/gopls/internal/test/integration/workspace/zero_config_test.go index 95906274b93..ac16e65f68d 100644 --- a/gopls/internal/test/integration/workspace/zero_config_test.go +++ b/gopls/internal/test/integration/workspace/zero_config_test.go @@ -312,10 +312,11 @@ func _() { ).Run(t, src, func(t *testing.T, env *Env) { env.OpenFile("a.go") env.AfterChange(NoDiagnostics()) - loc := env.GoToDefinition(env.RegexpSearch("a.go", `b\.(B)`)) + loc := env.FirstDefinition(env.RegexpSearch("a.go", `b\.(B)`)) if !strings.Contains(string(loc.URI), "/vendor/") { t.Fatalf("Definition(b.B) = %v, want vendored location", loc.URI) } + env.OpenFile(env.Sandbox.Workdir.URIToPath(loc.URI)) env.AfterChange( Diagnostics(env.AtRegexp("vendor/other.com/b/b.go", "V"), WithMessage("not used")), ) diff --git a/gopls/internal/test/integration/wrappers.go b/gopls/internal/test/integration/wrappers.go index 17e0cf329c4..d24be7bbe98 100644 --- a/gopls/internal/test/integration/wrappers.go +++ b/gopls/internal/test/integration/wrappers.go @@ -189,26 +189,32 @@ func (e *Env) SaveBufferWithoutActions(name string) { } } -// GoToDefinition goes to definition in the editor, calling t.Fatal on any -// error. It returns the path and position of the resulting jump. -// -// TODO(rfindley): rename this to just 'Definition'. -func (e *Env) GoToDefinition(loc protocol.Location) protocol.Location { +// FirstDefinition returns the first definition of the symbol at the +// selected location, calling t.Fatal on error. +func (e *Env) FirstDefinition(loc protocol.Location) protocol.Location { e.TB.Helper() - loc, err := e.Editor.Definition(e.Ctx, loc) + locs, err := e.Editor.Definitions(e.Ctx, loc) if err != nil { e.TB.Fatal(err) } - return loc + if len(locs) == 0 { + e.TB.Fatalf("no definitions") + } + return locs[0] } -func (e *Env) TypeDefinition(loc protocol.Location) protocol.Location { +// FirstTypeDefinition returns the first type definition of the symbol +// at the selected location, calling t.Fatal on error. +func (e *Env) FirstTypeDefinition(loc protocol.Location) protocol.Location { e.TB.Helper() - loc, err := e.Editor.TypeDefinition(e.Ctx, loc) + locs, err := e.Editor.TypeDefinitions(e.Ctx, loc) if err != nil { e.TB.Fatal(err) } - return loc + if len(locs) == 0 { + e.TB.Fatalf("no type definitions") + } + return locs[0] } // FormatBuffer formats the editor buffer, calling t.Fatal on any error. diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index 261add7b3b7..e12fa0f46a3 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -1686,7 +1686,7 @@ func acceptCompletionMarker(mark marker, src protocol.Location, label string, go // // TODO(rfindley): support a variadic destination set. func defMarker(mark marker, src, dst protocol.Location) { - got := mark.run.env.GoToDefinition(src) + got := mark.run.env.FirstDefinition(src) if got != dst { mark.errorf("definition location does not match:\n\tgot: %s\n\twant %s", mark.run.fmtLoc(got), mark.run.fmtLoc(dst)) @@ -1694,7 +1694,7 @@ func defMarker(mark marker, src, dst protocol.Location) { } func typedefMarker(mark marker, src, dst protocol.Location) { - got := mark.run.env.TypeDefinition(src) + got := mark.run.env.FirstTypeDefinition(src) if got != dst { mark.errorf("type definition location does not match:\n\tgot: %s\n\twant %s", mark.run.fmtLoc(got), mark.run.fmtLoc(dst)) @@ -1850,7 +1850,7 @@ func locMarker(mark marker, loc protocol.Location) protocol.Location { return lo // defLocMarker implements the @defloc marker, which binds a location to the // (first) result of a jump-to-definition request. func defLocMarker(mark marker, loc protocol.Location) protocol.Location { - return mark.run.env.GoToDefinition(loc) + return mark.run.env.FirstDefinition(loc) } // diagMarker implements the @diag marker. It eliminates diagnostics from From d5ec4a9294fbe8c602ac69a1ff329e6cf41f79e5 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 17 May 2025 12:59:45 -0400 Subject: [PATCH 114/196] internal/mcp: add notifications Add notification logic for intialization and list changes. Change-Id: I1d392bb49d5995f2046c400da955a033ab2715ce Reviewed-on: https://go-review.googlesource.com/c/tools/+/673775 Reviewed-by: Sam Thanawalla LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/client.go | 59 +++++++++++++++--- internal/mcp/generate.go | 24 ++++++-- internal/mcp/mcp_test.go | 49 ++++++++++++++- internal/mcp/protocol.go | 62 ++++++++++++++----- internal/mcp/server.go | 130 +++++++++++++++++++++++++-------------- internal/mcp/shared.go | 27 ++++++++ 6 files changed, 273 insertions(+), 78 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index c9fed8b134f..1bf76df0e8b 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -49,6 +49,10 @@ type ClientOptions struct { // Handler for sampling. // Called when a server calls CreateMessage. CreateMessageHandler func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) + // Handlers for notifications from the server. + ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) + PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) + ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) } // bind implements the binder[*ClientSession] interface, so that Clients can @@ -86,10 +90,13 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e if err != nil { return nil, err } + caps := &ClientCapabilities{} + caps.Roots.ListChanged = true if c.opts.CreateMessageHandler != nil { caps.Sampling = &SamplingCapabilities{} } + params := &InitializeParams{ ClientInfo: &implementation{Name: c.name, Version: c.version}, Capabilities: caps, @@ -133,11 +140,13 @@ func (c *ClientSession) Wait() error { // AddRoots adds the given roots to the client, // replacing any with the same URIs, // and notifies any connected servers. -// TODO: notification func (c *Client) AddRoots(roots ...*Root) { - c.mu.Lock() - defer c.mu.Unlock() - c.roots.add(roots...) + // Only notify if something could change. + if len(roots) == 0 { + return + } + c.changeAndNotify(notificationRootsListChanged, &RootsListChangedParams{}, + func() bool { c.roots.add(roots...); return true }) } // RemoveRoots removes the roots with the given URIs, @@ -145,9 +154,22 @@ func (c *Client) AddRoots(roots ...*Root) { // It is not an error to remove a nonexistent root. // TODO: notification func (c *Client) RemoveRoots(uris ...string) { + c.changeAndNotify(notificationRootsListChanged, &RootsListChangedParams{}, + func() bool { return c.roots.remove(uris...) }) +} + +// changeAndNotify is called when a feature is added or removed. +// It calls change, which should do the work and report whether a change actually occurred. +// If there was a change, it notifies a snapshot of the sessions. +func (c *Client) changeAndNotify(notification string, params any, change func() bool) { + var sessions []*ClientSession + // Lock for the change, but not for the notification. c.mu.Lock() - defer c.mu.Unlock() - c.roots.remove(uris...) + if change() { + sessions = slices.Clone(c.sessions) + } + c.mu.Unlock() + notifySessions(sessions, notification, params) } func (c *Client) listRoots(_ context.Context, _ *ClientSession, _ *ListRootsParams) (*ListRootsResult, error) { @@ -180,10 +202,12 @@ func (c *Client) AddMiddleware(middleware ...Middleware[ClientSession]) { // clientMethodInfos maps from the RPC method name to serverMethodInfos. var clientMethodInfos = map[string]methodInfo[ClientSession]{ - methodPing: newMethodInfo(sessionMethod((*ClientSession).ping)), - methodListRoots: newMethodInfo(clientMethod((*Client).listRoots)), - methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage)), - // TODO: notifications + methodPing: newMethodInfo(sessionMethod((*ClientSession).ping)), + methodListRoots: newMethodInfo(clientMethod((*Client).listRoots)), + methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage)), + notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)), + notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)), + notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)), } var _ session[ClientSession] = (*ClientSession)(nil) @@ -202,6 +226,9 @@ func (cs *ClientSession) methodHandler() MethodHandler[ClientSession] { return cs.client.methodHandler_ } +// getConn implements [session.getConn]. +func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn } + func (c *ClientSession) ping(ct context.Context, params *PingParams) (struct{}, error) { return struct{}{}, nil } @@ -263,6 +290,18 @@ func (c *ClientSession) ReadResource(ctx context.Context, params *ReadResourcePa return standardCall[ReadResourceResult](ctx, c.conn, methodReadResource, params) } +func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (any, error) { + return callNotificationHandler(ctx, c.opts.ToolListChangedHandler, s, params) +} + +func (c *Client) callPromptChangedHandler(ctx context.Context, s *ClientSession, params *PromptListChangedParams) (any, error) { + return callNotificationHandler(ctx, c.opts.PromptListChangedHandler, s, params) +} + +func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSession, params *ResourceListChangedParams) (any, error) { + return callNotificationHandler(ctx, c.opts.ResourceListChangedHandler, s, params) +} + func standardCall[TRes, TParams any](ctx context.Context, conn *jsonrpc2.Connection, method string, params TParams) (*TRes, error) { var result TRes if err := call(ctx, conn, method, params, &result); err != nil { diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index 3a67bd6b8db..c557c7c2d1c 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -120,8 +120,12 @@ var declarations = config{ "Prompt": {}, "PromptMessage": {}, "PromptArgument": {}, - "ProgressToken": {Name: "-", Substitute: "any"}, // null|number|string - "RequestId": {Name: "-", Substitute: "any"}, // null|number|string + "PromptListChangedNotification": { + Name: "-", + Fields: config{"Params": {Name: "PromptListChangedParams"}}, + }, + "ProgressToken": {Name: "-", Substitute: "any"}, // null|number|string + "RequestId": {Name: "-", Substitute: "any"}, // null|number|string "ReadResourceRequest": { Name: "-", Fields: config{"Params": {Name: "ReadResourceParams"}}, @@ -130,8 +134,16 @@ var declarations = config{ Fields: config{"Contents": {Substitute: "*ResourceContents"}}, }, "Resource": {}, - "Role": {}, - "Root": {}, + "ResourceListChangedNotification": { + Name: "-", + Fields: config{"Params": {Name: "ResourceListChangedParams"}}, + }, + "Role": {}, + "Root": {}, + "RootsListChangedNotification": { + Name: "-", + Fields: config{"Params": {Name: "RootsListChangedParams"}}, + }, "SamplingCapabilities": {Substitute: "struct{}"}, "SamplingMessage": {}, @@ -147,6 +159,10 @@ var declarations = config{ Fields: config{"InputSchema": {Substitute: "*jsonschema.Schema"}}, }, "ToolAnnotations": {}, + "ToolListChangedNotification": { + Name: "-", + Fields: config{"Params": {Name: "ToolListChangedParams"}}, + }, } func main() { diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 98ced31bdc6..f783496e19e 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -39,7 +39,26 @@ func TestEndToEnd(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() - s := NewServer("testServer", "v1.0.0", nil) + // Channels to check if notification callbacks happened. + notificationChans := map[string]chan int{} + for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources"} { + notificationChans[name] = make(chan int, 1) + } + waitForNotification := func(t *testing.T, name string) { + t.Helper() + select { + case <-notificationChans[name]: + case <-time.After(time.Second): + t.Fatalf("%s handler never called", name) + } + } + + sopts := &ServerOptions{ + InitializedHandler: func(context.Context, *ServerSession, *InitializedParams) { notificationChans["initialized"] <- 0 }, + RootsListChangedHandler: func(context.Context, *ServerSession, *RootsListChangedParams) { notificationChans["roots"] <- 0 }, + } + + s := NewServer("testServer", "v1.0.0", sopts) // The 'greet' tool says hi. s.AddTools(NewTool("greet", "say hi", sayHi)) @@ -89,6 +108,9 @@ func TestEndToEnd(t *testing.T) { CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) { return &CreateMessageResult{Model: "aModel"}, nil }, + ToolListChangedHandler: func(context.Context, *ClientSession, *ToolListChangedParams) { notificationChans["tools"] <- 0 }, + PromptListChangedHandler: func(context.Context, *ClientSession, *PromptListChangedParams) { notificationChans["prompts"] <- 0 }, + ResourceListChangedHandler: func(context.Context, *ClientSession, *ResourceListChangedParams) { notificationChans["resources"] <- 0 }, } c := NewClient("testClient", "v1.0.0", opts) rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) @@ -103,6 +125,7 @@ func TestEndToEnd(t *testing.T) { t.Fatal(err) } + waitForNotification(t, "initialized") if err := cs.Ping(ctx, nil); err != nil { t.Fatalf("ping failed: %v", err) } @@ -141,6 +164,11 @@ func TestEndToEnd(t *testing.T) { if _, err := cs.GetPrompt(ctx, &GetPromptParams{Name: "fail"}); err == nil || !strings.Contains(err.Error(), failure.Error()) { t.Errorf("fail returned unexpected error: got %v, want containing %v", err, failure) } + + s.AddPrompts(&ServerPrompt{Prompt: &Prompt{Name: "T"}}) + waitForNotification(t, "prompts") + s.RemovePrompts("T") + waitForNotification(t, "prompts") }) t.Run("tools", func(t *testing.T) { @@ -198,6 +226,11 @@ func TestEndToEnd(t *testing.T) { if diff := cmp.Diff(wantFail, gotFail); diff != "" { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) } + + s.AddTools(&ServerTool{Tool: &Tool{Name: "T"}}) + waitForNotification(t, "tools") + s.RemoveTools("T") + waitForNotification(t, "tools") }) t.Run("resources", func(t *testing.T) { @@ -254,6 +287,11 @@ func TestEndToEnd(t *testing.T) { } } } + + s.AddResources(&ServerResource{Resource: &Resource{URI: "http://U"}}) + waitForNotification(t, "resources") + s.RemoveResources("http://U") + waitForNotification(t, "resources") }) t.Run("roots", func(t *testing.T) { rootRes, err := ss.ListRoots(ctx, &ListRootsParams{}) @@ -265,6 +303,11 @@ func TestEndToEnd(t *testing.T) { if diff := cmp.Diff(wantRoots, gotRoots); diff != "" { t.Errorf("roots/list mismatch (-want +got):\n%s", diff) } + + c.AddRoots(&Root{URI: "U"}) + waitForNotification(t, "roots") + c.RemoveRoots("U") + waitForNotification(t, "roots") }) t.Run("sampling", func(t *testing.T) { // TODO: test that a client that doesn't have the handler returns CodeUnsupportedMethod. @@ -462,6 +505,10 @@ func TestMiddleware(t *testing.T) { 2 >initialize 2 notifications/initialized +2 >notifications/initialized +2 tools/list 2 >tools/list 2 Date: Sat, 17 May 2025 21:25:14 -0400 Subject: [PATCH 115/196] internal/mcp: construct README with weave Use the weave program from golang.org/x/example to weave together the README text with buildable programs. This ensures the code in the README compiles. Change-Id: Ib5e10c8f808cc6d1ed4addbee784607e910ba30b Reviewed-on: https://go-review.googlesource.com/c/tools/+/673796 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/README.md | 35 ++++++++---- internal/mcp/internal/readme/Makefile | 19 +++++++ internal/mcp/internal/readme/README.src.md | 57 +++++++++++++++++++ internal/mcp/internal/readme/build.sh | 22 +++++++ internal/mcp/internal/readme/client/client.go | 40 +++++++++++++ internal/mcp/internal/readme/server/server.go | 32 +++++++++++ internal/mcp/mcp.go | 1 + 7 files changed, 196 insertions(+), 10 deletions(-) create mode 100644 internal/mcp/internal/readme/Makefile create mode 100644 internal/mcp/internal/readme/README.src.md create mode 100755 internal/mcp/internal/readme/build.sh create mode 100644 internal/mcp/internal/readme/client/client.go create mode 100644 internal/mcp/internal/readme/server/server.go diff --git a/internal/mcp/README.md b/internal/mcp/README.md index 761df9b2d27..c1af3729182 100644 --- a/internal/mcp/README.md +++ b/internal/mcp/README.md @@ -1,18 +1,28 @@ + # MCP SDK prototype [![PkgGoDev](https://pkg.go.dev/badge/golang.org/x/tools)](https://pkg.go.dev/golang.org/x/tools/internal/mcp) +# Contents + +1. [Installation](#installation) +1. [Quickstart](#quickstart) +1. [Design](#design) +1. [Testing](#testing) +1. [Code of Conduct](#code-of-conduct) +1. [License](#license) + The mcp package provides a software development kit (SDK) for writing clients and servers of the [model context protocol](https://modelcontextprotocol.io/introduction). It is unstable, and will change in breaking ways in the future. As of writing, it is a prototype to explore the design space of client/server transport and binding. -## Installation +# Installation The mcp package is currently internal and cannot be imported using `go get`. -## Quickstart +# Quickstart Here's an example that creates a client that talks to an MCP server running as a sidecar process: @@ -35,15 +45,20 @@ func main() { // Connect to a server over stdin/stdout transport := mcp.NewCommandTransport(exec.Command("myserver")) session, err := client.Connect(ctx, transport) - if err != nil { + if err != nil { log.Fatal(err) } - defer session.Close() + defer session.Close() // Call a tool on the server. - if content, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}, nil); err != nil { + if res, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}, nil); err != nil { log.Printf("CallTool failed: %v", err) } else { - log.Printf("CallTool returns: %v", content) + if res.IsError { + log.Print("tool failed") + } + for _, c := range res.Content { + log.Print(c.Text) + } } } ``` @@ -78,13 +93,13 @@ func main() { } ``` -## Design +# Design See [design.md](./design/design.md) for the SDK design. That document is canonical: given any divergence between the design doc and this prototype, the doc reflects the latest design. -## Testing +# Testing To test your client or server using stdio transport, you can use an in-memory transport. See [example](server_example_test.go). @@ -92,12 +107,12 @@ transport. See [example](server_example_test.go). To test your client or server using sse transport, you can use the [httptest](https://pkg.go.dev/net/http/httptest) package. See [example](sse_example_test.go). -## Code of Conduct +# Code of Conduct This project follows the [Go Community Code of Conduct](https://go.dev/conduct). If you encounter a conduct-related issue, please mail conduct@golang.org. -## License +# License Unless otherwise noted, the Go source files are distributed under the BSD-style license found in the [LICENSE](../../LICENSE) file. diff --git a/internal/mcp/internal/readme/Makefile b/internal/mcp/internal/readme/Makefile new file mode 100644 index 00000000000..5e48418407d --- /dev/null +++ b/internal/mcp/internal/readme/Makefile @@ -0,0 +1,19 @@ +# This makefile builds ../README.md from the files in this directory. + +OUTFILE=../../README.md + +$(OUTFILE): build README.src.md + go run golang.org/x/example/internal/cmd/weave@latest README.src.md > $(OUTFILE) + +# Compile all the code used in the README. +build: $(wildcard */*.go) + go build -o /tmp/mcp-readme/ ./... + +# Preview the README on GitHub. +# $HOME/markdown must be a github repo. +# Visit https://github.com/$HOME/markdown to see the result. +preview: $(OUTFILE) + cp $(OUTFILE) $$HOME/markdown/ + (cd $$HOME/markdown/ && git commit -m . README.md && git push) + +.PHONY: build preview diff --git a/internal/mcp/internal/readme/README.src.md b/internal/mcp/internal/readme/README.src.md new file mode 100644 index 00000000000..310e6d0aa44 --- /dev/null +++ b/internal/mcp/internal/readme/README.src.md @@ -0,0 +1,57 @@ +# MCP SDK prototype + +[![PkgGoDev](https://pkg.go.dev/badge/golang.org/x/tools)](https://pkg.go.dev/golang.org/x/tools/internal/mcp) + +# Contents + +%toc + +The mcp package provides a software development kit (SDK) for writing clients +and servers of the [model context +protocol](https://modelcontextprotocol.io/introduction). It is unstable, and +will change in breaking ways in the future. As of writing, it is a prototype to +explore the design space of client/server transport and binding. + +# Installation + +The mcp package is currently internal and cannot be imported using `go get`. + +# Quickstart + +Here's an example that creates a client that talks to an MCP server running +as a sidecar process: + +%include client/client.go - + +Here is an example of the corresponding server, connected over stdin/stdout: + +%include server/server.go - + +# Design + +See [design.md](./design/design.md) for the SDK design. That document is +canonical: given any divergence between the design doc and this prototype, the +doc reflects the latest design. + +# Testing + +To test your client or server using stdio transport, you can use an in-memory +transport. See [example](server_example_test.go). + +To test your client or server using sse transport, you can use the [httptest](https://pkg.go.dev/net/http/httptest) +package. See [example](sse_example_test.go). + +# Code of Conduct + +This project follows the [Go Community Code of Conduct](https://go.dev/conduct). +If you encounter a conduct-related issue, please mail conduct@golang.org. + +# License + +Unless otherwise noted, the Go source files are distributed under the BSD-style +license found in the [LICENSE](../../LICENSE) file. + +Upon a potential move to the +[modelcontextprotocol](https://github.com/modelcontextprotocol) organization, +the license will be updated to the MIT License, and the license header will +reflect the Go MCP SDK Authors. diff --git a/internal/mcp/internal/readme/build.sh b/internal/mcp/internal/readme/build.sh new file mode 100755 index 00000000000..354c46c0267 --- /dev/null +++ b/internal/mcp/internal/readme/build.sh @@ -0,0 +1,22 @@ +#!/bin/sh + +# Build README.md from the files in this directory. +# Must be invoked from the internal/cmp directory. + +cd internal/readme + +outfile=../../README.md + +# Compile all the code used in the README. +go build -o /tmp/mcp-readme/ ./... +# Combine the code with the text in README.src.md. +# TODO: when at Go 1.24, use a tool directive for weave. +go run golang.org/x/example/internal/cmd/weave@latest README.src.md > $outfile + +if [[ $1 = '-preview' ]]; then + # Preview the README on GitHub. + # $HOME/markdown must be a github repo. + # Visit https://github.com/$HOME/markdown to see the result. + cp $outfile $HOME/markdown/ + (cd $HOME/markdown/ && git commit -m . README.md && git push) +fi diff --git a/internal/mcp/internal/readme/client/client.go b/internal/mcp/internal/readme/client/client.go new file mode 100644 index 00000000000..97600e7d2ab --- /dev/null +++ b/internal/mcp/internal/readme/client/client.go @@ -0,0 +1,40 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// !+ +package main + +import ( + "context" + "log" + "os/exec" + + "golang.org/x/tools/internal/mcp" +) + +func main() { + ctx := context.Background() + // Create a new client, with no features. + client := mcp.NewClient("mcp-client", "v1.0.0", nil) + // Connect to a server over stdin/stdout + transport := mcp.NewCommandTransport(exec.Command("myserver")) + session, err := client.Connect(ctx, transport) + if err != nil { + log.Fatal(err) + } + defer session.Close() + // Call a tool on the server. + if res, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}, nil); err != nil { + log.Printf("CallTool failed: %v", err) + } else { + if res.IsError { + log.Print("tool failed") + } + for _, c := range res.Content { + log.Print(c.Text) + } + } +} + +//!- diff --git a/internal/mcp/internal/readme/server/server.go b/internal/mcp/internal/readme/server/server.go new file mode 100644 index 00000000000..185a7297d4d --- /dev/null +++ b/internal/mcp/internal/readme/server/server.go @@ -0,0 +1,32 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// !+ +package main + +import ( + "context" + + "golang.org/x/tools/internal/mcp" +) + +type HiParams struct { + Name string `json:"name"` +} + +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *HiParams) ([]*mcp.Content, error) { + return []*mcp.Content{ + mcp.NewTextContent("Hi " + params.Name), + }, nil +} + +func main() { + // Create a server with a single tool. + server := mcp.NewServer("greeter", "v1.0.0", nil) + server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) + // Run the server over stdin/stdout, until the client diconnects + _ = server.Run(context.Background(), mcp.NewStdIOTransport()) +} + +// !- diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go index d1cd6c7a900..d4ae8ee2442 100644 --- a/internal/mcp/mcp.go +++ b/internal/mcp/mcp.go @@ -3,6 +3,7 @@ // license that can be found in the LICENSE file. //go:generate go run generate.go +//go:generate ./internal/readme/build.sh // The mcp package provides an SDK for writing model context protocol clients // and servers. From 35a9265740bf19f2cda31fc93fa22c8803bdb1e3 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sun, 27 Apr 2025 09:32:18 -0400 Subject: [PATCH 116/196] jsonschema: add more schema fields Add additional fields to Schema which are mentioned in the spec but do not affect validation. Change-Id: I89488261a00c207d01cd7fc59d782f43692fe528 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674976 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI Auto-Submit: Jonathan Amsterdam --- internal/mcp/jsonschema/resolve.go | 8 ++ internal/mcp/jsonschema/schema.go | 39 +++++++-- internal/mcp/jsonschema/schema_test.go | 16 ++-- .../testdata/draft2020-12/default.json | 82 +++++++++++++++++++ internal/mcp/jsonschema/validate_test.go | 6 +- 5 files changed, 134 insertions(+), 17 deletions(-) create mode 100644 internal/mcp/jsonschema/testdata/draft2020-12/default.json diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index f82eec1b78f..d28fba42cd9 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -153,6 +153,14 @@ func (s *Schema) checkLocal(report func(error)) { // TODO: validate the schema's properties, // ideally by jsonschema-validating it against the meta-schema. + // Some properties are present so that Schemas can round-trip, but we do not + // validate them. + // Currently, it's just the $vocabulary property. + // As a special case, we can validate the 2020-12 meta-schema. + if s.Vocabulary != nil && s.Schema != draft202012 { + addf("cannot validate a schema with $vocabulary") + } + // Check and compile regexps. if s.Pattern != "" { re, err := regexp.Compile(s.Pattern) diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 1ec4b0d4bf2..d6d5f765d6c 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -58,6 +58,11 @@ type Schema struct { // metadata Title string `json:"title,omitempty"` Description string `json:"description,omitempty"` + Default *any `json:"default,omitempty"` + Deprecated bool `json:"deprecated,omitempty"` + ReadOnly bool `json:"readOnly,omitempty"` + WriteOnly bool `json:"writeOnly,omitempty"` + Examples []any `json:"examples,omitempty"` // validation // Use Type for a single type, or Types for multiple types; never both. @@ -110,6 +115,15 @@ type Schema struct { Else *Schema `json:"else,omitempty"` DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` + // other + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 + ContentEncoding string `json:"contentEncoding,omitempty"` + ContentMediaType string `json:"contentMediaType,omitempty"` + ContentSchema *Schema `json:"contentSchema,omitempty"` + + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 + Format string `json:"format,omitempty"` + // computed fields // This schema's base schema. @@ -237,6 +251,7 @@ func (s *Schema) UnmarshalJSON(data []byte) error { ms := struct { Type json.RawMessage `json:"type,omitempty"` Const json.RawMessage `json:"const,omitempty"` + Default json.RawMessage `json:"default,omitempty"` MinLength *integer `json:"minLength,omitempty"` MaxLength *integer `json:"maxLength,omitempty"` MinItems *integer `json:"minItems,omitempty"` @@ -269,14 +284,24 @@ func (s *Schema) UnmarshalJSON(data []byte) error { return err } - // Setting Const to a pointer to null will marshal properly, but won't unmarshal: - // the *any is set to nil, not a pointer to nil. - if len(ms.Const) > 0 { - if bytes.Equal(ms.Const, []byte("null")) { - s.Const = new(any) - } else if err := json.Unmarshal(ms.Const, &s.Const); err != nil { - return err + unmarshalAnyPtr := func(p **any, raw json.RawMessage) error { + if len(raw) == 0 { + return nil } + if bytes.Equal(raw, []byte("null")) { + *p = new(any) + return nil + } + return json.Unmarshal(raw, p) + } + + // Setting Const or Default to a pointer to null will marshal properly, but won't + // unmarshal: the *any is set to nil, not a pointer to nil. + if err := unmarshalAnyPtr(&s.Const, ms.Const); err != nil { + return err + } + if err := unmarshalAnyPtr(&s.Default, ms.Default); err != nil { + return err } set := func(dst **int, src *integer) { diff --git a/internal/mcp/jsonschema/schema_test.go b/internal/mcp/jsonschema/schema_test.go index 4d042d560b6..8394bb587fa 100644 --- a/internal/mcp/jsonschema/schema_test.go +++ b/internal/mcp/jsonschema/schema_test.go @@ -24,6 +24,7 @@ func TestGoRoundTrip(t *testing.T) { {Const: Ptr(any(nil))}, {Const: Ptr(any([]int{}))}, {Const: Ptr(any(map[string]any{}))}, + {Default: Ptr(any(nil))}, } { data, err := json.Marshal(s) if err != nil { @@ -31,9 +32,7 @@ func TestGoRoundTrip(t *testing.T) { } t.Logf("marshal: %s", data) var got *Schema - if err := json.Unmarshal(data, &got); err != nil { - t.Fatal(err) - } + mustUnmarshal(t, data, &got) if !Equal(got, s) { t.Errorf("got %+v, want %+v", got, s) if got.Const != nil && s.Const != nil { @@ -68,9 +67,7 @@ func TestJSONRoundTrip(t *testing.T) { {`{"unk":0}`, `{}`}, // unknown fields are dropped, unfortunately } { var s Schema - if err := json.Unmarshal([]byte(tt.in), &s); err != nil { - t.Fatal(err) - } + mustUnmarshal(t, []byte(tt.in), &s) data, err := json.Marshal(s) if err != nil { t.Fatal(err) @@ -126,3 +123,10 @@ func TestEvery(t *testing.T) { t.Errorf("got %d, want %d", got, want) } } + +func mustUnmarshal(t *testing.T, data []byte, ptr any) { + t.Helper() + if err := json.Unmarshal(data, ptr); err != nil { + t.Fatal(err) + } +} diff --git a/internal/mcp/jsonschema/testdata/draft2020-12/default.json b/internal/mcp/jsonschema/testdata/draft2020-12/default.json new file mode 100644 index 00000000000..ceb3ae27172 --- /dev/null +++ b/internal/mcp/jsonschema/testdata/draft2020-12/default.json @@ -0,0 +1,82 @@ +[ + { + "description": "invalid type for default", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "foo": { + "type": "integer", + "default": [] + } + } + }, + "tests": [ + { + "description": "valid when property is specified", + "data": {"foo": 13}, + "valid": true + }, + { + "description": "still valid when the invalid default is used", + "data": {}, + "valid": true + } + ] + }, + { + "description": "invalid string value for default", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "bar": { + "type": "string", + "minLength": 4, + "default": "bad" + } + } + }, + "tests": [ + { + "description": "valid when property is specified", + "data": {"bar": "good"}, + "valid": true + }, + { + "description": "still valid when the invalid default is used", + "data": {}, + "valid": true + } + ] + }, + { + "description": "the default keyword does not do anything if the property is missing", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "alpha": { + "type": "number", + "maximum": 3, + "default": 5 + } + } + }, + "tests": [ + { + "description": "an explicit property value is checked against maximum (passing)", + "data": { "alpha": 1 }, + "valid": true + }, + { + "description": "an explicit property value is checked against maximum (failing)", + "data": { "alpha": 5 }, + "valid": false + }, + { + "description": "missing properties are not filled in with the default", + "data": {}, + "valid": true + } + ] + } +] diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index 3d096dfcef1..d6be6f8cfc4 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -42,14 +42,12 @@ func TestValidate(t *testing.T) { for _, file := range files { base := filepath.Base(file) t.Run(base, func(t *testing.T) { - f, err := os.Open(file) + data, err := os.ReadFile(file) if err != nil { t.Fatal(err) } - defer f.Close() - dec := json.NewDecoder(f) var groups []testGroup - if err := dec.Decode(&groups); err != nil { + if err := json.Unmarshal(data, &groups); err != nil { t.Fatal(err) } for _, g := range groups { From ac05d44a965d2dec5a94074cd19c697bd5bb7ca0 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sun, 27 Apr 2025 10:15:45 -0400 Subject: [PATCH 117/196] jsonschema: package doc Document the package. Change-Id: I770c14d6339b75ea6827280c23171c3584dc1071 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674977 Reviewed-by: Robert Findley Reviewed-by: Alan Donovan Auto-Submit: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/mcp/jsonschema/doc.go | 74 +++++++++++++++++++++++++++++++ internal/mcp/jsonschema/schema.go | 2 - 2 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 internal/mcp/jsonschema/doc.go diff --git a/internal/mcp/jsonschema/doc.go b/internal/mcp/jsonschema/doc.go new file mode 100644 index 00000000000..390dc5fa904 --- /dev/null +++ b/internal/mcp/jsonschema/doc.go @@ -0,0 +1,74 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package jsonschema is an implementation of the [JSON Schema specification], +a JSON-based format for describing the structure of JSON data. +The package can be used to read schemas for code generation, and to validate data using the +draft 2020-12 specification. Validation with other drafts or custom meta-schemas +is not supported. + +Construct a [Schema] as you would any Go struct (for example, by writing a struct +literal), or unmarshal a JSON schema into a [Schema] in the usual way (with [encoding/json], +for instance). It can then be used for code generation or other purposes without +further processing. + +# Validation + +Before using a Schema to validate a JSON value, you must first resolve it by calling +[Schema.Resolve]. +The call [Resolved.Validate] on the result to validate a JSON value. +The value must be a Go value that looks like the result of unmarshaling a JSON +value into an [any] or a struct. For example, the JSON value + + {"name": "Al", "scores": [90, 80, 100]} + +could be represented as the Go value + + map[string]any{ + "name": "Al", + "scores": []any{90, 80, 100}, + } + +or as a value of this type: + + type Player struct { + Name string `json:"name"` + Scores []int `json:"scores"` + } + +# Inference + +The [For] and [ForType] functions return a [Schema] describing the given Go type. +The type cannot contain any function or channel types, and any map types must have a string key. +For example, calling For on the above Player type results in this schema: + + { + "properties": { + "name": { + "type": "string" + }, + "scores": { + "type": "array", + "items": {"type": "integer"} + } + "required": ["name", "scores"], + "additionalProperties": {"not": {}} + } + } + +# Deviations from the specification + +Regular expressions are processed with Go's regexp package, which differs from ECMA 262, +most significantly in not supporting back-references. +See [this table of differences] for more. + +The value of the "format" keyword is recorded in the Schema, but is ignored during validation. +It does not even produce [annotations]. + +[JSON Schema specification]: https://json-schema.org +[this table of differences] https://github.com/dlclark/regexp2?tab=readme-ov-file#compare-regexp-and-regexp2 +[annotations]: https://json-schema.org/draft/2020-12/json-schema-core#name-annotations +*/ +package jsonschema diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index d6d5f765d6c..3fbf861af17 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package jsonschema is an implementation of the JSON Schema -// specification: https://json-schema.org. package jsonschema import ( From 66d4add26c4aacaca70299e95484216e67089601 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 22 May 2025 10:13:32 -0400 Subject: [PATCH 118/196] internal/gocommand: re-disable flaky TestRmdirAfterGoList_Runner The use of WaitDelay means completion of the Run operation does not mean the go command has terminated. Needs more thought. Updates golang/go#73736 Change-Id: Ia7852fc1cae13e841c0bc28f3567809e1b92915f Reviewed-on: https://go-review.googlesource.com/c/tools/+/675515 LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Jonathan Amsterdam --- internal/gocommand/invoke_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/gocommand/invoke_test.go b/internal/gocommand/invoke_test.go index 7e29135633c..0d4dbb1eb13 100644 --- a/internal/gocommand/invoke_test.go +++ b/internal/gocommand/invoke_test.go @@ -41,6 +41,8 @@ func TestGoVersion(t *testing.T) { // If this test ever fails, the combination of the gocommand package // and the go command itself has a bug; this has been observed (#73503). func TestRmdirAfterGoList_Runner(t *testing.T) { + t.Skip("flaky; see https://github.com/golang/go/issues/73736#issuecomment-2885407104") + testRmdirAfterGoList(t, func(ctx context.Context, dir string) { var runner gocommand.Runner stdout, stderr, friendlyErr, err := runner.RunRaw(ctx, gocommand.Invocation{ From 4d4fb92b593b2ece90eca450caa06196edf0ca68 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 12 May 2025 08:35:02 -0400 Subject: [PATCH 119/196] internal/mcp: implement logging Provide a ServerSession method that directly implements the spec method. Also provide a custom slog.Handler that writes a notification on each log call. It uses a slog.JSONHandler to format the message as JSON, then adds the additional fields of LoggingMessageParams. I decided that the direct method should handle levels, since that is in the spec, but that rate-limiting is an extra feature, so is only provided in the slog handler. Change-Id: I1190ca124e992e61fb633ac52ee3492c8e6d99c9 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674498 Auto-Submit: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI Reviewed-by: Sam Thanawalla Reviewed-by: Robert Findley --- internal/mcp/client.go | 13 +++ internal/mcp/design/design.md | 29 +++++- internal/mcp/generate.go | 29 ++++-- internal/mcp/logging.go | 189 ++++++++++++++++++++++++++++++++++ internal/mcp/mcp_test.go | 84 +++++++++++++++ internal/mcp/protocol.go | 72 ++++++++----- internal/mcp/server.go | 42 ++++++-- 7 files changed, 416 insertions(+), 42 deletions(-) create mode 100644 internal/mcp/logging.go diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 1bf76df0e8b..79ccd89e859 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -53,6 +53,7 @@ type ClientOptions struct { ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) + LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams) } // bind implements the binder[*ClientSession] interface, so that Clients can @@ -208,6 +209,7 @@ var clientMethodInfos = map[string]methodInfo[ClientSession]{ notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)), notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)), notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)), + notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)), } var _ session[ClientSession] = (*ClientSession)(nil) @@ -273,6 +275,10 @@ func (c *ClientSession) CallTool(ctx context.Context, name string, args map[stri return standardCall[CallToolResult](ctx, c.conn, methodCallTool, params) } +func (c *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error { + return call(ctx, c.conn, methodSetLevel, params, nil) +} + // NOTE: the following struct should consist of all fields of callToolParams except name and arguments. // CallToolOptions contains options to [ClientSession.CallTool]. @@ -302,6 +308,13 @@ func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSessio return callNotificationHandler(ctx, c.opts.ResourceListChangedHandler, s, params) } +func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, params *LoggingMessageParams) (any, error) { + if h := c.opts.LoggingMessageHandler; h != nil { + h(ctx, cs, params) + } + return nil, nil +} + func standardCall[TRes, TParams any](ctx context.Context, conn *jsonrpc2.Connection, method string, params TParams) (*TRes, error) { var result TRes if err := call(ctx, conn, method, params, &result); err != nil { diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 91963d8f84d..28de9aa6b4b 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -823,6 +823,27 @@ Clients call the spec method `Complete` to request completions. Servers automati ### Logging +MCP specifies a notification for servers to log to clients. Server sessions implement this with the `LoggingMessage` method. It honors the minimum log level established by the client session's `SetLevel` call. + +As a convenience, we also provide a `slog.Handler` that allows server authors to write logs with the `log/slog` package:: +```go +// A LoggingHandler is a [slog.Handler] for MCP. +type LoggingHandler struct {...} + +// LoggingHandlerOptions are options for a LoggingHandler. +type LoggingHandlerOptions struct { + // The value for the "logger" field of logging notifications. + LoggerName string + // Limits the rate at which log messages are sent. + // If zero, there is no rate limiting. + MinInterval time.Duration +} + +// NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a +// [slog.JSONHandler]. +func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingHandler +``` + Server-to-client logging is configured with `ServerOptions`: ```go @@ -832,24 +853,24 @@ type ServerOptions { LoggerName string // Log notifications to a single ClientSession will not be // sent more frequently than this duration. - LogInterval time.Duration + LoggingInterval time.Duration } ``` -Server sessions have a field `Logger` holding a `slog.Logger` that writes to the client session. A call to a log method like `Info` is translated to a `LoggingMessageNotification` as follows: +A call to a log method like `Info` is translated to a `LoggingMessageNotification` as follows: - The attributes and the message populate the "data" property with the output of a `slog.JSONHandler`: The result is always a JSON object, with the key "msg" for the message. - If the `LoggerName` server option is set, it populates the "logger" property. -- The standard slog levels `Info`, `Debug`, `Warn` and `Error` map to the corresponding levels in the MCP spec. The other spec levels map to integers between the slog levels. For example, "notice" is level 2 because it is between "warning" (slog value 4) and "info" (slog value 0). The `mcp` package defines consts for these levels. To log at the "notice" level, a handler would call `session.Logger.Log(ctx, mcp.LevelNotice, "message")`. +- The standard slog levels `Info`, `Debug`, `Warn` and `Error` map to the corresponding levels in the MCP spec. The other spec levels map to integers between the slog levels. For example, "notice" is level 2 because it is between "warning" (slog value 4) and "info" (slog value 0). The `mcp` package defines consts for these levels. To log at the "notice" level, a handler would call `Log(ctx, mcp.LevelNotice, "message")`. A client that wishes to receive log messages must provide a handler: ```go type ClientOptions struct { ... - LogMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams) + LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams) } ``` diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index c557c7c2d1c..a7cfe6d745e 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -19,6 +19,7 @@ import ( "go/format" "io" "log" + "maps" "net/http" "os" "reflect" @@ -110,7 +111,18 @@ var declarations = config{ Name: "-", Fields: config{"Params": {Name: "ListToolsParams"}}, }, - "ListToolsResult": {}, + "ListToolsResult": {}, + "loggingCapabilities": {Substitute: "struct{}"}, + "LoggingLevel": {}, + "LoggingMessageNotification": { + Name: "-", + Fields: config{ + "Params": { + Name: "LoggingMessageParams", + Fields: config{"Data": {Substitute: "any"}}, + }, + }, + }, "ModelHint": {}, "ModelPreferences": {}, "PingRequest": { @@ -153,8 +165,13 @@ var declarations = config{ "Prompts": {Name: "promptCapabilities"}, "Resources": {Name: "resourceCapabilities"}, "Tools": {Name: "toolCapabilities"}, + "Logging": {Name: "loggingCapabilities"}, }, }, + "SetLevelRequest": { + Name: "-", + Fields: config{"Params": {Name: "SetLevelParams"}}, + }, "Tool": { Fields: config{"InputSchema": {Substitute: "*jsonschema.Schema"}}, }, @@ -220,7 +237,7 @@ import ( } // Write out method names. fmt.Fprintln(buf, `const (`) - for name, s := range schema.Definitions { + for _, name := range slices.Sorted(maps.Keys(schema.Definitions)) { prefix := "method" method, found := strings.CutSuffix(name, "Request") if !found { @@ -228,7 +245,7 @@ import ( method, found = strings.CutSuffix(name, "Notification") } if found { - if ms, ok := s.Properties["method"]; ok { + if ms, ok := schema.Definitions[name].Properties["method"]; ok { if c := ms.Const; c != nil { fmt.Fprintf(buf, "%s%s = %q\n", prefix, method, *c) } @@ -395,9 +412,9 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma fieldTypeSchema = rs } needPointer := isStruct(fieldTypeSchema) - // Special case: there are no sampling capabilities defined, but - // we want it to be a struct for future expansion. - if !needPointer && name == "sampling" { + // Special case: there are no sampling or logging capabilities defined, + // but we want them to be structs for future expansion. + if !needPointer && (name == "sampling" || name == "logging") { needPointer = true } if config != nil && config.Fields[export] != nil { diff --git a/internal/mcp/logging.go b/internal/mcp/logging.go new file mode 100644 index 00000000000..14c5420bd04 --- /dev/null +++ b/internal/mcp/logging.go @@ -0,0 +1,189 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "log/slog" + "sync" + "time" +) + +// Logging levels. +const ( + LevelDebug = slog.LevelDebug + LevelInfo = slog.LevelInfo + LevelNotice = (slog.LevelInfo + slog.LevelWarn) / 2 + LevelWarning = slog.LevelWarn + LevelError = slog.LevelError + LevelCritical = slog.LevelError + 4 + LevelAlert = slog.LevelError + 8 + LevelEmergency = slog.LevelError + 12 +) + +var slogToMCP = map[slog.Level]LoggingLevel{ + LevelDebug: "debug", + LevelInfo: "info", + LevelNotice: "notice", + LevelWarning: "warning", + LevelError: "error", + LevelCritical: "critical", + LevelAlert: "alert", + LevelEmergency: "emergency", +} + +var mcpToSlog = make(map[LoggingLevel]slog.Level) + +func init() { + for sl, ml := range slogToMCP { + mcpToSlog[ml] = sl + } +} + +func slogLevelToMCP(sl slog.Level) LoggingLevel { + if ml, ok := slogToMCP[sl]; ok { + return ml + } + return "debug" // for lack of a better idea +} + +func mcpLevelToSlog(ll LoggingLevel) slog.Level { + if sl, ok := mcpToSlog[ll]; ok { + return sl + } + // TODO: is there a better default? + return LevelDebug +} + +// compareLevels behaves like [cmp.Compare] for [LoggingLevel]s. +func compareLevels(l1, l2 LoggingLevel) int { + return cmp.Compare(mcpLevelToSlog(l1), mcpLevelToSlog(l2)) +} + +// LoggingHandlerOptions are options for a LoggingHandler. +type LoggingHandlerOptions struct { + // The value for the "logger" field of logging notifications. + LoggerName string + // Limits the rate at which log messages are sent. + // If zero, there is no rate limiting. + MinInterval time.Duration +} + +// A LoggingHandler is a [slog.Handler] for MCP. +type LoggingHandler struct { + opts LoggingHandlerOptions + ss *ServerSession + // Ensures that the buffer reset is atomic with the write (see Handle). + // A pointer so that clones share the mutex. See + // https://github.com/golang/example/blob/master/slog-handler-guide/README.md#getting-the-mutex-right. + mu *sync.Mutex + lastMessageSent time.Time // for rate-limiting + buf *bytes.Buffer + handler slog.Handler +} + +// NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a +// [slog.JSONHandler]. +func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingHandler { + var buf bytes.Buffer + jsonHandler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { + // Remove level: it appears in LoggingMessageParams. + if a.Key == slog.LevelKey { + return slog.Attr{} + } + return a + }, + }) + lh := &LoggingHandler{ + ss: ss, + mu: new(sync.Mutex), + buf: &buf, + handler: jsonHandler, + } + if opts != nil { + lh.opts = *opts + } + return lh +} + +// Enabled implements [slog.Handler.Enabled] by comparing level to the [ServerSession]'s level. +func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { + // This is also checked in ServerSession.LoggingMessage, so checking it here + // is just an optimization that skips building the JSON. + h.ss.mu.Lock() + mcpLevel := h.ss.logLevel + h.ss.mu.Unlock() + return level >= mcpLevelToSlog(mcpLevel) +} + +// WithAttrs implements [slog.Handler.WithAttrs]. +func (h *LoggingHandler) WithAttrs(as []slog.Attr) slog.Handler { + h2 := *h + h2.handler = h.handler.WithAttrs(as) + return &h2 +} + +// WithGroup implements [slog.Handler.WithGroup]. +func (h *LoggingHandler) WithGroup(name string) slog.Handler { + h2 := *h + h2.handler = h.handler.WithGroup(name) + return &h2 +} + +// Handle implements [slog.Handler.Handle] by writing the Record to a JSONHandler, +// then calling [ServerSession.LoggingMesssage] with the result. +func (h *LoggingHandler) Handle(ctx context.Context, r slog.Record) error { + err := h.handle(ctx, r) + // TODO(jba): find a way to surface the error. + // The return value will probably be ignored. + return err +} + +func (h *LoggingHandler) handle(ctx context.Context, r slog.Record) error { + // Observe the rate limit. + // TODO(jba): use golang.org/x/time/rate. (We can't here because it would require adding + // golang.org/x/time to the go.mod file.) + h.mu.Lock() + skip := time.Since(h.lastMessageSent) < h.opts.MinInterval + h.mu.Unlock() + if skip { + return nil + } + + var err error + // Make the buffer reset atomic with the record write. + // We are careful here in the unlikely event that the handler panics. + // We don't want to hold the lock for the entire function, because Notify is + // an I/O operation. + // This can result in out-of-order delivery. + func() { + h.mu.Lock() + defer h.mu.Unlock() + h.buf.Reset() + err = h.handler.Handle(ctx, r) + }() + if err != nil { + return err + } + + h.mu.Lock() + h.lastMessageSent = time.Now() + h.mu.Unlock() + + params := &LoggingMessageParams{ + Logger: h.opts.LoggerName, + Level: slogLevelToMCP(r.Level), + Data: json.RawMessage(h.buf.Bytes()), + } + // We pass the argument context to Notify, even though slog.Handler.Handle's + // documentation says not to. + // In this case logging is a service to clients, not a means for debugging the + // server, so we want to cancel the log message. + return h.ss.LoggingMessage(ctx, params) +} diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index f783496e19e..a9ca09b67c5 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "log/slog" "path/filepath" "runtime" "slices" @@ -104,6 +105,7 @@ func TestEndToEnd(t *testing.T) { clientWG.Done() }() + loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging opts := &ClientOptions{ CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) { return &CreateMessageResult{Model: "aModel"}, nil @@ -111,6 +113,9 @@ func TestEndToEnd(t *testing.T) { ToolListChangedHandler: func(context.Context, *ClientSession, *ToolListChangedParams) { notificationChans["tools"] <- 0 }, PromptListChangedHandler: func(context.Context, *ClientSession, *PromptListChangedParams) { notificationChans["prompts"] <- 0 }, ResourceListChangedHandler: func(context.Context, *ClientSession, *ResourceListChangedParams) { notificationChans["resources"] <- 0 }, + LoggingMessageHandler: func(_ context.Context, _ *ClientSession, lm *LoggingMessageParams) { + loggingMessages <- lm + }, } c := NewClient("testClient", "v1.0.0", opts) rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) @@ -319,6 +324,85 @@ func TestEndToEnd(t *testing.T) { t.Errorf("got %q, want %q", g, w) } }) + t.Run("logging", func(t *testing.T) { + want := []*LoggingMessageParams{ + { + Logger: "test", + Level: "warning", + Data: map[string]any{ + "msg": "first", + "name": "Pat", + "logtest": true, + }, + }, + { + Logger: "test", + Level: "alert", + Data: map[string]any{ + "msg": "second", + "count": 2.0, + "logtest": true, + }, + }, + } + + check := func(t *testing.T) { + t.Helper() + var got []*LoggingMessageParams + // Read messages from this test until we've seen all we expect. + for len(got) < len(want) { + select { + case p := <-loggingMessages: + // Ignore logging from other tests. + if m, ok := p.Data.(map[string]any); ok && m["logtest"] != nil { + delete(m, "time") // remove time because it changes + got = append(got, p) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for log messages") + } + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } + } + + t.Run("direct", func(t *testing.T) { // Use the LoggingMessage method directly. + + mustLog := func(level LoggingLevel, data any) { + t.Helper() + if err := ss.LoggingMessage(ctx, &LoggingMessageParams{ + Logger: "test", + Level: level, + Data: data, + }); err != nil { + t.Fatal(err) + } + } + + // Nothing should be logged until the client sets a level. + mustLog("info", "before") + if err := cs.SetLevel(ctx, &SetLevelParams{Level: "warning"}); err != nil { + t.Fatal(err) + } + mustLog("warning", want[0].Data) + mustLog("debug", "nope") // below the level + mustLog("info", "negative") // below the level + mustLog("alert", want[1].Data) + check(t) + }) + + t.Run("handler", func(t *testing.T) { // Use the slog handler. + // We can't check the "before SetLevel" behavior because it's already been set. + // Not a big deal: that check is in LoggingMessage anyway. + logger := slog.New(NewLoggingHandler(ss, &LoggingHandlerOptions{LoggerName: "test"})) + logger.Warn("first", "name", "Pat", "logtest", true) + logger.Debug("nope") // below the level + logger.Info("negative") // below the level + logger.Log(ctx, LevelAlert, "second", "count", 2, "logtest", true) + check(t) + }) + }) // Disconnect. cs.Close() diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index 9f222222a49..25aeff1f567 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -241,6 +241,22 @@ type ListToolsResult struct { Tools []*Tool `json:"tools"` } +// The severity of a log message. +// +// These map to syslog message severities, as specified in RFC-5424: +// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 +type LoggingLevel string + +type LoggingMessageParams struct { + // The data to be logged, such as a string message or an object. Any JSON + // serializable type is allowed here. + Data any `json:"data"` + // The severity of this log message. + Level LoggingLevel `json:"level"` + // An optional name of the logger issuing this message. + Logger string `json:"logger,omitempty"` +} + // Hints to use for model selection. // // Keys not declared here are currently left unspecified by the spec and are up @@ -326,8 +342,7 @@ type PromptArgument struct { type PromptListChangedParams struct { // This parameter name is reserved by MCP to allow clients and servers to attach // additional metadata to their notifications. - Meta struct { - } `json:"_meta,omitempty"` + Meta map[string]json.RawMessage `json:"_meta,omitempty"` } // Describes a message returned as part of a prompt. @@ -381,8 +396,7 @@ type Resource struct { type ResourceListChangedParams struct { // This parameter name is reserved by MCP to allow clients and servers to attach // additional metadata to their notifications. - Meta struct { - } `json:"_meta,omitempty"` + Meta map[string]json.RawMessage `json:"_meta,omitempty"` } // The sender or recipient of messages and data in a conversation. @@ -403,8 +417,7 @@ type Root struct { type RootsListChangedParams struct { // This parameter name is reserved by MCP to allow clients and servers to attach // additional metadata to their notifications. - Meta struct { - } `json:"_meta,omitempty"` + Meta map[string]json.RawMessage `json:"_meta,omitempty"` } // Present if the client supports sampling from an LLM. @@ -417,6 +430,13 @@ type SamplingMessage struct { Role Role `json:"role"` } +type SetLevelParams struct { + // The level of logging that the client wants to receive from the server. The + // server should send all logs at this level and higher (i.e., more severe) to + // the client as notifications/message. + Level LoggingLevel `json:"level"` +} + // Definition for a tool the client can call. type Tool struct { // Optional additional tool information. @@ -472,8 +492,7 @@ type ToolAnnotations struct { type ToolListChangedParams struct { // This parameter name is reserved by MCP to allow clients and servers to attach // additional metadata to their notifications. - Meta struct { - } `json:"_meta,omitempty"` + Meta map[string]json.RawMessage `json:"_meta,omitempty"` } // Describes the name and version of an MCP implementation. @@ -482,6 +501,10 @@ type implementation struct { Version string `json:"version"` } +// Present if the server supports sending log messages to the client. +type loggingCapabilities struct { +} + // Present if the server offers any prompt templates. type promptCapabilities struct { // Whether this server supports notifications for changes to the prompt list. @@ -507,8 +530,7 @@ type serverCapabilities struct { Experimental map[string]struct { } `json:"experimental,omitempty"` // Present if the server supports sending log messages to the client. - Logging struct { - } `json:"logging,omitempty"` + Logging *loggingCapabilities `json:"logging,omitempty"` // Present if the server offers any prompt templates. Prompts *promptCapabilities `json:"prompts,omitempty"` // Present if the server offers any resources to read. @@ -524,28 +546,28 @@ type toolCapabilities struct { } const ( - methodCreateMessage = "sampling/createMessage" - notificationToolListChanged = "notifications/tools/list_changed" - notificationResourceListChanged = "notifications/resources/list_changed" - methodListPrompts = "prompts/list" - notificationPromptListChanged = "notifications/prompts/list_changed" - notificationResourceUpdated = "notifications/resources/updated" + methodCallTool = "tools/call" notificationCancelled = "notifications/cancelled" - methodSetLevel = "logging/setLevel" - methodInitialize = "initialize" - methodListRoots = "roots/list" - notificationProgress = "notifications/progress" + methodComplete = "completion/complete" + methodCreateMessage = "sampling/createMessage" methodGetPrompt = "prompts/get" + methodInitialize = "initialize" + notificationInitialized = "notifications/initialized" + methodListPrompts = "prompts/list" methodListResourceTemplates = "resources/templates/list" - methodPing = "ping" - methodComplete = "completion/complete" methodListResources = "resources/list" - notificationInitialized = "notifications/initialized" - methodCallTool = "tools/call" + methodListRoots = "roots/list" + methodListTools = "tools/list" notificationLoggingMessage = "notifications/message" + methodPing = "ping" + notificationProgress = "notifications/progress" + notificationPromptListChanged = "notifications/prompts/list_changed" methodReadResource = "resources/read" - methodListTools = "tools/list" + notificationResourceListChanged = "notifications/resources/list_changed" + notificationResourceUpdated = "notifications/resources/updated" notificationRootsListChanged = "notifications/roots/list_changed" + methodSetLevel = "logging/setLevel" methodSubscribe = "resources/subscribe" + notificationToolListChanged = "notifications/tools/list_changed" methodUnsubscribe = "resources/unsubscribe" ) diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 93690615f25..3e491dab982 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -299,11 +299,11 @@ func (s *Server) Run(ctx context.Context, t Transport) error { // bind implements the binder[*ServerSession] interface, so that Servers can // be connected using [connect]. func (s *Server) bind(conn *jsonrpc2.Connection) *ServerSession { - cc := &ServerSession{conn: conn, server: s} + ss := &ServerSession{conn: conn, server: s} s.mu.Lock() - s.sessions = append(s.sessions, cc) + s.sessions = append(s.sessions, ss) s.mu.Unlock() - return cc + return ss } // disconnect implements the binder[*ServerSession] interface, so that @@ -341,17 +341,17 @@ func (s *Server) callRootsListChangedHandler(ctx context.Context, ss *ServerSess // Call [ServerSession.Close] to close the connection, or await client // termination with [ServerSession.Wait]. type ServerSession struct { - server *Server - conn *jsonrpc2.Connection - + server *Server + conn *jsonrpc2.Connection mu sync.Mutex + logLevel LoggingLevel initializeParams *InitializeParams initialized bool } // Ping pings the client. func (ss *ServerSession) Ping(ctx context.Context, _ *PingParams) error { - return call(ctx, ss.conn, "ping", (*PingParams)(nil), nil) + return call(ctx, ss.conn, methodPing, (*PingParams)(nil), nil) } // ListRoots lists the client roots. @@ -364,6 +364,25 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag return standardCall[CreateMessageResult](ctx, ss.conn, methodCreateMessage, params) } +// LoggingMessage sends a logging message to the client. +// The message is not sent if the client has not called SetLevel, or if its level +// is below that of the last SetLevel. +func (ss *ServerSession) LoggingMessage(ctx context.Context, params *LoggingMessageParams) error { + ss.mu.Lock() + logLevel := ss.logLevel + ss.mu.Unlock() + if logLevel == "" { + // The spec is unclear, but seems to imply that no log messages are sent until the client + // sets the level. + // TODO(jba): read other SDKs, possibly file an issue. + return nil + } + if compareLevels(params.Level, logLevel) < 0 { + return nil + } + return ss.conn.Notify(ctx, notificationLoggingMessage, params) +} + // AddMiddleware wraps the server's current method handler using the provided // middleware. Middleware is applied from right to left, so that the first one // is executed first. @@ -386,6 +405,7 @@ var serverMethodInfos = map[string]methodInfo[ServerSession]{ methodCallTool: newMethodInfo(serverMethod((*Server).callTool)), methodListResources: newMethodInfo(serverMethod((*Server).listResources)), methodReadResource: newMethodInfo(serverMethod((*Server).readResource)), + methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel)), notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler)), notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)), } @@ -457,6 +477,7 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam Resources: &resourceCapabilities{ ListChanged: true, }, + Logging: &loggingCapabilities{}, }, Instructions: ss.server.opts.Instructions, ServerInfo: &implementation{ @@ -470,6 +491,13 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (struct{}, error) { return struct{}{}, nil } +func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (struct{}, error) { + ss.mu.Lock() + defer ss.mu.Unlock() + ss.logLevel = params.Level + return struct{}{}, nil +} + // Close performs a graceful shutdown of the connection, preventing new // requests from being handled, and waiting for ongoing requests to return. // Close then terminates the connection. From 6e44d1ebf6ebaf441932b554e3dbbfee0635adb1 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 19 May 2025 10:16:11 -0400 Subject: [PATCH 120/196] internal/mcp: future-proof prompt handling Pass GetPromptParams as needed, in case that type ever gets additional properties. Change-Id: I83fa37bb2af5633a8c2bcdc47a85c3030dfe70fb Reviewed-on: https://go-review.googlesource.com/c/tools/+/674895 Auto-Submit: Jonathan Amsterdam Reviewed-by: Sam Thanawalla LUCI-TryBot-Result: Go LUCI --- internal/mcp/examples/hello/main.go | 6 +++--- internal/mcp/mcp_test.go | 19 ++++++++++--------- internal/mcp/prompt.go | 16 +++++++++------- internal/mcp/prompt_test.go | 2 +- internal/mcp/server.go | 2 +- 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index b39b460f8ea..c672c5f393a 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -16,17 +16,17 @@ import ( var httpAddr = flag.String("http", "", "if set, use SSE HTTP at this address, instead of stdin/stdout") -type HiParams struct { +type HiArgs struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *HiParams) ([]*mcp.Content, error) { +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *HiArgs) ([]*mcp.Content, error) { return []*mcp.Content{ mcp.NewTextContent("Hi " + params.Name), }, nil } -func PromptHi(ctx context.Context, cc *mcp.ServerSession, params *HiParams) (*mcp.GetPromptResult, error) { +func PromptHi(ctx context.Context, cc *mcp.ServerSession, args *HiArgs, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Description: "Code review prompt", Messages: []*mcp.PromptMessage{ diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index a9ca09b67c5..798da88e313 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -73,15 +73,16 @@ func TestEndToEnd(t *testing.T) { ) s.AddPrompts( - NewPrompt("code_review", "do a code review", func(_ context.Context, _ *ServerSession, params struct{ Code string }) (*GetPromptResult, error) { - return &GetPromptResult{ - Description: "Code review prompt", - Messages: []*PromptMessage{ - {Role: "user", Content: NewTextContent("Please review the following code: " + params.Code)}, - }, - }, nil - }), - NewPrompt("fail", "", func(_ context.Context, _ *ServerSession, params struct{}) (*GetPromptResult, error) { + NewPrompt("code_review", "do a code review", + func(_ context.Context, _ *ServerSession, params struct{ Code string }, _ *GetPromptParams) (*GetPromptResult, error) { + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{ + {Role: "user", Content: NewTextContent("Please review the following code: " + params.Code)}, + }, + }, nil + }), + NewPrompt("fail", "", func(_ context.Context, _ *ServerSession, args struct{}, _ *GetPromptParams) (*GetPromptResult, error) { return nil, failure }), ) diff --git a/internal/mcp/prompt.go b/internal/mcp/prompt.go index 2c4c757f9bc..f57ccfb1069 100644 --- a/internal/mcp/prompt.go +++ b/internal/mcp/prompt.go @@ -16,7 +16,7 @@ import ( ) // A PromptHandler handles a call to prompts/get. -type PromptHandler func(context.Context, *ServerSession, map[string]string) (*GetPromptResult, error) +type PromptHandler func(context.Context, *ServerSession, *GetPromptParams) (*GetPromptResult, error) // A Prompt is a prompt definition bound to a prompt handler. type ServerPrompt struct { @@ -24,15 +24,17 @@ type ServerPrompt struct { Handler PromptHandler } -// NewPrompt is a helper to use reflection to create a prompt for the given -// handler. +// NewPrompt is a helper that uses reflection to create a prompt for the given handler. // // The arguments for the prompt are extracted from the request type for the // handler. The handler request type must be a struct consisting only of fields // of type string or *string. The argument names for the resulting prompt // definition correspond to the JSON names of the request fields, and any // fields that are not marked "omitempty" are considered required. -func NewPrompt[TReq any](name, description string, handler func(context.Context, *ServerSession, TReq) (*GetPromptResult, error), opts ...PromptOption) *ServerPrompt { +// +// The handler is passed [GetPromptParams] so it can have access to prompt parameters other than name and arguments. +// At present, there are no such parameters. +func NewPrompt[TReq any](name, description string, handler func(context.Context, *ServerSession, TReq, *GetPromptParams) (*GetPromptResult, error), opts ...PromptOption) *ServerPrompt { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) @@ -60,10 +62,10 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, Required: required[name], }) } - prompt.Handler = func(ctx context.Context, cc *ServerSession, args map[string]string) (*GetPromptResult, error) { + prompt.Handler = func(ctx context.Context, ss *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { // For simplicity, just marshal and unmarshal the arguments. // This could be avoided in the future. - rawArgs, err := json.Marshal(args) + rawArgs, err := json.Marshal(params.Arguments) if err != nil { return nil, err } @@ -71,7 +73,7 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, if err := unmarshalSchema(rawArgs, schema, &v); err != nil { return nil, err } - return handler(ctx, cc, v) + return handler(ctx, ss, v, params) } for _, opt := range opts { opt.set(prompt) diff --git a/internal/mcp/prompt_test.go b/internal/mcp/prompt_test.go index 4de5aa93d9f..eb6de5cae9c 100644 --- a/internal/mcp/prompt_test.go +++ b/internal/mcp/prompt_test.go @@ -13,7 +13,7 @@ import ( ) // testPromptHandler is used for type inference in TestNewPrompt. -func testPromptHandler[T any](context.Context, *mcp.ServerSession, T) (*mcp.GetPromptResult, error) { +func testPromptHandler[T any](context.Context, *mcp.ServerSession, T, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { panic("not implemented") } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 3e491dab982..87a14acecb4 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -180,7 +180,7 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPr // TODO: surface the error code over the wire, instead of flattening it into the string. return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, params.Name) } - return prompt.Handler(ctx, cc, params.Arguments) + return prompt.Handler(ctx, cc, params) } func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) { From f5ea575184fcfac30af466c0bf4676965e52f3fd Mon Sep 17 00:00:00 2001 From: cuishuang Date: Fri, 23 May 2025 17:25:24 +0800 Subject: [PATCH 121/196] go/ssa/interp: use slices.Equal to simplify code Change-Id: Ia479b7ef55acad189ae8cc0846780cbdc41b2298 Reviewed-on: https://go-review.googlesource.com/c/tools/+/675517 LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase Reviewed-by: Robert Findley --- go/ssa/interp/external.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/go/ssa/interp/external.go b/go/ssa/interp/external.go index 2fb683c07fe..9de53ffd9d3 100644 --- a/go/ssa/interp/external.go +++ b/go/ssa/interp/external.go @@ -13,6 +13,7 @@ import ( "math" "os" "runtime" + "slices" "sort" "strconv" "strings" @@ -119,15 +120,7 @@ func ext۰bytes۰Equal(fr *frame, args []value) value { // func Equal(a, b []byte) bool a := args[0].([]value) b := args[1].([]value) - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true + return slices.Equal(a, b) } func ext۰bytes۰IndexByte(fr *frame, args []value) value { From 60df06fb2b1141ef46bccd79d4745c954747b864 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Wed, 21 May 2025 15:31:56 +0000 Subject: [PATCH 122/196] internal/mcp: add pagination for tools This CL adds paginating functionality for tools. It uses the gob encoder for creating opaque cursors. More CLs to follow for paginating other features. Change-Id: I1443c0213ceb6238d844a8eee9a0be52934f5cab Reviewed-on: https://go-review.googlesource.com/c/tools/+/675055 Reviewed-by: Robert Findley Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/mcp/server.go | 93 +++++++++++++++- internal/mcp/server_example_test.go | 161 ++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+), 1 deletion(-) diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 87a14acecb4..104fa645089 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -5,7 +5,10 @@ package mcp import ( + "bytes" "context" + "encoding/base64" + "encoding/gob" "fmt" "iter" "net/url" @@ -16,6 +19,8 @@ import ( jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" ) +const DefaultPageSize = 1000 + // A Server is an instance of an MCP server. // // Servers expose server-side MCP features, which can serve one or more MCP @@ -40,6 +45,9 @@ type ServerOptions struct { Instructions string // If non-nil, called when "notifications/intialized" is received. InitializedHandler func(context.Context, *ServerSession, *InitializedParams) + // PageSize is the maximum number of items to return in a single page for + // list methods (e.g. ListTools). + PageSize int // If non-nil, called when "notifications/roots/list_changed" is received. RootsListChangedHandler func(context.Context, *ServerSession, *RootsListChangedParams) } @@ -55,6 +63,12 @@ func NewServer(name, version string, opts *ServerOptions) *Server { if opts == nil { opts = new(ServerOptions) } + if opts.PageSize < 0 { + panic(fmt.Errorf("invalid page size %d", opts.PageSize)) + } + if opts.PageSize == 0 { + opts.PageSize = DefaultPageSize + } return &Server{ name: name, version: version, @@ -186,8 +200,17 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPr func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() + var cursor string + if params != nil { + cursor = params.Cursor + } + tools, nextCursor, err := paginateList(s.tools, cursor, s.opts.PageSize) + if err != nil { + return nil, err + } res := new(ListToolsResult) - for t := range s.tools.all() { + res.NextCursor = nextCursor + for _, t := range tools { res.Tools = append(res.Tools, t.Tool) } return res, nil @@ -509,3 +532,71 @@ func (ss *ServerSession) Close() error { func (ss *ServerSession) Wait() error { return ss.conn.Wait() } + +// pageToken is the internal structure for the opaque pagination cursor. +// It will be Gob-encoded and then Base64-encoded for use as a string token. +type pageToken struct { + LastUID string // The unique ID of the last resource seen. +} + +// paginateList returns a slice of features from the given featureSet, based on +// the provided cursor and page size. It also returns a new cursor for the next +// page, or an empty string if there are no more pages. +func paginateList[T any](fs *featureSet[T], cursor string, pageSize int) (features []T, nextCursor string, err error) { + encodeCursor := func(uid string) (string, error) { + var buf bytes.Buffer + token := pageToken{LastUID: uid} + encoder := gob.NewEncoder(&buf) + if err := encoder.Encode(token); err != nil { + return "", fmt.Errorf("failed to encode page token: %w", err) + } + return base64.URLEncoding.EncodeToString(buf.Bytes()), nil + } + + decodeCursor := func(cursor string) (*pageToken, error) { + decodedBytes, err := base64.URLEncoding.DecodeString(cursor) + if err != nil { + return nil, fmt.Errorf("failed to decode cursor: %w", err) + } + + var token pageToken + buf := bytes.NewBuffer(decodedBytes) + decoder := gob.NewDecoder(buf) + if err := decoder.Decode(&token); err != nil { + return nil, fmt.Errorf("failed to decode page token: %w, cursor: %v", err, cursor) + } + return &token, nil + } + + var seq iter.Seq[T] + if cursor == "" { + seq = fs.all() + } else { + pageToken, err := decodeCursor(cursor) + // According to the spec, invalid cursors should return Invalid params. + if err != nil { + return nil, "", jsonrpc2.ErrInvalidParams + } + seq = fs.above(pageToken.LastUID) + } + var count int + for f := range seq { + count++ + // If we've seen pageSize + 1 elements, we've gathered enough info to determine + // if there's a next page. Stop processing the sequence. + if count == pageSize+1 { + break + } + features = append(features, f) + } + // No remaining pages. + if count < pageSize+1 { + return features, "", nil + } + // Trim the extra element from the result. + nextCursor, err = encodeCursor(fs.uniqueID(features[len(features)-1])) + if err != nil { + return nil, "", err + } + return features, nextCursor, nil +} diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index 4a9a9c7044c..fd1d22f3580 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -8,8 +8,13 @@ import ( "context" "fmt" "log" + "slices" + "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/tools/internal/mcp" + "golang.org/x/tools/internal/mcp/jsonschema" ) type SayHiParams struct { @@ -51,3 +56,159 @@ func ExampleServer() { // Output: Hi user } + +func TestListTool(t *testing.T) { + toolA := mcp.NewTool("apple", "apple tool", SayHi) + toolB := mcp.NewTool("banana", "banana tool", SayHi) + toolC := mcp.NewTool("cherry", "cherry tool", SayHi) + testCases := []struct { + tools []*mcp.ServerTool + want []*mcp.Tool + pageSize int + }{ + { + // Simple test. + []*mcp.ServerTool{toolA, toolB, toolC}, + []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool}, + mcp.DefaultPageSize, + }, + { + // Tools should be ordered by tool name. + []*mcp.ServerTool{toolC, toolA, toolB}, + []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool}, + mcp.DefaultPageSize, + }, + { + // Page size of 1 should yield the first tool only. + []*mcp.ServerTool{toolC, toolA, toolB}, + []*mcp.Tool{toolA.Tool}, + 1, + }, + { + // Page size of 2 should yield the first 2 tools only. + []*mcp.ServerTool{toolC, toolA, toolB}, + []*mcp.Tool{toolA.Tool, toolB.Tool}, + 2, + }, + { + // Page size of 3 should yield all tools. + []*mcp.ServerTool{toolC, toolA, toolB}, + []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool}, + 3, + }, + { + []*mcp.ServerTool{}, + nil, + 1, + }, + } + ctx := context.Background() + for _, tc := range testCases { + server := mcp.NewServer("server", "v0.0.1", &mcp.ServerOptions{PageSize: tc.pageSize}) + server.AddTools(tc.tools...) + clientTransport, serverTransport := mcp.NewInMemoryTransports() + serverSession, err := server.Connect(ctx, serverTransport) + if err != nil { + log.Fatal(err) + } + client := mcp.NewClient("client", "v0.0.1", nil) + clientSession, err := client.Connect(ctx, clientTransport) + if err != nil { + log.Fatal(err) + } + res, err := clientSession.ListTools(ctx, nil) + serverSession.Close() + clientSession.Close() + if err != nil { + log.Fatal(err) + } + if len(res.Tools) != len(tc.want) { + t.Fatalf("expected %d tools, got %d", len(tc.want), len(res.Tools)) + } + if diff := cmp.Diff(res.Tools, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("expected tools %+v, got %+v", tc.want, res.Tools) + } + if tc.pageSize < len(tc.tools) && res.NextCursor == "" { + t.Fatalf("expected next cursor, got none") + } + } +} + +func TestListToolPaginateInvalidCursor(t *testing.T) { + toolA := mcp.NewTool("apple", "apple tool", SayHi) + ctx := context.Background() + server := mcp.NewServer("server", "v0.0.1", nil) + server.AddTools(toolA) + clientTransport, serverTransport := mcp.NewInMemoryTransports() + serverSession, err := server.Connect(ctx, serverTransport) + if err != nil { + log.Fatal(err) + } + client := mcp.NewClient("client", "v0.0.1", nil) + clientSession, err := client.Connect(ctx, clientTransport) + if err != nil { + log.Fatal(err) + } + _, err = clientSession.ListTools(ctx, &mcp.ListToolsParams{Cursor: "invalid"}) + if err == nil { + t.Fatalf("expected error, got none") + } + serverSession.Close() + clientSession.Close() +} + +func TestListToolPaginate(t *testing.T) { + serverTools := []*mcp.ServerTool{ + mcp.NewTool("apple", "apple tool", SayHi), + mcp.NewTool("banana", "banana tool", SayHi), + mcp.NewTool("cherry", "cherry tool", SayHi), + mcp.NewTool("durian", "durian tool", SayHi), + mcp.NewTool("elderberry", "elderberry tool", SayHi), + } + var wantTools []*mcp.Tool + for _, tool := range serverTools { + wantTools = append(wantTools, tool.Tool) + } + ctx := context.Background() + // Try all possible page sizes, ensuring we get the correct list of tools. + for pageSize := 1; pageSize < len(serverTools)+1; pageSize++ { + server := mcp.NewServer("server", "v0.0.1", &mcp.ServerOptions{PageSize: pageSize}) + server.AddTools(serverTools...) + clientTransport, serverTransport := mcp.NewInMemoryTransports() + serverSession, err := server.Connect(ctx, serverTransport) + if err != nil { + log.Fatal(err) + } + client := mcp.NewClient("client", "v0.0.1", nil) + clientSession, err := client.Connect(ctx, clientTransport) + if err != nil { + log.Fatal(err) + } + var gotTools []*mcp.Tool + var nextCursor string + wantChunks := slices.Collect(slices.Chunk(wantTools, pageSize)) + index := 0 + // Iterate through all pages, comparing sub-slices to the paginated list. + for { + res, err := clientSession.ListTools(ctx, &mcp.ListToolsParams{Cursor: nextCursor}) + if err != nil { + log.Fatal(err) + } + gotTools = append(gotTools, res.Tools...) + nextCursor = res.NextCursor + if diff := cmp.Diff(res.Tools, wantChunks[index], cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Errorf("expected %v, got %v, (-want +got):\n%s", wantChunks[index], res.Tools, diff) + } + if res.NextCursor == "" { + break + } + index++ + } + serverSession.Close() + clientSession.Close() + + if len(gotTools) != len(wantTools) { + t.Fatalf("expected %d tools, got %d", len(wantTools), len(gotTools)) + } + } +} From 08af7d4b4570722fcf42c4a4d1dbe1c5d6325215 Mon Sep 17 00:00:00 2001 From: cuishuang Date: Fri, 23 May 2025 15:07:50 +0800 Subject: [PATCH 123/196] all: fix some function names and typos in comment Change-Id: I8066b22cb87e09fb160de3b8ff2d8fdde6ec3dcf Reviewed-on: https://go-review.googlesource.com/c/tools/+/675815 Reviewed-by: David Chase Auto-Submit: Alan Donovan Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- cmd/signature-fuzzer/internal/fuzz-generator/gen_test.go | 2 +- cmd/stringer/endtoend_test.go | 2 +- go/analysis/checker/example_test.go | 2 +- go/analysis/doc/suggested_fixes.md | 2 +- go/analysis/internal/checker/checker.go | 2 +- go/analysis/passes/asmdecl/asmdecl.go | 2 +- go/analysis/passes/loopclosure/loopclosure.go | 2 +- go/analysis/passes/structtag/structtag.go | 2 +- go/analysis/unitchecker/unitchecker_test.go | 2 +- go/callgraph/vta/graph.go | 4 ++-- go/callgraph/vta/internal/trie/builder.go | 2 +- go/loader/doc.go | 2 +- go/packages/overlay_test.go | 4 ++-- go/packages/packages_test.go | 4 ++-- go/ssa/builder.go | 2 +- go/ssa/builder_generic_test.go | 2 +- go/ssa/subst.go | 2 +- go/ssa/util.go | 2 +- gopls/doc/assets/go.mod | 2 +- gopls/doc/contributing.md | 2 +- gopls/doc/emacs.md | 2 +- gopls/doc/features/passive.md | 2 +- gopls/doc/features/templates.md | 2 +- gopls/doc/release/v0.17.0.md | 4 ++-- gopls/doc/release/v0.19.0.md | 2 +- .../analysis/modernize/testdata/src/minmax/minmax.go | 2 +- .../analysis/modernize/testdata/src/minmax/minmax.go.golden | 2 +- .../testdata/src/generatedcode/nongeneratedcode.go | 2 +- .../testdata/src/generatedcode/nongeneratedcode.go.golden | 2 +- gopls/internal/analysis/yield/yield.go | 2 +- gopls/internal/cache/check.go | 6 +++--- gopls/internal/cache/mod_tidy.go | 2 +- gopls/internal/cache/parse_cache.go | 2 +- gopls/internal/cache/parsego/parse.go | 4 ++-- gopls/internal/cache/session.go | 4 ++-- gopls/internal/cache/snapshot.go | 6 +++--- gopls/internal/cmd/codelens.go | 2 +- gopls/internal/cmd/usage/codelens.hlp | 2 +- gopls/internal/golang/addtest.go | 2 +- gopls/internal/golang/completion/completion.go | 4 ++-- gopls/internal/golang/completion/keywords.go | 4 ++-- gopls/internal/golang/completion/literal.go | 2 +- gopls/internal/golang/completion/package.go | 2 +- gopls/internal/golang/highlight.go | 2 +- gopls/internal/golang/implementation.go | 2 +- gopls/internal/golang/rename_check.go | 2 +- gopls/internal/golang/type_hierarchy.go | 2 +- gopls/internal/protocol/command/interface.go | 2 +- gopls/internal/server/command.go | 2 +- gopls/internal/server/prompt.go | 2 +- gopls/internal/settings/analysis.go | 2 +- gopls/internal/test/integration/bench/bench_test.go | 2 +- gopls/internal/test/integration/bench/imports_test.go | 2 +- gopls/internal/test/integration/misc/hover_test.go | 2 +- gopls/internal/test/integration/web/pkdoc_test.go | 2 +- gopls/internal/test/marker/testdata/quickfix/stub.txt | 2 +- internal/analysisinternal/analysis.go | 4 ++-- internal/diff/lcs/common.go | 2 +- internal/imports/fix.go | 2 +- internal/imports/imports.go | 2 +- internal/mcp/README.md | 2 +- internal/mcp/content.go | 2 +- internal/mcp/internal/readme/server/server.go | 2 +- internal/mcp/internal/util/util.go | 4 ++-- internal/mcp/jsonschema/resolve_test.go | 2 +- internal/mcp/server.go | 4 ++-- internal/mcp/shared.go | 2 +- internal/mcp/sse.go | 4 ++-- internal/modindex/directories.go | 2 +- internal/modindex/gomodindex/cmd.go | 4 ++-- internal/modindex/lookup.go | 2 +- internal/modindex/symbols.go | 4 ++-- internal/refactor/inline/callee.go | 2 +- internal/refactor/inline/inline.go | 2 +- internal/refactor/inline/testdata/method.txtar | 2 +- internal/typesinternal/zerovalue_test.go | 2 +- 76 files changed, 94 insertions(+), 94 deletions(-) diff --git a/cmd/signature-fuzzer/internal/fuzz-generator/gen_test.go b/cmd/signature-fuzzer/internal/fuzz-generator/gen_test.go index f10a7e9a7df..1c0cc993fba 100644 --- a/cmd/signature-fuzzer/internal/fuzz-generator/gen_test.go +++ b/cmd/signature-fuzzer/internal/fuzz-generator/gen_test.go @@ -112,7 +112,7 @@ func TestIsBuildable(t *testing.T) { verb(1, "output is: %s\n", string(coutput)) } -// TestExhaustive does a series of code genreation runs, starting with +// TestExhaustive does a series of code generation runs, starting with // (relatively) simple code and then getting progressively more // complex (more params, deeper structs, turning on additional // features such as address-taken vars and reflect testing). The diff --git a/cmd/stringer/endtoend_test.go b/cmd/stringer/endtoend_test.go index 721c1f68df5..8e062cc80f5 100644 --- a/cmd/stringer/endtoend_test.go +++ b/cmd/stringer/endtoend_test.go @@ -286,7 +286,7 @@ func TestTestFiles(t *testing.T) { } } -// The -output flag cannot be used in combiation with matching types across multiple packages. +// The -output flag cannot be used in combination with matching types across multiple packages. func TestCollidingOutput(t *testing.T) { testenv.NeedsTool(t, "go") stringer := stringerPath(t) diff --git a/go/analysis/checker/example_test.go b/go/analysis/checker/example_test.go index 91beeb1ed3f..524c5b7f2c5 100644 --- a/go/analysis/checker/example_test.go +++ b/go/analysis/checker/example_test.go @@ -69,7 +69,7 @@ func Example() { // min=bufio max=unsafe } -// minmaxpkg is a trival example analyzer that uses package facts to +// minmaxpkg is a trivial example analyzer that uses package facts to // compute information from the entire dependency graph. var minmaxpkg = &analysis.Analyzer{ Name: "minmaxpkg", diff --git a/go/analysis/doc/suggested_fixes.md b/go/analysis/doc/suggested_fixes.md index 74888f8a96e..6fa033fc136 100644 --- a/go/analysis/doc/suggested_fixes.md +++ b/go/analysis/doc/suggested_fixes.md @@ -79,7 +79,7 @@ These requirements guarantee that suggested fixes can be cleanly applied. Because a driver may only analyze, or be able to modify, the current package, we restrict edits to the current package. In general this restriction should not be a big problem for users because other packages might not belong to the -same module and so will not be safe to modify in a singe change. +same module and so will not be safe to modify in a single change. On the other hand, analyzers will not be required to produce gofmt-compliant code. Analysis drivers will be expected to apply gofmt to the results of diff --git a/go/analysis/internal/checker/checker.go b/go/analysis/internal/checker/checker.go index bc57dc6e673..19ebdac8460 100644 --- a/go/analysis/internal/checker/checker.go +++ b/go/analysis/internal/checker/checker.go @@ -99,7 +99,7 @@ func Run(args []string, analyzers []*analysis.Analyzer) (exitcode int) { // without having to remember what code to return. // // TODO(adonovan): interpreting exit codes is like reading tea-leaves. - // Insted of wasting effort trying to encode a multidimensional result + // Instead of wasting effort trying to encode a multidimensional result // into 7 bits we should just emit structured JSON output, and // an exit code of 0 or 1 for success or failure. exitAtLeast := func(code int) { diff --git a/go/analysis/passes/asmdecl/asmdecl.go b/go/analysis/passes/asmdecl/asmdecl.go index 436b03cb290..1aa7afb9c2a 100644 --- a/go/analysis/passes/asmdecl/asmdecl.go +++ b/go/analysis/passes/asmdecl/asmdecl.go @@ -57,7 +57,7 @@ type asmArch struct { // include the first integer register and first floating-point register. Accessing // any of them counts as writing to result. retRegs []string - // writeResult is a list of instructions that will change result register implicity. + // writeResult is a list of instructions that will change result register implicitly. writeResult []string // calculated during initialization sizes types.Sizes diff --git a/go/analysis/passes/loopclosure/loopclosure.go b/go/analysis/passes/loopclosure/loopclosure.go index 64df1b106a1..2580a0ac21f 100644 --- a/go/analysis/passes/loopclosure/loopclosure.go +++ b/go/analysis/passes/loopclosure/loopclosure.go @@ -88,7 +88,7 @@ func run(pass *analysis.Pass) (any, error) { // // TODO: consider allowing the "last" go/defer/Go statement to be followed by // N "trivial" statements, possibly under a recursive definition of "trivial" - // so that that checker could, for example, conclude that a go statement is + // so that checker could, for example, conclude that a go statement is // followed by an if statement made of only trivial statements and trivial expressions, // and hence the go statement could still be checked. forEachLastStmt(body.List, func(last ast.Stmt) { diff --git a/go/analysis/passes/structtag/structtag.go b/go/analysis/passes/structtag/structtag.go index 13a9997316e..cc90f7335ec 100644 --- a/go/analysis/passes/structtag/structtag.go +++ b/go/analysis/passes/structtag/structtag.go @@ -107,7 +107,7 @@ func checkCanonicalFieldTag(pass *analysis.Pass, field *types.Var, tag string, s // Embedded struct. Nothing to do for now, but that // may change, depending on what happens with issue 7363. - // TODO(adonovan): investigate, now that that issue is fixed. + // TODO(adonovan): investigate, now that issue is fixed. if field.Anonymous() { return } diff --git a/go/analysis/unitchecker/unitchecker_test.go b/go/analysis/unitchecker/unitchecker_test.go index 6c3bba6793e..6ab23bf61fd 100644 --- a/go/analysis/unitchecker/unitchecker_test.go +++ b/go/analysis/unitchecker/unitchecker_test.go @@ -170,7 +170,7 @@ func _() { // TODO(golang/go#65729): this is unsound: any extra // logging by the child process (e.g. due to GODEBUG // options) will add noise to stderr, causing the - // CombinedOutput to be unparseable as JSON. But we + // CombinedOutput to be unparsable as JSON. But we // can't simply use Output here as some of the tests // look for substrings of stderr. Rework the test to // be specific about which output stream to match. diff --git a/go/callgraph/vta/graph.go b/go/callgraph/vta/graph.go index 26225e7db37..f66f4c73b14 100644 --- a/go/callgraph/vta/graph.go +++ b/go/callgraph/vta/graph.go @@ -658,7 +658,7 @@ func (b *builder) call(c ssa.CallInstruction) { } func addArgumentFlows(b *builder, c ssa.CallInstruction, f *ssa.Function) { - // When f has no paremeters (including receiver), there is no type + // When f has no parameters (including receiver), there is no type // flow here. Also, f's body and parameters might be missing, such // as when vta is used within the golang.org/x/tools/go/analysis // framework (see github.com/golang/go/issues/50670). @@ -803,7 +803,7 @@ func (b *builder) nodeFromVal(val ssa.Value) node { return function{f: v} case *ssa.Parameter, *ssa.FreeVar, ssa.Instruction: // ssa.Param, ssa.FreeVar, and a specific set of "register" instructions, - // satisifying the ssa.Value interface, can serve as local variables. + // satisfying the ssa.Value interface, can serve as local variables. return local{val: v} default: panic(fmt.Errorf("unsupported value %v in node creation", val)) diff --git a/go/callgraph/vta/internal/trie/builder.go b/go/callgraph/vta/internal/trie/builder.go index bdd39397ec6..adef1282d81 100644 --- a/go/callgraph/vta/internal/trie/builder.go +++ b/go/callgraph/vta/internal/trie/builder.go @@ -245,7 +245,7 @@ func (b *Builder) create(leaves []*leaf) node { } else if n == 1 { return leaves[0] } - // Note: we can do a more sophisicated algorithm by: + // Note: we can do a more sophisticated algorithm by: // - sorting the leaves ahead of time, // - taking the prefix and branching bit of the min and max key, // - binary searching for the branching bit, diff --git a/go/loader/doc.go b/go/loader/doc.go index e35b1fd7d93..769a1fcf7c6 100644 --- a/go/loader/doc.go +++ b/go/loader/doc.go @@ -164,7 +164,7 @@ package loader // entry is created in this cache by startLoad the first time the // package is imported. The first goroutine to request an entry becomes // responsible for completing the task and broadcasting completion to -// subsequent requestors, which block until then. +// subsequent requesters, which block until then. // // Type checking occurs in (parallel) postorder: we cannot type-check a // set of files until we have loaded and type-checked all of their diff --git a/go/packages/overlay_test.go b/go/packages/overlay_test.go index 4a7cc68f4c7..f2e1d9584e1 100644 --- a/go/packages/overlay_test.go +++ b/go/packages/overlay_test.go @@ -81,7 +81,7 @@ func testOverlayChangesBothPackageNames(t *testing.T, exporter packagestest.Expo t.Fatalf("failed to load: %v", err) } if len(initial) != 3 { - t.Errorf("got %d packges, expected 3", len(initial)) + t.Errorf("got %d packages, expected 3", len(initial)) } want := []struct { id, name string @@ -127,7 +127,7 @@ func testOverlayChangesTestPackageName(t *testing.T, exporter packagestest.Expor t.Fatalf("failed to load: %v", err) } if len(initial) != 3 { - t.Errorf("got %d packges, expected 3", len(initial)) + t.Errorf("got %d packages, expected 3", len(initial)) } want := []struct { id, name string diff --git a/go/packages/packages_test.go b/go/packages/packages_test.go index ae3cbb6bb2b..fa577345e3c 100644 --- a/go/packages/packages_test.go +++ b/go/packages/packages_test.go @@ -2676,7 +2676,7 @@ func testIssue48226(t *testing.T, exporter packagestest.Exporter) { t.Fatal(err) } if len(initial) != 1 { - t.Fatalf("exepected 1 package, got %d", len(initial)) + t.Fatalf("expected 1 package, got %d", len(initial)) } pkg := initial[0] @@ -2721,7 +2721,7 @@ func testModule(t *testing.T, exporter packagestest.Exporter) { t.Fatal("package.Module: want non-nil, got nil") } if a.Module.Path != "golang.org/fake" { - t.Fatalf("package.Modile.Path: want \"golang.org/fake\", got %q", a.Module.Path) + t.Fatalf("package.Module.Path: want \"golang.org/fake\", got %q", a.Module.Path) } if a.Module.GoMod != filepath.Join(rootDir, "go.mod") { t.Fatalf("package.Module.GoMod: want %q, got %q", filepath.Join(rootDir, "go.mod"), a.Module.GoMod) diff --git a/go/ssa/builder.go b/go/ssa/builder.go index b76b75ea025..fe713a77b61 100644 --- a/go/ssa/builder.go +++ b/go/ssa/builder.go @@ -25,7 +25,7 @@ package ssa // populating fields such as Function.Body, .Params, and others. // // Building may create additional methods, including: -// - wrapper methods (e.g. for embeddding, or implicit &recv) +// - wrapper methods (e.g. for embedding, or implicit &recv) // - bound method closures (e.g. for use(recv.f)) // - thunks (e.g. for use(I.f) or use(T.f)) // - generic instances (e.g. to produce f[int] from f[any]). diff --git a/go/ssa/builder_generic_test.go b/go/ssa/builder_generic_test.go index af16036dfa9..f2af808e911 100644 --- a/go/ssa/builder_generic_test.go +++ b/go/ssa/builder_generic_test.go @@ -766,7 +766,7 @@ func TestInstructionString(t *testing.T) { // Expectation is a {function, type string} -> {want, matches} // where matches is all Instructions.String() that match the key. - // Each expecation is that some permutation of matches is wants. + // Each expectation is that some permutation of matches is wants. type expKey struct { function string kind string diff --git a/go/ssa/subst.go b/go/ssa/subst.go index b4ea16854ea..362dce1267b 100644 --- a/go/ssa/subst.go +++ b/go/ssa/subst.go @@ -543,7 +543,7 @@ func (subst *subster) signature(t *types.Signature) types.Type { // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice. // // There are some known usages for types.Types coming from types.{Eval,CheckExpr}. - // To support tparams.Len() > 0, we just need to do the following [psuedocode]: + // To support tparams.Len() > 0, we just need to do the following [pseudocode]: // targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false) assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.") diff --git a/go/ssa/util.go b/go/ssa/util.go index e53b31ff3bb..932eb6cb0e7 100644 --- a/go/ssa/util.go +++ b/go/ssa/util.go @@ -25,7 +25,7 @@ type unit struct{} //// Sanity checking utilities -// assert panics with the mesage msg if p is false. +// assert panics with the message msg if p is false. // Avoid combining with expensive string formatting. func assert(p bool, msg string) { if !p { diff --git a/gopls/doc/assets/go.mod b/gopls/doc/assets/go.mod index 9b417f19ed8..ca76ffefd55 100644 --- a/gopls/doc/assets/go.mod +++ b/gopls/doc/assets/go.mod @@ -1,6 +1,6 @@ // This module contains no Go code, but serves to carve out a hole in // its parent module to avoid bloating it with large image files that -// would otherwise be dowloaded by "go install golang.org/x/tools/gopls@latest". +// would otherwise be downloaded by "go install golang.org/x/tools/gopls@latest". module golang.org/x/tools/gopls/doc/assets diff --git a/gopls/doc/contributing.md b/gopls/doc/contributing.md index 94752c5394d..fcbe7d65e59 100644 --- a/gopls/doc/contributing.md +++ b/gopls/doc/contributing.md @@ -188,7 +188,7 @@ Jenkins-like Google infrastructure for running Dockerized tests. This allows us to run gopls tests in various environments that would be difficult to add to the TryBots. Notably, Kokoro runs tests on [older Go versions](../README.md#supported-go-versions) that are no longer supported -by the TryBots. Per that that policy, support for these older Go versions is +by the TryBots. Per that policy, support for these older Go versions is best-effort, and test failures may be skipped rather than fixed. Kokoro runs are triggered by the `Run-TryBot=1` label, just like TryBots, but diff --git a/gopls/doc/emacs.md b/gopls/doc/emacs.md index 3b6ee80d05a..48553d71938 100644 --- a/gopls/doc/emacs.md +++ b/gopls/doc/emacs.md @@ -110,7 +110,7 @@ project root. ;; Optional: install eglot-format-buffer as a save hook. ;; The depth of -10 places this before eglot's willSave notification, -;; so that that notification reports the actual contents that will be saved. +;; so that notification reports the actual contents that will be saved. (defun eglot-format-buffer-before-save () (add-hook 'before-save-hook #'eglot-format-buffer -10 t)) (add-hook 'go-mode-hook #'eglot-format-buffer-before-save) diff --git a/gopls/doc/features/passive.md b/gopls/doc/features/passive.md index 77f7b2f0c06..e1c8c5bdc58 100644 --- a/gopls/doc/features/passive.md +++ b/gopls/doc/features/passive.md @@ -46,7 +46,7 @@ structures, or when reading assembly files or stack traces that refer to each field by its cryptic byte offset. In addition, Hover reports: -- the struct's size class, which is the number of of bytes actually +- the struct's size class, which is the number of bytes actually allocated by the Go runtime for a single object of this type; and - the percentage of wasted space due to suboptimal ordering of struct fields, if this figure is 20% or higher: diff --git a/gopls/doc/features/templates.md b/gopls/doc/features/templates.md index a71a2ea181c..f734fcd84fa 100644 --- a/gopls/doc/features/templates.md +++ b/gopls/doc/features/templates.md @@ -14,7 +14,7 @@ value, since Go templates don't have a canonical file extension.) Additional configuration may be necessary to ensure that your client chooses the correct language kind when opening template files. -Gopls recogizes both `"tmpl"` and `"gotmpl"` for template files. +Gopls recognizes both `"tmpl"` and `"gotmpl"` for template files. For example, in `VS Code` you will also need to add an entry to the [`files.associations`](https://code.visualstudio.com/docs/languages/identifiers) diff --git a/gopls/doc/release/v0.17.0.md b/gopls/doc/release/v0.17.0.md index e6af9c6bf26..786891bd1fd 100644 --- a/gopls/doc/release/v0.17.0.md +++ b/gopls/doc/release/v0.17.0.md @@ -17,7 +17,7 @@ reduce the considerable costs to us of testing against older Go versions, allowing us to spend more time fixing bugs and adding features that benefit the majority of gopls users who run recent versions of Go. -This narrowing is occuring in two dimensions: **build compatibility** refers to +This narrowing is occurring in two dimensions: **build compatibility** refers to the versions of the Go toolchain that can be used to build gopls, and **go command compatibility** refers to the versions of the `go` command that can be used by gopls to list information about packages and modules in your workspace. @@ -110,7 +110,7 @@ The user can invoke this code action by selecting a function name, the keywords or by selecting a whole declaration or multiple declarations. In order to avoid ambiguity and surprise about what to extract, some kinds -of paritial selection of a declaration cannot invoke this code action. +of partial selection of a declaration cannot invoke this code action. ### Extract constant diff --git a/gopls/doc/release/v0.19.0.md b/gopls/doc/release/v0.19.0.md index 9ae1feb6c36..94f225a800f 100644 --- a/gopls/doc/release/v0.19.0.md +++ b/gopls/doc/release/v0.19.0.md @@ -35,7 +35,7 @@ Slightly more than half of the analyzers in the enabled by default. This subset has been chosen for precision and efficiency. -Prevously, Staticcheck analyzers (all of them) would be run only if +Previously, Staticcheck analyzers (all of them) would be run only if the experimental `staticcheck` boolean option was set to `true`. This value continues to enable the complete set, and a value of `false` continues to disable the complete set. Leaving the option unspecified diff --git a/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go b/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go index cdc767450d2..5f404ed717d 100644 --- a/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go +++ b/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go @@ -139,7 +139,7 @@ func fix72727(a, b int) { type myfloat float64 -// The built-in min/max differ in their treatement of NaN, +// The built-in min/max differ in their treatment of NaN, // so reject floating-point numbers (#72829). func nopeFloat(a, b myfloat) (res myfloat) { if a < b { diff --git a/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden b/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden index b7be86bf416..a13c72db5c0 100644 --- a/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden +++ b/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden @@ -126,7 +126,7 @@ func fix72727(a, b int) { type myfloat float64 -// The built-in min/max differ in their treatement of NaN, +// The built-in min/max differ in their treatment of NaN, // so reject floating-point numbers (#72829). func nopeFloat(a, b myfloat) (res myfloat) { if a < b { diff --git a/gopls/internal/analysis/unusedparams/testdata/src/generatedcode/nongeneratedcode.go b/gopls/internal/analysis/unusedparams/testdata/src/generatedcode/nongeneratedcode.go index fe0ef94afbb..cad36d39aa8 100644 --- a/gopls/internal/analysis/unusedparams/testdata/src/generatedcode/nongeneratedcode.go +++ b/gopls/internal/analysis/unusedparams/testdata/src/generatedcode/nongeneratedcode.go @@ -12,7 +12,7 @@ type implementsGeneratedInterface struct{} // in the generated code. func (implementsGeneratedInterface) n(f bool) { // The body must not be empty, otherwise unusedparams will - // not report the unused parameter regardles of the + // not report the unused parameter regardless of the // interface. println() } diff --git a/gopls/internal/analysis/unusedparams/testdata/src/generatedcode/nongeneratedcode.go.golden b/gopls/internal/analysis/unusedparams/testdata/src/generatedcode/nongeneratedcode.go.golden index 170dc85785c..44d24fb55e3 100644 --- a/gopls/internal/analysis/unusedparams/testdata/src/generatedcode/nongeneratedcode.go.golden +++ b/gopls/internal/analysis/unusedparams/testdata/src/generatedcode/nongeneratedcode.go.golden @@ -12,7 +12,7 @@ type implementsGeneratedInterface struct{} // in the generated code. func (implementsGeneratedInterface) n(f bool) { // The body must not be empty, otherwise unusedparams will - // not report the unused parameter regardles of the + // not report the unused parameter regardless of the // interface. println() } diff --git a/gopls/internal/analysis/yield/yield.go b/gopls/internal/analysis/yield/yield.go index 354cf372186..4cac152a88c 100644 --- a/gopls/internal/analysis/yield/yield.go +++ b/gopls/internal/analysis/yield/yield.go @@ -14,7 +14,7 @@ package yield // // seq(yield) // -// to avoid unnecesary range desugaring and chains of dynamic calls. +// to avoid unnecessary range desugaring and chains of dynamic calls. import ( _ "embed" diff --git a/gopls/internal/cache/check.go b/gopls/internal/cache/check.go index bee0616c8a1..a23ec4c2937 100644 --- a/gopls/internal/cache/check.go +++ b/gopls/internal/cache/check.go @@ -1688,7 +1688,7 @@ func (b *typeCheckBatch) checkPackage(ctx context.Context, fset *token.FileSet, // Track URIs with parse errors so that we can suppress type errors for these // files. - unparseable := map[protocol.DocumentURI]bool{} + unparsable := map[protocol.DocumentURI]bool{} for _, e := range pkg.parseErrors { diags, err := parseErrorDiagnostics(pkg, e) if err != nil { @@ -1696,7 +1696,7 @@ func (b *typeCheckBatch) checkPackage(ctx context.Context, fset *token.FileSet, continue } for _, diag := range diags { - unparseable[diag.URI] = true + unparsable[diag.URI] = true pkg.diagnostics = append(pkg.diagnostics, diag) } } @@ -1706,7 +1706,7 @@ func (b *typeCheckBatch) checkPackage(ctx context.Context, fset *token.FileSet, // If the file didn't parse cleanly, it is highly likely that type // checking errors will be confusing or redundant. But otherwise, type // checking usually provides a good enough signal to include. - if !unparseable[diag.URI] { + if !unparsable[diag.URI] { pkg.diagnostics = append(pkg.diagnostics, diag) } } diff --git a/gopls/internal/cache/mod_tidy.go b/gopls/internal/cache/mod_tidy.go index 6d9a3e56b81..434f6f9c760 100644 --- a/gopls/internal/cache/mod_tidy.go +++ b/gopls/internal/cache/mod_tidy.go @@ -45,7 +45,7 @@ func (s *Snapshot) ModTidy(ctx context.Context, pm *ParsedModule) (*TidiedModule uri := pm.URI if pm.File == nil { - return nil, fmt.Errorf("cannot tidy unparseable go.mod file: %v", uri) + return nil, fmt.Errorf("cannot tidy unparsable go.mod file: %v", uri) } s.mu.Lock() diff --git a/gopls/internal/cache/parse_cache.go b/gopls/internal/cache/parse_cache.go index 015510b881d..259587408f5 100644 --- a/gopls/internal/cache/parse_cache.go +++ b/gopls/internal/cache/parse_cache.go @@ -391,7 +391,7 @@ func (c *parseCache) parseFiles(ctx context.Context, fset *token.FileSet, mode p // -- priority queue boilerplate -- -// queue is a min-atime prority queue of cache entries. +// queue is a min-atime priority queue of cache entries. type queue []*parseCacheEntry func (q queue) Len() int { return len(q) } diff --git a/gopls/internal/cache/parsego/parse.go b/gopls/internal/cache/parsego/parse.go index 3346edd2b7a..2708e2b262b 100644 --- a/gopls/internal/cache/parsego/parse.go +++ b/gopls/internal/cache/parsego/parse.go @@ -82,7 +82,7 @@ func Parse(ctx context.Context, fset *token.FileSet, uri protocol.DocumentURI, s } for i := range 10 { - // Fix certain syntax errors that render the file unparseable. + // Fix certain syntax errors that render the file unparsable. newSrc, srcFix := fixSrc(file, tok, src) if newSrc == nil { break @@ -213,7 +213,7 @@ func fixAST(n ast.Node, tok *token.File, src []byte) (fixes []FixType) { return fixes } -// TODO(rfindley): revert this intrumentation once we're certain the crash in +// TODO(rfindley): revert this instrumentation once we're certain the crash in // #59097 is fixed. type FixType int diff --git a/gopls/internal/cache/session.go b/gopls/internal/cache/session.go index f0d8f062138..82472e82a95 100644 --- a/gopls/internal/cache/session.go +++ b/gopls/internal/cache/session.go @@ -152,7 +152,7 @@ func (s *Session) HasView(uri protocol.DocumentURI) bool { } // createView creates a new view, with an initial snapshot that retains the -// supplied context, detached from events and cancelation. +// supplied context, detached from events and cancellation. // // The caller is responsible for calling the release function once. func (s *Session) createView(ctx context.Context, def *viewDefinition) (*View, *Snapshot, func()) { @@ -418,7 +418,7 @@ func (s *Session) SnapshotOf(ctx context.Context, uri protocol.DocumentURI) (*Sn continue // view was shut down } // We don't check the error from awaitLoaded, because a load failure (that - // doesn't result from context cancelation) should not prevent us from + // doesn't result from context cancellation) should not prevent us from // continuing to search for the best view. _ = snapshot.awaitLoaded(ctx) g := snapshot.MetadataGraph() diff --git a/gopls/internal/cache/snapshot.go b/gopls/internal/cache/snapshot.go index 8dda86071de..e78c1bba010 100644 --- a/gopls/internal/cache/snapshot.go +++ b/gopls/internal/cache/snapshot.go @@ -102,7 +102,7 @@ type Snapshot struct { // initialErr holds the last error resulting from initialization. If // initialization fails, we only retry when the workspace modules change, // to avoid too many go/packages calls. - // If initialized is false, initialErr stil holds the error resulting from + // If initialized is false, initialErr still holds the error resulting from // the previous initialization. // TODO(rfindley): can we unify the lifecycle of initialized and initialErr. initialErr *InitializationError @@ -1762,7 +1762,7 @@ func (s *Snapshot) clone(ctx, bgCtx context.Context, changed StateChange, done f // // We could also do better by looking at which imports were deleted and // trying to find cycles they are involved in. This fails when the file goes - // from an unparseable state to a parseable state, as we don't have a + // from an unparsable state to a parseable state, as we don't have a // starting point to compare with. if anyImportDeleted { for id, mp := range s.meta.Packages { @@ -2090,7 +2090,7 @@ func metadataChanges(ctx context.Context, lockedSnapshot *Snapshot, oldFH, newFH } else { // At this point, we shouldn't ever fail to produce a parsego.File, as // we're already past header parsing. - bug.Reportf("metadataChanges: unparseable file %v (old error: %v, new error: %v)", oldFH.URI(), oldErr, newErr) + bug.Reportf("metadataChanges: unparsable file %v (old error: %v, new error: %v)", oldFH.URI(), oldErr, newErr) } } diff --git a/gopls/internal/cmd/codelens.go b/gopls/internal/cmd/codelens.go index 074733e58f5..55424a395e0 100644 --- a/gopls/internal/cmd/codelens.go +++ b/gopls/internal/cmd/codelens.go @@ -32,7 +32,7 @@ The codelens command lists or executes code lenses for the specified file, or line within a file. A code lens is a command associated with a position in the code. -With an optional title argment, only code lenses matching that +With an optional title argument, only code lenses matching that title are considered. By default, the codelens command lists the available lenses for the diff --git a/gopls/internal/cmd/usage/codelens.hlp b/gopls/internal/cmd/usage/codelens.hlp index f72bb465e07..5b72961e44e 100644 --- a/gopls/internal/cmd/usage/codelens.hlp +++ b/gopls/internal/cmd/usage/codelens.hlp @@ -7,7 +7,7 @@ The codelens command lists or executes code lenses for the specified file, or line within a file. A code lens is a command associated with a position in the code. -With an optional title argment, only code lenses matching that +With an optional title argument, only code lenses matching that title are considered. By default, the codelens command lists the available lenses for the diff --git a/gopls/internal/golang/addtest.go b/gopls/internal/golang/addtest.go index 66ed9716c9a..da9a8ecc88c 100644 --- a/gopls/internal/golang/addtest.go +++ b/gopls/internal/golang/addtest.go @@ -182,7 +182,7 @@ type testInfo struct { // TestingPackageName is the package name should be used when referencing // package "testing" TestingPackageName string - // PackageName is the package name the target function/method is delcared from. + // PackageName is the package name the target function/method is declared from. PackageName string TestFuncName string // Func holds information about the function or method being tested. diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index d6b49ca9d04..13793995561 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -2428,8 +2428,8 @@ Nodes: if inst != nil { // TODO(jacobz): If partial signature instantiation becomes possible, // make needsExactType only true if necessary. - // Currently, ambigious cases always resolve to a conversion expression - // wrapping the completion, which is occassionally superfluous. + // Currently, ambiguous cases always resolve to a conversion expression + // wrapping the completion, which is occasionally superfluous. inf.needsExactType = true sig = inst } diff --git a/gopls/internal/golang/completion/keywords.go b/gopls/internal/golang/completion/keywords.go index fb1fa1694ce..8ed83e8ad07 100644 --- a/gopls/internal/golang/completion/keywords.go +++ b/gopls/internal/golang/completion/keywords.go @@ -122,7 +122,7 @@ func (c *completer) addKeywordCompletions() { } case *ast.TypeSwitchStmt, *ast.SelectStmt, *ast.SwitchStmt: // if there is no default case yet, it's highly likely to add a default in switch. - // we don't offer 'default' anymore if user has used it already in current swtich. + // we don't offer 'default' anymore if user has used it already in current switch. if !hasDefaultClause(node) { c.addKeywordItems(seen, highScore, CASE, DEFAULT) } @@ -152,7 +152,7 @@ func (c *completer) addKeywordCompletions() { // as user must return something, we offer a space after return. // function literal inside a function will be affected by outer function, // but 'go fmt' will help to remove the ending space. - // the benefit is greater than introducing an unncessary space. + // the benefit is greater than introducing an unnecessary space. ret += " " } diff --git a/gopls/internal/golang/completion/literal.go b/gopls/internal/golang/completion/literal.go index 5dc364724c6..572a7c4c2ca 100644 --- a/gopls/internal/golang/completion/literal.go +++ b/gopls/internal/golang/completion/literal.go @@ -343,7 +343,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m for i := range results.Len() { if resultHasTypeParams && !c.opts.placeholders { // Leave an empty tabstop if placeholders are disabled and there - // are type args that need specificying. + // are type args that need specifying. snip.WritePlaceholder(nil) break } diff --git a/gopls/internal/golang/completion/package.go b/gopls/internal/golang/completion/package.go index d1698ee6580..ae421d821c0 100644 --- a/gopls/internal/golang/completion/package.go +++ b/gopls/internal/golang/completion/package.go @@ -95,7 +95,7 @@ func packageCompletionSurrounding(pgf *parsego.File, offset int) (*Selection, er fset := token.NewFileSet() expr, _ := parser.ParseExprFrom(fset, m.URI.Path(), pgf.Src, parser.Mode(0)) if expr == nil { - return nil, fmt.Errorf("unparseable file (%s)", m.URI) + return nil, fmt.Errorf("unparsable file (%s)", m.URI) } tok := fset.File(expr.Pos()) cursor := tok.Pos(offset) diff --git a/gopls/internal/golang/highlight.go b/gopls/internal/golang/highlight.go index ee82b622a71..1d97d2819af 100644 --- a/gopls/internal/golang/highlight.go +++ b/gopls/internal/golang/highlight.go @@ -358,7 +358,7 @@ findEnclosingFunc: } else if returnStmt == nil && !inResults { return // nothing to highlight } else { - // If we're not highighting the entire return statement, we need to collect + // If we're not highlighting the entire return statement, we need to collect // specific result indexes to highlight. This may be more than one index if // the cursor is on a multi-name result field, but not in any specific name. if !highlightAll { diff --git a/gopls/internal/golang/implementation.go b/gopls/internal/golang/implementation.go index 678861440da..9aaf27d7d1a 100644 --- a/gopls/internal/golang/implementation.go +++ b/gopls/internal/golang/implementation.go @@ -280,7 +280,7 @@ func implementationsMsets(ctx context.Context, snapshot *cache.Snapshot, pkg *ca // It returns a nil type to indicate that the query should not proceed. // // (It is factored out to allow it to be used both in the query package -// then (in [localImplementations]) again in the declarating package.) +// then (in [localImplementations]) again in the declaring package.) func typeOrMethod(obj types.Object) (types.Type, *types.Func) { switch obj := obj.(type) { case *types.TypeName: diff --git a/gopls/internal/golang/rename_check.go b/gopls/internal/golang/rename_check.go index 060a2f5e6c6..bbac4558bec 100644 --- a/gopls/internal/golang/rename_check.go +++ b/gopls/internal/golang/rename_check.go @@ -331,7 +331,7 @@ func deeper(x, y *types.Scope) bool { // scope that begins at the end of its ValueSpec, or after the // AssignStmt for a var declared by ":=". // -// - Each type {t,u} in the body has a scope that that begins at +// - Each type {t,u} in the body has a scope that begins at // the start of the TypeSpec, so they can be self-recursive // but--unlike package-level types--not mutually recursive. diff --git a/gopls/internal/golang/type_hierarchy.go b/gopls/internal/golang/type_hierarchy.go index bbcd5325d7b..71aec6a8365 100644 --- a/gopls/internal/golang/type_hierarchy.go +++ b/gopls/internal/golang/type_hierarchy.go @@ -91,7 +91,7 @@ func Subtypes(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, ite return relatedTypes(ctx, snapshot, fh, item, methodsets.Subtype) } -// Subtypes reports information about supertypes of the selected type. +// Supertypes reports information about supertypes of the selected type. func Supertypes(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, item protocol.TypeHierarchyItem) ([]protocol.TypeHierarchyItem, error) { return relatedTypes(ctx, snapshot, fh, item, methodsets.Supertype) } diff --git a/gopls/internal/protocol/command/interface.go b/gopls/internal/protocol/command/interface.go index 01d41dec473..46abd9184d9 100644 --- a/gopls/internal/protocol/command/interface.go +++ b/gopls/internal/protocol/command/interface.go @@ -273,7 +273,7 @@ type Interface interface { // ClientOpenURL: Request that the client open a URL in a browser. ClientOpenURL(_ context.Context, url string) error - // ScanImports: force a sychronous scan of the imports cache. + // ScanImports: force a synchronous scan of the imports cache. // // This command is intended for use by gopls tests only. ScanImports(context.Context) error diff --git a/gopls/internal/server/command.go b/gopls/internal/server/command.go index b16009ec0ce..8782dfd1460 100644 --- a/gopls/internal/server/command.go +++ b/gopls/internal/server/command.go @@ -1551,7 +1551,7 @@ func openClientBrowser(ctx context.Context, cli protocol.Client, title string, u Message: fmt.Sprintf("%s: open your browser to %s", title, url), } if err := cli.ShowMessage(ctx, params); err != nil { - event.Error(ctx, "failed to show brower url", err) + event.Error(ctx, "failed to show browser url", err) } } } diff --git a/gopls/internal/server/prompt.go b/gopls/internal/server/prompt.go index f8895358942..b0392b230d8 100644 --- a/gopls/internal/server/prompt.go +++ b/gopls/internal/server/prompt.go @@ -51,7 +51,7 @@ const ( // The following environment variables may be set by the client. // Exported for testing telemetry integration. const ( - GoTelemetryGoplsClientStartTimeEnvvar = "GOTELEMETRY_GOPLS_CLIENT_START_TIME" // telemetry start time recored in client + GoTelemetryGoplsClientStartTimeEnvvar = "GOTELEMETRY_GOPLS_CLIENT_START_TIME" // telemetry start time recorded in client GoTelemetryGoplsClientTokenEnvvar = "GOTELEMETRY_GOPLS_CLIENT_TOKEN" // sampling token ) diff --git a/gopls/internal/settings/analysis.go b/gopls/internal/settings/analysis.go index 48a783ac486..59b88ba840f 100644 --- a/gopls/internal/settings/analysis.go +++ b/gopls/internal/settings/analysis.go @@ -133,7 +133,7 @@ func (a *Analyzer) ActionKinds() []protocol.CodeActionKind { return a.actionKind // less intrusive than Info diagnostics. The rule of thumb is this: use Info if // the diagnostic is not a bug, but the author probably didn't mean to write // the code that way. Use Hint if the diagnostic is not a bug and the author -// indended to write the code that way, but there is a simpler or more modern +// intended to write the code that way, but there is a simpler or more modern // way to express the same logic. An 'unused' diagnostic is Info level, since // the author probably didn't mean to check in unreachable code. A 'modernize' // or 'deprecated' diagnostic is Hint level, since the author intended to write diff --git a/gopls/internal/test/integration/bench/bench_test.go b/gopls/internal/test/integration/bench/bench_test.go index d7c1fd976bd..9858177b7e0 100644 --- a/gopls/internal/test/integration/bench/bench_test.go +++ b/gopls/internal/test/integration/bench/bench_test.go @@ -236,7 +236,7 @@ func (s *SidecarServer) Connect(ctx context.Context) jsonrpc2.Conn { // Note: don't use CommandContext here, as we want gopls to exit gracefully // in order to write out profile data. // - // We close the connection on context cancelation below. + // We close the connection on context cancellation below. cmd := exec.Command(s.goplsPath, s.args...) stdin, err := cmd.StdinPipe() diff --git a/gopls/internal/test/integration/bench/imports_test.go b/gopls/internal/test/integration/bench/imports_test.go index 3f47a561681..fae217eee6d 100644 --- a/gopls/internal/test/integration/bench/imports_test.go +++ b/gopls/internal/test/integration/bench/imports_test.go @@ -46,7 +46,7 @@ func BenchmarkInitialGoimportsScan(b *testing.B) { defer env.Close() env.Await(InitialWorkspaceLoad) - // Create a buffer with a dangling selctor where the receiver is a single + // Create a buffer with a dangling selector where the receiver is a single // character ('a') that matches a large fraction of the module cache. env.CreateBuffer("internal/lsp/cache/temp.go", ` // This is a temp file to exercise goimports scan of the module cache. diff --git a/gopls/internal/test/integration/misc/hover_test.go b/gopls/internal/test/integration/misc/hover_test.go index b6b7b679357..b79057def33 100644 --- a/gopls/internal/test/integration/misc/hover_test.go +++ b/gopls/internal/test/integration/misc/hover_test.go @@ -332,7 +332,7 @@ func Hello() string { t.Errorf("item:%q not sig:%q", itemContent, sigContent) } if !strings.Contains(hoverContent, itemContent) { - t.Errorf("hover:%q does not containt sig;%q", hoverContent, sigContent) + t.Errorf("hover:%q does not contain sig;%q", hoverContent, sigContent) } }) } diff --git a/gopls/internal/test/integration/web/pkdoc_test.go b/gopls/internal/test/integration/web/pkdoc_test.go index 7f940e9ddd1..b5001421f8e 100644 --- a/gopls/internal/test/integration/web/pkdoc_test.go +++ b/gopls/internal/test/integration/web/pkdoc_test.go @@ -71,7 +71,7 @@ func (G[T]) F(int, int, int, int, int, int, int, ...int) {} collectDocs := env.Awaiter.ListenToShownDocuments() get(t, srcURL) - // Check that that shown location is that of NewFunc. + // Check that shown location is that of NewFunc. shownSource := shownDocument(t, collectDocs(), "file:") gotLoc := protocol.Location{ URI: protocol.DocumentURI(shownSource.URI), // fishy conversion diff --git a/gopls/internal/test/marker/testdata/quickfix/stub.txt b/gopls/internal/test/marker/testdata/quickfix/stub.txt index 385565e3eaf..45f8918fe29 100644 --- a/gopls/internal/test/marker/testdata/quickfix/stub.txt +++ b/gopls/internal/test/marker/testdata/quickfix/stub.txt @@ -194,7 +194,7 @@ package stub import "io" -// This file tests that that the stub method generator accounts for concrete +// This file tests that the stub method generator accounts for concrete // types that have type parameters defined. var _ io.ReaderFrom = &genReader[string, int]{} //@quickfix("&genReader", re"does not implement", generic_receiver) diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go index f54c3f4208d..e8292d007d5 100644 --- a/internal/analysisinternal/analysis.go +++ b/internal/analysisinternal/analysis.go @@ -619,8 +619,8 @@ Outer: // otherwise remove the line edit := analysis.TextEdit{Pos: stmt.Pos(), End: stmt.End()} if from.IsValid() || to.IsValid() { - // remove just the statment. - // we can't tell if there is a ; or whitespace right after the statment + // remove just the statement. + // we can't tell if there is a ; or whitespace right after the statement // ideally we'd like to remove the former and leave the latter // (if gofmt has run, there likely won't be a ;) // In type switches we know there's a semicolon somewhere after the statement, diff --git a/internal/diff/lcs/common.go b/internal/diff/lcs/common.go index c3e82dd2683..27fa9ecbd5c 100644 --- a/internal/diff/lcs/common.go +++ b/internal/diff/lcs/common.go @@ -51,7 +51,7 @@ func (l lcs) fix() lcs { // from the set of diagonals in l, find a maximal non-conflicting set // this problem may be NP-complete, but we use a greedy heuristic, // which is quadratic, but with a better data structure, could be D log D. - // indepedent is not enough: {0,3,1} and {3,0,2} can't both occur in an lcs + // independent is not enough: {0,3,1} and {3,0,2} can't both occur in an lcs // which has to have monotone x and y if len(l) == 0 { return nil diff --git a/internal/imports/fix.go b/internal/imports/fix.go index d2e275934e4..50b6ca51a6b 100644 --- a/internal/imports/fix.go +++ b/internal/imports/fix.go @@ -291,7 +291,7 @@ func (p *pass) loadPackageNames(ctx context.Context, imports []*ImportInfo) erro return nil } -// WithouVersion removes a trailing major version, if there is one. +// WithoutVersion removes a trailing major version, if there is one. func WithoutVersion(nm string) string { if v := path.Base(nm); len(v) > 0 && v[0] == 'v' { if _, err := strconv.Atoi(v[1:]); err == nil { diff --git a/internal/imports/imports.go b/internal/imports/imports.go index 2215a12880a..b5f5218b5cc 100644 --- a/internal/imports/imports.go +++ b/internal/imports/imports.go @@ -93,7 +93,7 @@ func FixImports(ctx context.Context, filename string, src []byte, goroot string, // env is needed. func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options, extraMode parser.Mode) (formatted []byte, err error) { // Don't use parse() -- we don't care about fragments or statement lists - // here, and we need to work with unparseable files. + // here, and we need to work with unparsable files. fileSet := token.NewFileSet() parserMode := parser.SkipObjectResolution if opt.Comments { diff --git a/internal/mcp/README.md b/internal/mcp/README.md index c1af3729182..a4fc3dee443 100644 --- a/internal/mcp/README.md +++ b/internal/mcp/README.md @@ -88,7 +88,7 @@ func main() { // Create a server with a single tool. server := mcp.NewServer("greeter", "v1.0.0", nil) server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) - // Run the server over stdin/stdout, until the client diconnects + // Run the server over stdin/stdout, until the client disconnects _ = server.Run(context.Background(), mcp.NewStdIOTransport()) } ``` diff --git a/internal/mcp/content.go b/internal/mcp/content.go index 94f5cd18f4a..7e0b89b4be8 100644 --- a/internal/mcp/content.go +++ b/internal/mcp/content.go @@ -114,7 +114,7 @@ func NewTextResourceContents(uri, mimeType, text string) *ResourceContents { } } -// NewTextResourceContents returns a [ResourceContents] containing a byte slice. +// NewBlobResourceContents returns a [ResourceContents] containing a byte slice. func NewBlobResourceContents(uri, mimeType string, blob []byte) *ResourceContents { // The only way to distinguish text from blob is a non-nil Blob field. if blob == nil { diff --git a/internal/mcp/internal/readme/server/server.go b/internal/mcp/internal/readme/server/server.go index 185a7297d4d..867d4c1e08d 100644 --- a/internal/mcp/internal/readme/server/server.go +++ b/internal/mcp/internal/readme/server/server.go @@ -25,7 +25,7 @@ func main() { // Create a server with a single tool. server := mcp.NewServer("greeter", "v1.0.0", nil) server.AddTools(mcp.NewTool("greet", "say hi", SayHi)) - // Run the server over stdin/stdout, until the client diconnects + // Run the server over stdin/stdout, until the client disconnects _ = server.Run(context.Background(), mcp.NewStdIOTransport()) } diff --git a/internal/mcp/internal/util/util.go b/internal/mcp/internal/util/util.go index cdc6038ede8..8ef5d1fd464 100644 --- a/internal/mcp/internal/util/util.go +++ b/internal/mcp/internal/util/util.go @@ -12,7 +12,7 @@ import ( // Helpers below are copied from gopls' moremaps package. -// sorted returns an iterator over the entries of m in key order. +// Sorted returns an iterator over the entries of m in key order. func Sorted[M ~map[K]V, K cmp.Ordered, V any](m M) iter.Seq2[K, V] { // TODO(adonovan): use maps.Sorted if proposal #68598 is accepted. return func(yield func(K, V) bool) { @@ -26,7 +26,7 @@ func Sorted[M ~map[K]V, K cmp.Ordered, V any](m M) iter.Seq2[K, V] { } } -// keySlice returns the keys of the map M, like slices.Collect(maps.Keys(m)). +// KeySlice returns the keys of the map M, like slices.Collect(maps.Keys(m)). func KeySlice[M ~map[K]V, K comparable, V any](m M) []K { r := make([]K, 0, len(m)) for k := range m { diff --git a/internal/mcp/jsonschema/resolve_test.go b/internal/mcp/jsonschema/resolve_test.go index 717bdbb0c04..2b60759cf7d 100644 --- a/internal/mcp/jsonschema/resolve_test.go +++ b/internal/mcp/jsonschema/resolve_test.go @@ -132,7 +132,7 @@ func TestResolveURIs(t *testing.T) { func TestRefCycle(t *testing.T) { // Verify that cycles of refs are OK. - // The test suite doesn't check this, suprisingly. + // The test suite doesn't check this, surprisingly. schemas := map[string]*Schema{ "root": {Ref: "a"}, "a": {Ref: "b"}, diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 104fa645089..8ad99b80457 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -43,7 +43,7 @@ type Server struct { type ServerOptions struct { // Optional instructions for connected clients. Instructions string - // If non-nil, called when "notifications/intialized" is received. + // If non-nil, called when "notifications/initialized" is received. InitializedHandler func(context.Context, *ServerSession, *InitializedParams) // PageSize is the maximum number of items to return in a single page for // list methods (e.g. ListTools). @@ -122,7 +122,7 @@ func (s *Server) RemoveTools(names ...string) { func() bool { return s.tools.remove(names...) }) } -// AddResource adds the given resource to the server and associates it with +// AddResources adds the given resource to the server and associates it with // a [ResourceHandler], which will be called when the client calls [ClientSession.ReadResource]. // If a resource with the same URI already exists, this one replaces it. // AddResource panics if a resource URI is invalid or not absolute (has an empty scheme). diff --git a/internal/mcp/shared.go b/internal/mcp/shared.go index 0e97391dd48..bb7957e05d8 100644 --- a/internal/mcp/shared.go +++ b/internal/mcp/shared.go @@ -144,7 +144,7 @@ const ( // See https://modelcontextprotocol.io/specification/2025-03-26/server/resources#error-handling // However, the code they chose is in the wrong space // (see https://github.com/modelcontextprotocol/modelcontextprotocol/issues/509). - // so we pick a different one, arbirarily for now (until they fix it). + // so we pick a different one, arbitrarily for now (until they fix it). // The immediate problem is that jsonprc2 defines -32002 as "server closing". CodeResourceNotFound = -31002 // The error code if the method exists and was called properly, but the peer does not support it. diff --git a/internal/mcp/sse.go b/internal/mcp/sse.go index bd82538769a..9b7ebf590fa 100644 --- a/internal/mcp/sse.go +++ b/internal/mcp/sse.go @@ -102,7 +102,7 @@ type SSEServerTransport struct { incoming chan jsonrpc2.Message // queue of incoming messages; never closed // We must guard both pushes to the incoming queue and writes to the response - // writer, because incoming POST requests are abitrarily concurrent and we + // writer, because incoming POST requests are arbitrarily concurrent and we // need to ensure we don't write push to the queue, or write to the // ResponseWriter, after the session GET request exits. mu sync.Mutex @@ -431,7 +431,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) { // An sseClientStream is a logical jsonrpc2 stream that implements the client // half of the SSE protocol: -// - Writes are POSTS to the sesion endpoint. +// - Writes are POSTS to the session endpoint. // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientStream struct { diff --git a/internal/modindex/directories.go b/internal/modindex/directories.go index 1e1a02f239b..2faa6ce0b8a 100644 --- a/internal/modindex/directories.go +++ b/internal/modindex/directories.go @@ -26,7 +26,7 @@ type directory struct { syms []symbol } -// filterDirs groups the directories by import path, +// byImportPath groups the directories by import path, // sorting the ones with the same import path by semantic version, // most recent first. func byImportPath(dirs []Relpath) (map[string][]*directory, error) { diff --git a/internal/modindex/gomodindex/cmd.go b/internal/modindex/gomodindex/cmd.go index 4fc0caf400e..fd281cb0a56 100644 --- a/internal/modindex/gomodindex/cmd.go +++ b/internal/modindex/gomodindex/cmd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A command for building and maintaing the module cache +// A command for building and maintaining the module cache // a.out // The commands are 'create' which builds a new index, // 'update', which attempts to update an existing index, @@ -35,7 +35,7 @@ type cmd struct { var cmds = []cmd{ {"create", index, "create a clean index of GOMODCACHE"}, - {"update", update, "if there is an existing index of GOMODCACHE, update it. Otherise create one."}, + {"update", update, "if there is an existing index of GOMODCACHE, update it. Otherwise create one."}, {"clean", clean, "removed unreferenced indexes more than an hour old"}, {"query", query, "not yet implemented"}, } diff --git a/internal/modindex/lookup.go b/internal/modindex/lookup.go index bd605e0d763..34e3673ff54 100644 --- a/internal/modindex/lookup.go +++ b/internal/modindex/lookup.go @@ -75,7 +75,7 @@ func (ix *Index) Lookup(pkg, name string, prefix bool) []Candidate { return nil // didn't find the package } var ans []Candidate - // loc is the first entry for this package name, but there may be severeal + // loc is the first entry for this package name, but there may be several for i := loc; i < len(ix.Entries); i++ { e := ix.Entries[i] if e.PkgName != pkg { diff --git a/internal/modindex/symbols.go b/internal/modindex/symbols.go index b918529d43e..31a502c5891 100644 --- a/internal/modindex/symbols.go +++ b/internal/modindex/symbols.go @@ -30,7 +30,7 @@ import ( type symbol struct { pkg string // name of the symbols's package name string // declared name - kind string // T, C, V, or F, follwed by D if deprecated + kind string // T, C, V, or F, followed by D if deprecated sig string // signature information, for F } @@ -110,7 +110,7 @@ func getFileExports(f *ast.File) []symbol { // The only place a $ can occur seems to be in a struct tag, which // can be an arbitrary string literal, and ExprString does not presently // print struct tags. So for this to happen the type of a formal parameter - // has to be a explict struct, e.g. foo(x struct{a int "$"}) and ExprString + // has to be a explicit struct, e.g. foo(x struct{a int "$"}) and ExprString // would have to show the struct tag. Even testing for this case seems // a waste of effort, but let's remember the possibility if strings.Contains(tp, "$") { diff --git a/internal/refactor/inline/callee.go b/internal/refactor/inline/callee.go index d4f53310a2a..41deebb8228 100644 --- a/internal/refactor/inline/callee.go +++ b/internal/refactor/inline/callee.go @@ -603,7 +603,7 @@ func analyzeAssignment(info *types.Info, stack []ast.Node) (assignable, ifaceAss } } - // Types do not need to match for index expresions. + // Types do not need to match for index expressions. if ix, ok := parent.(*ast.IndexExpr); ok { if ix.Index == expr { typ := info.TypeOf(ix.X) diff --git a/internal/refactor/inline/inline.go b/internal/refactor/inline/inline.go index 445f6b705c4..7e1c9994ddf 100644 --- a/internal/refactor/inline/inline.go +++ b/internal/refactor/inline/inline.go @@ -1741,7 +1741,7 @@ next: } } - // Arg is a potential substition candidate: analyze its shadowing. + // Arg is a potential substitution candidate: analyze its shadowing. // // Consider inlining a call f(z, 1) to // diff --git a/internal/refactor/inline/testdata/method.txtar b/internal/refactor/inline/testdata/method.txtar index 92343edd840..05767e25710 100644 --- a/internal/refactor/inline/testdata/method.txtar +++ b/internal/refactor/inline/testdata/method.txtar @@ -1,7 +1,7 @@ Test of inlining a method call. The call to (*T).g0 implicitly takes the address &x, and -the call to T.h implictly dereferences the argument *ptr. +the call to T.h implicitly dereferences the argument *ptr. The f1/g1 methods have parameters, exercising the splicing of the receiver into the parameter list. diff --git a/internal/typesinternal/zerovalue_test.go b/internal/typesinternal/zerovalue_test.go index 67295a95020..ca1e28b3d91 100644 --- a/internal/typesinternal/zerovalue_test.go +++ b/internal/typesinternal/zerovalue_test.go @@ -25,7 +25,7 @@ func TestZeroValue(t *testing.T) { testenv.NeedsGoExperiment(t, "aliastypeparams") } - // This test only refernece types/functions defined within the same package. + // This test only reference types/functions defined within the same package. // We can safely drop the package name when encountered. qual := types.Qualifier(func(p *types.Package) string { return "" From 23911234aa4d83bebe44c5d04f68cd877b8b42d0 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 23 May 2025 15:41:07 -0400 Subject: [PATCH 124/196] internal/mcp: treat zero struct fields generously Validating a struct is ambiguous, because a zero field could be considered missing or present. Interpret a zero optional struct field in whichever way results in success. Change-Id: I6a9474e6f6558d2a8522bc5ba0451967367e467d Reviewed-on: https://go-review.googlesource.com/c/tools/+/675956 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- internal/mcp/jsonschema/resolve.go | 9 +++ internal/mcp/jsonschema/schema.go | 3 + internal/mcp/jsonschema/validate.go | 38 +++++++-- internal/mcp/jsonschema/validate_test.go | 98 +++++++++++++++++++----- 4 files changed, 122 insertions(+), 26 deletions(-) diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index d28fba42cd9..fa3e3d2ad50 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -181,6 +181,15 @@ func (s *Schema) checkLocal(report func(error)) { s.patternProperties[re] = subschema } } + + // Build a set of required properties, to avoid quadratic behavior when validating + // a struct. + if len(s.Required) > 0 { + s.isRequired = map[string]bool{} + for _, r := range s.Required { + s.isRequired[r] = true + } + } } // resolveURIs resolves the ids and anchors in all the schemas of root, relative diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 3fbf861af17..0cf9d4d4b7a 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -155,6 +155,9 @@ type Schema struct { // compiled regexps pattern *regexp.Regexp patternProperties map[*regexp.Regexp]*Schema + + // the set of required properties + isRequired map[string]bool } // falseSchema returns a new Schema tree that fails to validate any value. diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index bc1701428c6..6651ec10ea2 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -271,6 +271,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an } // arrays + // TODO(jba): consider arrays of structs. if instance.Kind() == reflect.Array || instance.Kind() == reflect.Slice { // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.1 // This validate call doesn't collect annotations for the items of the instance; they are separate @@ -386,13 +387,19 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // If we used anns here, then we'd be including properties evaluated in subschemas // from allOf, etc., which additionalProperties shouldn't observe. evalProps := map[string]bool{} - for prop, schema := range schema.Properties { + for prop, subschema := range schema.Properties { val := property(instance, prop) if !val.IsValid() { // It's OK if the instance doesn't have the property. continue } - if err := st.validate(val, schema, nil, append(path, prop)); err != nil { + // If the instance is a struct and an optional property has the zero + // value, then we could interpret it as present or missing. Be generous: + // assume it's missing, and thus always validates successfully. + if instance.Kind() == reflect.Struct && val.IsZero() && !schema.isRequired[prop] { + continue + } + if err := st.validate(val, subschema, nil, append(path, prop)); err != nil { return err } evalProps[prop] = true @@ -433,13 +440,17 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an } // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.5 + var min, max int + if schema.MinProperties != nil || schema.MaxProperties != nil { + min, max = numPropertiesBounds(instance, schema.isRequired) + } if schema.MinProperties != nil { - if n, m := numProperties(instance), *schema.MinProperties; n < m { + if n, m := max, *schema.MinProperties; n < m { return fmt.Errorf("minProperties: object has %d properties, less than %d", n, m) } } if schema.MaxProperties != nil { - if n, m := numProperties(instance), *schema.MaxProperties; n > m { + if n, m := min, *schema.MaxProperties; n > m { return fmt.Errorf("maxProperties: object has %d properties, greater than %d", n, m) } } @@ -557,14 +568,25 @@ func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { } } -// numProperties returns the number of v's properties. +// numPropertiesBounds returns bounds on the number of v's properties. // v must be a map or a struct. -func numProperties(v reflect.Value) int { +// If v is a map, both bounds are the map's size. +// If v is a struct, the max is the number of struct properties. +// But since we don't know whether a zero value indicates a missing optional property +// or not, be generous and use the number of non-zero properties as the min. +func numPropertiesBounds(v reflect.Value, isRequired map[string]bool) (int, int) { switch v.Kind() { case reflect.Map: - return v.Len() + return v.Len(), v.Len() case reflect.Struct: - return len(structPropertiesOf(v.Type())) + sp := structPropertiesOf(v.Type()) + min := 0 + for prop, index := range sp { + if !v.FieldByIndex(index).IsZero() || isRequired[prop] { + min++ + } + } + return min, len(sp) default: panic(fmt.Sprintf("properties: bad value: %s of kind %s", v, v.Kind())) } diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index d6be6f8cfc4..b6b8d61a191 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -81,30 +81,92 @@ func TestStructInstance(t *testing.T) { instance := struct { I int B bool `json:"b"` - u int - }{1, true, 0} + P *int // either missing or nil + u int // unexported: not a property + }{1, true, nil, 0} - // The instance fails for all of these schemas, demonstrating that it - // was processed correctly. - for _, schema := range []*Schema{ - {MinProperties: Ptr(3)}, - {MaxProperties: Ptr(1)}, - {Required: []string{"i"}}, // the name is "I" - {Required: []string{"B"}}, // the name is "b" - {PropertyNames: &Schema{MinLength: Ptr(2)}}, - {Properties: map[string]*Schema{"b": {Type: "number"}}}, - {Required: []string{"I"}, AdditionalProperties: falseSchema()}, - {DependentRequired: map[string][]string{"b": {"u"}}}, - {DependentSchemas: map[string]*Schema{"b": falseSchema()}}, - {UnevaluatedProperties: falseSchema()}, + for _, tt := range []struct { + s Schema + want bool + }{ + { + Schema{MinProperties: Ptr(4)}, + false, + }, + { + Schema{MinProperties: Ptr(3)}, + true, // P interpreted as present + }, + { + Schema{MaxProperties: Ptr(1)}, + false, + }, + { + Schema{MaxProperties: Ptr(2)}, + true, // P interpreted as absent + }, + { + Schema{Required: []string{"i"}}, // the name is "I" + false, + }, + { + Schema{Required: []string{"B"}}, // the name is "b" + false, + }, + { + Schema{PropertyNames: &Schema{MinLength: Ptr(2)}}, + false, + }, + { + Schema{Properties: map[string]*Schema{"b": {Type: "boolean"}}}, + true, + }, + { + Schema{Properties: map[string]*Schema{"b": {Type: "number"}}}, + false, + }, + { + Schema{Required: []string{"I"}}, + true, + }, + { + Schema{Required: []string{"I", "P"}}, + true, // P interpreted as present + }, + { + Schema{Required: []string{"I", "P"}, Properties: map[string]*Schema{"P": {Type: "number"}}}, + false, // P interpreted as present, but not a number + }, + { + Schema{Required: []string{"I"}, Properties: map[string]*Schema{"P": {Type: "number"}}}, + true, // P not required, so interpreted as absent + }, + { + Schema{Required: []string{"I"}, AdditionalProperties: falseSchema()}, + false, + }, + { + Schema{DependentRequired: map[string][]string{"b": {"u"}}}, + false, + }, + { + Schema{DependentSchemas: map[string]*Schema{"b": falseSchema()}}, + false, + }, + { + Schema{UnevaluatedProperties: falseSchema()}, + false, + }, } { - res, err := schema.Resolve("", nil) + res, err := tt.s.Resolve("", nil) if err != nil { t.Fatal(err) } err = res.Validate(instance) - if err == nil { - t.Errorf("succeeded but wanted failure; schema = %s", schema.json()) + if err == nil && !tt.want { + t.Errorf("succeeded unexpectedly\nschema = %s", tt.s.json()) + } else if err != nil && tt.want { + t.Errorf("Validate: %v\nschema = %s", err, tt.s.json()) } } } From 1c017f1652c52aa388b9989fbe6a28e5a3bd9f3d Mon Sep 17 00:00:00 2001 From: aarzilli Date: Sat, 24 May 2025 13:01:31 +0200 Subject: [PATCH 125/196] internal/telemetry/cmd/stacks: support Delve Changes to support Delve. Change-Id: Iee9d79704ced6252745335a0a4d14843996d6ce6 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676015 Auto-Submit: Alan Donovan LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan Reviewed-by: Dmitri Shuralyov --- gopls/internal/telemetry/cmd/stacks/stacks.go | 64 +++++++++++++++---- .../telemetry/cmd/stacks/stacks_test.go | 4 +- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/gopls/internal/telemetry/cmd/stacks/stacks.go b/gopls/internal/telemetry/cmd/stacks/stacks.go index cb0a21b4ec2..17f180cacd0 100644 --- a/gopls/internal/telemetry/cmd/stacks/stacks.go +++ b/gopls/internal/telemetry/cmd/stacks/stacks.go @@ -133,6 +133,9 @@ type ProgramConfig struct { // IgnoreSymbolContains are "uninteresting" symbol substrings. e.g., // logging packages. IgnoreSymbolContains []string + + // Repository is the repository where the issues should be created, for example: "golang/go" + Repository string } var programs = map[string]ProgramConfig{ @@ -152,6 +155,7 @@ var programs = map[string]ProgramConfig{ "internal/util/bug.", "internal/bug.", // former name in gopls/0.14.2 }, + Repository: "golang/go", }, "cmd/compile": { Program: "cmd/compile", @@ -173,6 +177,23 @@ var programs = map[string]ProgramConfig{ "cmd/compile/internal/types2.(*Checker).handleBailout", "cmd/compile/internal/gc.handlePanic", }, + Repository: "golang/go", + }, + "github.com/go-delve/delve/cmd/dlv": { + Program: "github.com/go-delve/delve/cmd/dlv", + IncludeClient: false, + SearchLabel: "delve/telemetry-wins", + NewIssuePrefix: "telemetry report", + NewIssueLabels: []string{ + "delve/telemetry-wins", + }, + MatchSymbolPrefix: "github.com/go-delve/delve", + IgnoreSymbolContains: []string{ + "service/dap.(*Session).recoverPanic", + "rpccommon.newInternalError", + "rpccommon.(*ServerImpl).serveJSONCodec", + }, + Repository: "go-delve/delve", }, } @@ -227,7 +248,7 @@ func main() { claimedBy := claimStacks(issues, stacks) // Update existing issues that claimed new stacks. - updateIssues(ghclient, issues, stacks, stackToURL) + updateIssues(ghclient, pcfg.Repository, issues, stacks, stackToURL) // For each stack, show existing issue or create a new one. // Aggregate stack IDs by issue summary. @@ -402,7 +423,7 @@ func readReports(pcfg ProgramConfig, days int) (stacks map[string]map[Info]int64 // predicates. func readIssues(cli *githubClient, pcfg ProgramConfig) ([]*Issue, error) { // Query GitHub for all existing GitHub issues with the report label. - issues, err := cli.searchIssues(pcfg.SearchLabel) + issues, err := cli.searchIssues(pcfg.Repository, pcfg.SearchLabel) if err != nil { // TODO(jba): return error instead of dying, or doc. log.Fatalf("GitHub issues label %q search failed: %v", pcfg.SearchLabel, err) @@ -581,7 +602,7 @@ func claimStacks(issues []*Issue, stacks map[string]map[Info]int64) map[string]* } // updateIssues updates existing issues that claimed new stacks by predicate. -func updateIssues(cli *githubClient, issues []*Issue, stacks map[string]map[Info]int64, stackToURL map[string]string) { +func updateIssues(cli *githubClient, repo string, issues []*Issue, stacks map[string]map[Info]int64, stackToURL map[string]string) { for _, issue := range issues { if len(issue.newStacks) == 0 { continue @@ -597,7 +618,7 @@ func updateIssues(cli *githubClient, issues []*Issue, stacks map[string]map[Info writeStackComment(comment, stack, id, stackToURL[stack], stacks[stack]) } - if err := cli.addIssueComment(issue.Number, comment.String()); err != nil { + if err := cli.addIssueComment(repo, issue.Number, comment.String()); err != nil { log.Println(err) continue } @@ -616,7 +637,7 @@ func updateIssues(cli *githubClient, issues []*Issue, stacks map[string]map[Info update.State = "open" update.StateReason = "reopened" } - if err := cli.updateIssue(update); err != nil { + if err := cli.updateIssue(repo, update); err != nil { log.Printf("added comment to issue #%d but failed to update: %v", issue.Number, err) continue @@ -783,7 +804,7 @@ outer: // Report it. The user will interactively finish the task, // since they will typically de-dup it without even creating a new issue // by expanding the #!stacks predicate of an existing issue. - if !browser.Open("https://github.com/golang/go/issues/new?labels=" + labels + "&title=" + url.QueryEscape(title) + "&body=" + url.QueryEscape(body.String())) { + if !browser.Open("https://github.com/" + pcfg.Repository + "/issues/new?labels=" + labels + "&title=" + url.QueryEscape(title) + "&body=" + url.QueryEscape(body.String())) { log.Print("Please file a new issue at golang.org/issue/new using this template:\n\n") log.Printf("Title: %s\n", title) log.Printf("Labels: %s\n", labels) @@ -910,6 +931,14 @@ func frameURL(pclntab map[string]FileLine, info Info, frame string) string { } } + // Delve + const delveRepo = "github.com/go-delve/delve/" + if strings.HasPrefix(fileline.file, delveRepo) { + filename := fileline.file[len(delveRepo):] + return fmt.Sprintf("https://%sblob/%s/%s#L%d", delveRepo, info.ProgramVersion, filename, linenum) + + } + log.Printf("no CodeSearch URL for %q (%s:%d)", symbol, fileline.file, linenum) return "" @@ -952,7 +981,7 @@ type updateIssue struct { // -- GitHub search -- // searchIssues queries the GitHub issue tracker. -func (cli *githubClient) searchIssues(label string) ([]*Issue, error) { +func (cli *githubClient) searchIssues(repo, label string) ([]*Issue, error) { label = url.QueryEscape(label) // Slurp all issues with the telemetry label. @@ -966,7 +995,7 @@ func (cli *githubClient) searchIssues(label string) ([]*Issue, error) { // issues across pages. getPage := func(page int) ([]*Issue, error) { - url := fmt.Sprintf("https://api.github.com/repos/golang/go/issues?state=all&labels=%s&per_page=100&page=%d", label, page) + url := fmt.Sprintf("https://api.github.com/repos/%s/issues?state=all&labels=%s&per_page=100&page=%d", repo, label, page) req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err @@ -1007,7 +1036,7 @@ func (cli *githubClient) searchIssues(label string) ([]*Issue, error) { } // updateIssue updates the numbered issue. -func (cli *githubClient) updateIssue(update updateIssue) error { +func (cli *githubClient) updateIssue(repo string, update updateIssue) error { if cli.divertChanges { cli.changes = append(cli.changes, update) return nil @@ -1018,7 +1047,7 @@ func (cli *githubClient) updateIssue(update updateIssue) error { return err } - url := fmt.Sprintf("https://api.github.com/repos/golang/go/issues/%d", update.number) + url := fmt.Sprintf("https://api.github.com/repos/%s/issues/%d", repo, update.number) if err := cli.requestChange("PATCH", url, data, http.StatusOK); err != nil { return fmt.Errorf("updating issue: %v", err) } @@ -1026,7 +1055,7 @@ func (cli *githubClient) updateIssue(update updateIssue) error { } // addIssueComment adds a markdown comment to the numbered issue. -func (cli *githubClient) addIssueComment(number int, comment string) error { +func (cli *githubClient) addIssueComment(repo string, number int, comment string) error { if cli.divertChanges { cli.changes = append(cli.changes, addIssueComment{number, comment}) return nil @@ -1042,7 +1071,7 @@ func (cli *githubClient) addIssueComment(number int, comment string) error { return err } - url := fmt.Sprintf("https://api.github.com/repos/golang/go/issues/%d/comments", number) + url := fmt.Sprintf("https://api.github.com/repos/%s/issues/%d/comments", repo, number) if err := cli.requestChange("POST", url, data, http.StatusCreated); err != nil { return fmt.Errorf("creating issue comment: %v", err) } @@ -1167,6 +1196,17 @@ func readPCLineTable(info Info, stacksDir string) (map[string]FileLine, error) { // directory its go.mod doesn't restrict the toolchain versions // we're allowed to use. buildDir = "/" + case "github.com/go-delve/delve/cmd/dlv": + revDir := filepath.Join(stacksDir, "delve@"+info.ProgramVersion) + if !fileExists(filepath.Join(revDir, "go.mod")) { + _ = os.RemoveAll(revDir) + log.Printf("cloning github.com/go-delve/delve@%s", info.ProgramVersion) + if err := shallowClone(revDir, "https://github.com/go-delve/delve", info.ProgramVersion); err != nil { + _ = os.RemoveAll(revDir) + return nil, fmt.Errorf("clone: %v", err) + } + } + buildDir = revDir default: return nil, fmt.Errorf("don't know how to build unknown program %s", info.Program) } diff --git a/gopls/internal/telemetry/cmd/stacks/stacks_test.go b/gopls/internal/telemetry/cmd/stacks/stacks_test.go index 9f798aa43a3..d7bf12f830f 100644 --- a/gopls/internal/telemetry/cmd/stacks/stacks_test.go +++ b/gopls/internal/telemetry/cmd/stacks/stacks_test.go @@ -177,7 +177,7 @@ func TestUpdateIssues(t *testing.T) { ProgramVersion: "v0.16.1", } stacks := map[string]map[Info]int64{stack1: map[Info]int64{info: 3}} - updateIssues(c, issues, stacks, stacksToURL) + updateIssues(c, "golang/go", issues, stacks, stacksToURL) changes := c.takeChanges() if g, w := len(changes), 2; g != w { @@ -218,7 +218,7 @@ func TestUpdateIssues(t *testing.T) { ProgramVersion: "v0.17.0", } stacks := map[string]map[Info]int64{stack1: map[Info]int64{info: 3}} - updateIssues(c, issues, stacks, stacksToURL) + updateIssues(c, "golang/go", issues, stacks, stacksToURL) changes := c.takeChanges() if g, w := len(changes), 2; g != w { From baa4e14bebbb5ed05414a5d94d386419fc12e570 Mon Sep 17 00:00:00 2001 From: Peter Weinberger Date: Sun, 25 May 2025 09:03:29 -0400 Subject: [PATCH 126/196] internal/modindex: tiny test improvement Simplify some output statements. Change-Id: Idd6b9f6d02f7e362653b5b8884b1e5cc8b2157e6 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676195 Reviewed-by: Madeline Kalil LUCI-TryBot-Result: Go LUCI --- internal/modindex/lookup_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/modindex/lookup_test.go b/internal/modindex/lookup_test.go index 191395cffc9..265b26767dd 100644 --- a/internal/modindex/lookup_test.go +++ b/internal/modindex/lookup_test.go @@ -147,11 +147,11 @@ func TestLookupAll(t *testing.T) { t.Fatal(err) } defer fd.Close() - if _, err := fd.WriteString(fmt.Sprintf("package foo\n")); err != nil { + if _, err := fmt.Fprintf(fd, "package foo\n"); err != nil { t.Fatal(err) } for _, nm := range nms { - fd.WriteString(fmt.Sprintf("func %s() {}\n", nm)) + fmt.Fprintf(fd, "func %s() {}\n", nm) } } wrtModule("a.com/go/x4@v1.1.1", "A", "B", "C", "D") From 866eb14a2ca83430f9308d53d8bd2b1a1260b116 Mon Sep 17 00:00:00 2001 From: cuishuang Date: Sun, 25 May 2025 00:55:44 +0800 Subject: [PATCH 127/196] go/analysis/passes/printf: fix the issue where %#q/%#x/%#X recursion is not recognized Add recursive call detection for the %#q, %#x, and %#X formats to the current Go vet printf analyzer. Fixes golang/go#73825 Change-Id: If4a524436bc19ff4fca337aba7ca98c1c1ba4fa8 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676115 Auto-Submit: Alan Donovan Reviewed-by: Alan Donovan Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- go/analysis/passes/printf/printf.go | 4 +- go/analysis/passes/printf/testdata/src/a/a.go | 56 +++++++++++++------ 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/go/analysis/passes/printf/printf.go b/go/analysis/passes/printf/printf.go index a28ed365d1e..07d4fcf0a80 100644 --- a/go/analysis/passes/printf/printf.go +++ b/go/analysis/passes/printf/printf.go @@ -758,7 +758,9 @@ func okPrintfArg(pass *analysis.Pass, call *ast.CallExpr, maxArgIndex *int, firs pass.ReportRangef(call, "%s format %s has arg %s of wrong type %s%s", name, operation.Text, analysisinternal.Format(pass.Fset, arg), typeString, details) return false } - if v.typ&argString != 0 && v.verb != 'T' && !strings.Contains(operation.Flags, "#") { + // Detect recursive formatting via value's String/Error methods. + // The '#' flag suppresses the methods, except with %x, %X, and %q. + if v.typ&argString != 0 && v.verb != 'T' && (!strings.Contains(operation.Flags, "#") || strings.ContainsRune("qxX", v.verb)) { if methodName, ok := recursiveStringer(pass, arg); ok { pass.ReportRangef(call, "%s format %s with arg %s causes recursive %s method call", name, operation.Text, analysisinternal.Format(pass.Fset, arg), methodName) return false diff --git a/go/analysis/passes/printf/testdata/src/a/a.go b/go/analysis/passes/printf/testdata/src/a/a.go index da48f98f0a8..4a35773efe4 100644 --- a/go/analysis/passes/printf/testdata/src/a/a.go +++ b/go/analysis/passes/printf/testdata/src/a/a.go @@ -567,10 +567,16 @@ type recursiveStringer int func (s recursiveStringer) String() string { _ = fmt.Sprintf("%d", s) _ = fmt.Sprintf("%#v", s) - _ = fmt.Sprintf("%v", s) // want `fmt.Sprintf format %v with arg s causes recursive \(a.recursiveStringer\).String method call` - _ = fmt.Sprintf("%v", &s) // want `fmt.Sprintf format %v with arg &s causes recursive \(a.recursiveStringer\).String method call` - _ = fmt.Sprintf("%T", s) // ok; does not recursively call String - return fmt.Sprintln(s) // want `fmt.Sprintln arg s causes recursive call to \(a.recursiveStringer\).String method` + _ = fmt.Sprintf("%v", s) // want `fmt.Sprintf format %v with arg s causes recursive \(a.recursiveStringer\).String method call` + _ = fmt.Sprintf("%v", &s) // want `fmt.Sprintf format %v with arg &s causes recursive \(a.recursiveStringer\).String method call` + _ = fmt.Sprintf("%#x", s) // want `fmt.Sprintf format %#x with arg s causes recursive \(a.recursiveStringer\).String method call` + _ = fmt.Sprintf("%#x", &s) // want `fmt.Sprintf format %#x with arg &s causes recursive \(a.recursiveStringer\).String method call` + _ = fmt.Sprintf("%#X", s) // want `fmt.Sprintf format %#X with arg s causes recursive \(a.recursiveStringer\).String method call` + _ = fmt.Sprintf("%#X", &s) // want `fmt.Sprintf format %#X with arg &s causes recursive \(a.recursiveStringer\).String method call` + _ = fmt.Sprintf("%#q", s) // want `fmt.Sprintf format %#q with arg s causes recursive \(a.recursiveStringer\).String method call` + _ = fmt.Sprintf("%#q", &s) // want `fmt.Sprintf format %#q with arg &s causes recursive \(a.recursiveStringer\).String method call` + _ = fmt.Sprintf("%T", s) // ok; does not recursively call String + return fmt.Sprintln(s) // want `fmt.Sprintln arg s causes recursive call to \(a.recursiveStringer\).String method` } type recursivePtrStringer int @@ -586,10 +592,16 @@ type recursiveError int func (s recursiveError) Error() string { _ = fmt.Sprintf("%d", s) _ = fmt.Sprintf("%#v", s) - _ = fmt.Sprintf("%v", s) // want `fmt.Sprintf format %v with arg s causes recursive \(a.recursiveError\).Error method call` - _ = fmt.Sprintf("%v", &s) // want `fmt.Sprintf format %v with arg &s causes recursive \(a.recursiveError\).Error method call` - _ = fmt.Sprintf("%T", s) // ok; does not recursively call Error - return fmt.Sprintln(s) // want `fmt.Sprintln arg s causes recursive call to \(a.recursiveError\).Error method` + _ = fmt.Sprintf("%v", s) // want `fmt.Sprintf format %v with arg s causes recursive \(a.recursiveError\).Error method call` + _ = fmt.Sprintf("%v", &s) // want `fmt.Sprintf format %v with arg &s causes recursive \(a.recursiveError\).Error method call` + _ = fmt.Sprintf("%#x", s) // want `fmt.Sprintf format %#x with arg s causes recursive \(a.recursiveError\).Error method call` + _ = fmt.Sprintf("%#x", &s) // want `fmt.Sprintf format %#x with arg &s causes recursive \(a.recursiveError\).Error method call` + _ = fmt.Sprintf("%#X", s) // want `fmt.Sprintf format %#X with arg s causes recursive \(a.recursiveError\).Error method call` + _ = fmt.Sprintf("%#X", &s) // want `fmt.Sprintf format %#X with arg &s causes recursive \(a.recursiveError\).Error method call` + _ = fmt.Sprintf("%#q", s) // want `fmt.Sprintf format %#q with arg s causes recursive \(a.recursiveError\).Error method call` + _ = fmt.Sprintf("%#q", &s) // want `fmt.Sprintf format %#q with arg &s causes recursive \(a.recursiveError\).Error method call` + _ = fmt.Sprintf("%T", s) // ok; does not recursively call Error + return fmt.Sprintln(s) // want `fmt.Sprintln arg s causes recursive call to \(a.recursiveError\).Error method` } type recursivePtrError int @@ -605,19 +617,31 @@ type recursiveStringerAndError int func (s recursiveStringerAndError) String() string { _ = fmt.Sprintf("%d", s) _ = fmt.Sprintf("%#v", s) - _ = fmt.Sprintf("%v", s) // want `fmt.Sprintf format %v with arg s causes recursive \(a.recursiveStringerAndError\).String method call` - _ = fmt.Sprintf("%v", &s) // want `fmt.Sprintf format %v with arg &s causes recursive \(a.recursiveStringerAndError\).String method call` - _ = fmt.Sprintf("%T", s) // ok; does not recursively call String - return fmt.Sprintln(s) // want `fmt.Sprintln arg s causes recursive call to \(a.recursiveStringerAndError\).String method` + _ = fmt.Sprintf("%v", s) // want `fmt.Sprintf format %v with arg s causes recursive \(a.recursiveStringerAndError\).String method call` + _ = fmt.Sprintf("%v", &s) // want `fmt.Sprintf format %v with arg &s causes recursive \(a.recursiveStringerAndError\).String method call` + _ = fmt.Sprintf("%#x", s) // want `fmt.Sprintf format %#x with arg s causes recursive \(a.recursiveStringerAndError\).String method call` + _ = fmt.Sprintf("%#x", &s) // want `fmt.Sprintf format %#x with arg &s causes recursive \(a.recursiveStringerAndError\).String method call` + _ = fmt.Sprintf("%#X", s) // want `fmt.Sprintf format %#X with arg s causes recursive \(a.recursiveStringerAndError\).String method call` + _ = fmt.Sprintf("%#X", &s) // want `fmt.Sprintf format %#X with arg &s causes recursive \(a.recursiveStringerAndError\).String method call` + _ = fmt.Sprintf("%#q", s) // want `fmt.Sprintf format %#q with arg s causes recursive \(a.recursiveStringerAndError\).String method call` + _ = fmt.Sprintf("%#q", &s) // want `fmt.Sprintf format %#q with arg &s causes recursive \(a.recursiveStringerAndError\).String method call` + _ = fmt.Sprintf("%T", s) // ok; does not recursively call String + return fmt.Sprintln(s) // want `fmt.Sprintln arg s causes recursive call to \(a.recursiveStringerAndError\).String method` } func (s recursiveStringerAndError) Error() string { _ = fmt.Sprintf("%d", s) _ = fmt.Sprintf("%#v", s) - _ = fmt.Sprintf("%v", s) // want `fmt.Sprintf format %v with arg s causes recursive \(a.recursiveStringerAndError\).Error method call` - _ = fmt.Sprintf("%v", &s) // want `fmt.Sprintf format %v with arg &s causes recursive \(a.recursiveStringerAndError\).Error method call` - _ = fmt.Sprintf("%T", s) // ok; does not recursively call Error - return fmt.Sprintln(s) // want `fmt.Sprintln arg s causes recursive call to \(a.recursiveStringerAndError\).Error method` + _ = fmt.Sprintf("%v", s) // want `fmt.Sprintf format %v with arg s causes recursive \(a.recursiveStringerAndError\).Error method call` + _ = fmt.Sprintf("%v", &s) // want `fmt.Sprintf format %v with arg &s causes recursive \(a.recursiveStringerAndError\).Error method call` + _ = fmt.Sprintf("%#x", s) // want `fmt.Sprintf format %#x with arg s causes recursive \(a.recursiveStringerAndError\).Error method call` + _ = fmt.Sprintf("%#x", &s) // want `fmt.Sprintf format %#x with arg &s causes recursive \(a.recursiveStringerAndError\).Error method call` + _ = fmt.Sprintf("%#X", s) // want `fmt.Sprintf format %#X with arg s causes recursive \(a.recursiveStringerAndError\).Error method call` + _ = fmt.Sprintf("%#X", &s) // want `fmt.Sprintf format %#X with arg &s causes recursive \(a.recursiveStringerAndError\).Error method call` + _ = fmt.Sprintf("%#q", s) // want `fmt.Sprintf format %#q with arg s causes recursive \(a.recursiveStringerAndError\).Error method call` + _ = fmt.Sprintf("%#q", &s) // want `fmt.Sprintf format %#q with arg &s causes recursive \(a.recursiveStringerAndError\).Error method call` + _ = fmt.Sprintf("%T", s) // ok; does not recursively call Error + return fmt.Sprintln(s) // want `fmt.Sprintln arg s causes recursive call to \(a.recursiveStringerAndError\).Error method` } type recursivePtrStringerAndError int From 6d1bf3b345a09a8a3ad3341273e748ab0e6e4276 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 23 May 2025 18:01:25 -0400 Subject: [PATCH 128/196] internal/test/integration/bench: add second kubernetes benchmark This CL adds a second kubernetes benchmark that touches types.go, a file with many upward dependencies, exercising an extreme case of DidChange. getRepo supports an optional suffix so that many tests may use the same repo. All existing tests continue to use the same name, which is important for continuity of dashboard data. Change-Id: I2bd889ca12184de5f25d8b6433acd626fb50d638 Reviewed-on: https://go-review.googlesource.com/c/tools/+/675959 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- gopls/internal/test/integration/bench/didchange_test.go | 3 ++- gopls/internal/test/integration/bench/repo_test.go | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/gopls/internal/test/integration/bench/didchange_test.go b/gopls/internal/test/integration/bench/didchange_test.go index aa87a4f9b0e..5dc7c657e22 100644 --- a/gopls/internal/test/integration/bench/didchange_test.go +++ b/gopls/internal/test/integration/bench/didchange_test.go @@ -21,7 +21,7 @@ import ( var editID int64 = time.Now().UnixNano() type changeTest struct { - repo string + repo string // repo identifier + optional disambiguating ".foo" suffix file string canSave bool } @@ -30,6 +30,7 @@ var didChangeTests = []changeTest{ {"google-cloud-go", "internal/annotate.go", true}, {"istio", "pkg/fuzz/util.go", true}, {"kubernetes", "pkg/controller/lookup_cache.go", true}, + {"kubernetes.types", "staging/src/k8s.io/api/core/v1/types.go", true}, // results in 25K file batch! {"kuma", "api/generic/insights.go", true}, {"oracle", "dataintegration/data_type.go", false}, // diagnoseSave fails because this package is generated {"pkgsite", "internal/frontend/server.go", true}, diff --git a/gopls/internal/test/integration/bench/repo_test.go b/gopls/internal/test/integration/bench/repo_test.go index 65728c00552..e0b323ccd3b 100644 --- a/gopls/internal/test/integration/bench/repo_test.go +++ b/gopls/internal/test/integration/bench/repo_test.go @@ -13,6 +13,7 @@ import ( "log" "os" "path/filepath" + "strings" "sync" "testing" "time" @@ -109,8 +110,13 @@ var repos = map[string]*repo{ // getRepo gets the requested repo, and skips the test if -short is set and // repo is not configured as a short repo. +// +// The name may include an optional ".foo" suffix after the repo +// identifier. This allows several tests to use the same repo but have +// distinct test names and associated file names. func getRepo(tb testing.TB, name string) *repo { tb.Helper() + name, _, _ = strings.Cut(name, ".") // remove ".foo" suffix repo := repos[name] if repo == nil { tb.Fatalf("repo %s does not exist", name) From 14c014cf80eb58e3f78c92f38c86e5d921096d91 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 22 May 2025 11:25:09 -0400 Subject: [PATCH 129/196] internal/tokeninternal: optimize AddExistingFiles MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A sequence of calls to AddExistingFiles typically causes a large FileSet to grow in small amounts, resulting in a linear number of calls to O(n log n) sort operations. This consumed 12% of CPU on gopls DidChange/kubernetes benchmark on a file with a large upward dependency graph: the type- checking batch is approximately 25,000 files, and the FileSet grows to this size across thousands of calls. This fraction grows with the size of the project and appears to be the cause of field reports of poor performance. More efficient is to sort only the handful of new files, then merge with the existing ones, which are already sorted. In-place merging with constant space is too fiddly, so we allocate additional 2x space in the FileSet array and use the top half as the second buffer. This change improves the benchmark from CL 675875 by almost 10x. The more principled asymptotic improvement in CL 675736 will have to wait for go1.25. Apple M1 │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ FileSet_AddExistingFiles/sequence-8 387.91m ± 17% 53.28m ± 3% -86.27% (p=0.000 n=9) Change-Id: I04c3b76ff2df50b6413206f057455366ca71caef Reviewed-on: https://go-review.googlesource.com/c/tools/+/675535 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan --- internal/tokeninternal/tokeninternal_go124.go | 95 +++++++++++++++---- internal/tokeninternal/tokeninternal_test.go | 40 ++++++++ 2 files changed, 118 insertions(+), 17 deletions(-) diff --git a/internal/tokeninternal/tokeninternal_go124.go b/internal/tokeninternal/tokeninternal_go124.go index da34ae608ca..bcf346cac6f 100644 --- a/internal/tokeninternal/tokeninternal_go124.go +++ b/internal/tokeninternal/tokeninternal_go124.go @@ -10,9 +10,10 @@ package tokeninternal import ( + "cmp" "fmt" "go/token" - "sort" + "slices" "sync" "sync/atomic" "unsafe" @@ -22,7 +23,8 @@ import ( // are not already present. It panics if any pair of files in the // resulting FileSet would overlap. // -// TODO(adonovan): replace with FileSet.AddExistingFiles in go1.25. +// TODO(adonovan): replace with FileSet.AddExistingFiles in go1.25, +// which is much more efficient. func AddExistingFiles(fset *token.FileSet, files []*token.File) { // This function cannot be implemented as: @@ -62,18 +64,77 @@ func AddExistingFiles(fset *token.FileSet, files []*token.File) { ptr.mutex.Lock() defer ptr.mutex.Unlock() - // Merge and sort. - newFiles := append(ptr.files, files...) - sort.Slice(newFiles, func(i, j int) bool { - return newFiles[i].Base() < newFiles[j].Base() - }) + cmp := func(x, y *token.File) int { + return cmp.Compare(x.Base(), y.Base()) + } + + // A naive implementation would simply concatenate and sort + // the arrays. However, the typical usage pattern is to + // repeatedly add a handful of files to a large FileSet, which + // would cause O(n) sort operations each of O(n log n), where + // n is the total size. + // + // A more efficient approach is to sort the new items, then + // merge the two sorted lists. Although it is possible to do + // this in-place with only constant additional space, it is + // quite fiddly; see "Practical In-Place Merging", Huang & + // Langston, CACM, 1988. + // (https://dl.acm.org/doi/pdf/10.1145/42392.42403) + // + // If we could change the representation of FileSet, we could + // use double-buffering: allocate a second array of size m+n + // into which we merge the m initial items and the n new ones, + // then we switch the two arrays until the next time. + // + // But since we cannot, for now we grow the existing array by + // doubling to at least 2*(m+n), use the upper half as + // temporary space, then copy back into the lower half. + // Any excess capacity will help amortize future calls to + // AddExistingFiles. + // + // The implementation of FileSet for go1.25 is expected to use + // a balanced tree, making FileSet.AddExistingFiles much more + // efficient; see CL 675736. + + m, n := len(ptr.files), len(files) + size := m + n // final size assuming no duplicates + ptr.files = slices.Grow(ptr.files, 2*size) + ptr.files = append(ptr.files, files...) + + // Sort the new files, without mutating the files argument. + // (The existing ptr.files are already sorted.) + slices.SortFunc(ptr.files[m:size], cmp) + + // Merge old (x) and new (y) files into output array. + // For simplicity, we remove dups and check overlaps as a second pass. + var ( + x, y, out = ptr.files[:m], ptr.files[m:size], ptr.files[size:size] + xi, yi = 0, 0 + ) + for xi < m && yi < n { + xf := x[xi] + yf := y[yi] + switch cmp(xf, yf) { + case -1: + out = append(out, xf) + xi++ + case +1: + out = append(out, yf) + yi++ + default: + yi++ // equal; discard y + } + } + out = append(out, x[xi:]...) + out = append(out, y[yi:]...) - // Reject overlapping files. - // Discard adjacent identical files. - out := newFiles[:0] - for i, file := range newFiles { + // Compact out into start of ptr.files array, + // rejecting overlapping files and + // discarding adjacent identical files. + ptr.files = ptr.files[:0] + for i, file := range out { if i > 0 { - prev := newFiles[i-1] + prev := out[i-1] if file == prev { continue } @@ -83,15 +144,15 @@ func AddExistingFiles(fset *token.FileSet, files []*token.File) { file.Name(), file.Base(), file.Base()+file.Size())) } } - out = append(out, file) + ptr.files = append(ptr.files, file) } - newFiles = out - ptr.files = newFiles + // This ensures that we don't keep a File alive after RemoveFile. + clear(ptr.files[size:cap(ptr.files)]) // Advance FileSet.Base(). - if len(newFiles) > 0 { - last := newFiles[len(newFiles)-1] + if len(ptr.files) > 0 { + last := ptr.files[len(ptr.files)-1] newBase := last.Base() + last.Size() + 1 if ptr.base < newBase { ptr.base = newBase diff --git a/internal/tokeninternal/tokeninternal_test.go b/internal/tokeninternal/tokeninternal_test.go index 7fd14fea6a3..bbe2c060963 100644 --- a/internal/tokeninternal/tokeninternal_test.go +++ b/internal/tokeninternal/tokeninternal_test.go @@ -7,6 +7,7 @@ package tokeninternal_test import ( "fmt" "go/token" + "math/rand/v2" "strings" "testing" @@ -53,3 +54,42 @@ func fsetString(fset *token.FileSet) string { buf.WriteRune('}') return buf.String() } + +// This is a copy of the go/token benchmark from CL 675875. +func BenchmarkFileSet_AddExistingFiles(b *testing.B) { + // Create the "universe" of files. + fset := token.NewFileSet() + var files []*token.File + for range 25000 { + files = append(files, fset.AddFile("", -1, 10000)) + } + rand.Shuffle(len(files), func(i, j int) { + files[i], files[j] = files[j], files[i] + }) + + // choose returns n random files. + choose := func(n int) []*token.File { + res := make([]*token.File, n) + for i := range res { + res[i] = files[rand.IntN(n)] + } + return files[:n] + } + + // Measure the cost of creating a FileSet with a large number + // of files added in small handfuls, with some overlap. + // This case is critical to gopls. + b.Run("sequence", func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + fset2 := token.NewFileSet() + // 40% of files are already in the FileSet. + tokeninternal.AddExistingFiles(fset2, files[:10000]) + b.StartTimer() + + for range 1000 { + tokeninternal.AddExistingFiles(fset2, choose(10)) // about one package + } + } + }) +} From f8a56cca385e58555bf52618cc1ec4453e772f71 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 27 May 2025 17:09:50 +0000 Subject: [PATCH 130/196] internal/jsonrpc2_v2: update for go1.16+ Change-Id: If70bb4b7ab236ae921ff23d123cb38d476cfa665 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676516 Auto-Submit: Robert Findley Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- internal/jsonrpc2_v2/net.go | 6 +++--- internal/jsonrpc2_v2/serve_go116.go | 13 ------------- internal/jsonrpc2_v2/serve_pre116.go | 15 --------------- 3 files changed, 3 insertions(+), 31 deletions(-) delete mode 100644 internal/jsonrpc2_v2/serve_go116.go delete mode 100644 internal/jsonrpc2_v2/serve_pre116.go diff --git a/internal/jsonrpc2_v2/net.go b/internal/jsonrpc2_v2/net.go index 15d0aea3af0..a8bd6c081fc 100644 --- a/internal/jsonrpc2_v2/net.go +++ b/internal/jsonrpc2_v2/net.go @@ -100,14 +100,14 @@ func (l *netPiper) Accept(context.Context) (io.ReadWriteCloser, error) { // preferring the latter if already closed at the start of Accept. select { case <-l.done: - return nil, errClosed + return nil, net.ErrClosed default: } select { case rwc := <-l.dialed: return rwc, nil case <-l.done: - return nil, errClosed + return nil, net.ErrClosed } } @@ -133,6 +133,6 @@ func (l *netPiper) Dial(ctx context.Context) (io.ReadWriteCloser, error) { case <-l.done: client.Close() server.Close() - return nil, errClosed + return nil, net.ErrClosed } } diff --git a/internal/jsonrpc2_v2/serve_go116.go b/internal/jsonrpc2_v2/serve_go116.go deleted file mode 100644 index 19114502d1c..00000000000 --- a/internal/jsonrpc2_v2/serve_go116.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2022 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.16 - -package jsonrpc2 - -import ( - "net" -) - -var errClosed = net.ErrClosed diff --git a/internal/jsonrpc2_v2/serve_pre116.go b/internal/jsonrpc2_v2/serve_pre116.go deleted file mode 100644 index 9e8ece2ea7b..00000000000 --- a/internal/jsonrpc2_v2/serve_pre116.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2020 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.16 - -package jsonrpc2 - -import ( - "errors" -) - -// errClosed is an error with the same string as net.ErrClosed, -// which was added in Go 1.16. -var errClosed = errors.New("use of closed network connection") From 66fd75991b0a7b51c0573e565e3cbfcf22776e48 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Fri, 23 May 2025 15:21:57 +0000 Subject: [PATCH 131/196] internal/mcp: add pagination for resources Change-Id: I9f257fa09fe82e53087616859d1ad0c03f57dab6 Reviewed-on: https://go-review.googlesource.com/c/tools/+/675696 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/server.go | 63 +++++---- internal/mcp/server_example_test.go | 179 ++++++------------------ internal/mcp/server_test.go | 202 ++++++++++++++++++++++++++++ 3 files changed, 278 insertions(+), 166 deletions(-) create mode 100644 internal/mcp/server_test.go diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 8ad99b80457..18d3038181f 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -229,8 +229,17 @@ func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallTo func (s *Server) listResources(_ context.Context, _ *ServerSession, params *ListResourcesParams) (*ListResourcesResult, error) { s.mu.Lock() defer s.mu.Unlock() + var cursor string + if params != nil { + cursor = params.Cursor + } + resources, nextCursor, err := paginateList(s.resources, cursor, s.opts.PageSize) + if err != nil { + return nil, err + } res := new(ListResourcesResult) - for r := range s.resources.all() { + res.NextCursor = nextCursor + for _, r := range resources { res.Resources = append(res.Resources, r.Resource) } return res, nil @@ -539,35 +548,38 @@ type pageToken struct { LastUID string // The unique ID of the last resource seen. } -// paginateList returns a slice of features from the given featureSet, based on -// the provided cursor and page size. It also returns a new cursor for the next -// page, or an empty string if there are no more pages. -func paginateList[T any](fs *featureSet[T], cursor string, pageSize int) (features []T, nextCursor string, err error) { - encodeCursor := func(uid string) (string, error) { - var buf bytes.Buffer - token := pageToken{LastUID: uid} - encoder := gob.NewEncoder(&buf) - if err := encoder.Encode(token); err != nil { - return "", fmt.Errorf("failed to encode page token: %w", err) - } - return base64.URLEncoding.EncodeToString(buf.Bytes()), nil +// encodeCursor encodes a unique identifier (UID) into a opaque pagination cursor +// by serializing a pageToken struct. +func encodeCursor(uid string) (string, error) { + var buf bytes.Buffer + token := pageToken{LastUID: uid} + encoder := gob.NewEncoder(&buf) + if err := encoder.Encode(token); err != nil { + return "", fmt.Errorf("failed to encode page token: %w", err) } + return base64.URLEncoding.EncodeToString(buf.Bytes()), nil +} - decodeCursor := func(cursor string) (*pageToken, error) { - decodedBytes, err := base64.URLEncoding.DecodeString(cursor) - if err != nil { - return nil, fmt.Errorf("failed to decode cursor: %w", err) - } +// decodeCursor decodes an opaque pagination cursor into the original pageToken struct. +func decodeCursor(cursor string) (*pageToken, error) { + decodedBytes, err := base64.URLEncoding.DecodeString(cursor) + if err != nil { + return nil, fmt.Errorf("failed to decode cursor: %w", err) + } - var token pageToken - buf := bytes.NewBuffer(decodedBytes) - decoder := gob.NewDecoder(buf) - if err := decoder.Decode(&token); err != nil { - return nil, fmt.Errorf("failed to decode page token: %w, cursor: %v", err, cursor) - } - return &token, nil + var token pageToken + buf := bytes.NewBuffer(decodedBytes) + decoder := gob.NewDecoder(buf) + if err := decoder.Decode(&token); err != nil { + return nil, fmt.Errorf("failed to decode page token: %w, cursor: %v", err, cursor) } + return &token, nil +} +// paginateList returns a slice of features from the given featureSet, based on +// the provided cursor and page size. It also returns a new cursor for the next +// page, or an empty string if there are no more pages. +func paginateList[T any](fs *featureSet[T], cursor string, pageSize int) (features []T, nextCursor string, err error) { var seq iter.Seq[T] if cursor == "" { seq = fs.all() @@ -593,7 +605,6 @@ func paginateList[T any](fs *featureSet[T], cursor string, pageSize int) (featur if count < pageSize+1 { return features, "", nil } - // Trim the extra element from the result. nextCursor, err = encodeCursor(fs.uniqueID(features[len(features)-1])) if err != nil { return nil, "", err diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index fd1d22f3580..913a41cff20 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -8,7 +8,6 @@ import ( "context" "fmt" "log" - "slices" "testing" "github.com/google/go-cmp/cmp" @@ -57,158 +56,58 @@ func ExampleServer() { // Output: Hi user } -func TestListTool(t *testing.T) { - toolA := mcp.NewTool("apple", "apple tool", SayHi) - toolB := mcp.NewTool("banana", "banana tool", SayHi) - toolC := mcp.NewTool("cherry", "cherry tool", SayHi) - testCases := []struct { - tools []*mcp.ServerTool - want []*mcp.Tool - pageSize int - }{ - { - // Simple test. - []*mcp.ServerTool{toolA, toolB, toolC}, - []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool}, - mcp.DefaultPageSize, - }, - { - // Tools should be ordered by tool name. - []*mcp.ServerTool{toolC, toolA, toolB}, - []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool}, - mcp.DefaultPageSize, - }, - { - // Page size of 1 should yield the first tool only. - []*mcp.ServerTool{toolC, toolA, toolB}, - []*mcp.Tool{toolA.Tool}, - 1, - }, - { - // Page size of 2 should yield the first 2 tools only. - []*mcp.ServerTool{toolC, toolA, toolB}, - []*mcp.Tool{toolA.Tool, toolB.Tool}, - 2, - }, - { - // Page size of 3 should yield all tools. - []*mcp.ServerTool{toolC, toolA, toolB}, - []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool}, - 3, - }, - { - []*mcp.ServerTool{}, - nil, - 1, - }, - } - ctx := context.Background() - for _, tc := range testCases { - server := mcp.NewServer("server", "v0.0.1", &mcp.ServerOptions{PageSize: tc.pageSize}) - server.AddTools(tc.tools...) - clientTransport, serverTransport := mcp.NewInMemoryTransports() - serverSession, err := server.Connect(ctx, serverTransport) - if err != nil { - log.Fatal(err) - } - client := mcp.NewClient("client", "v0.0.1", nil) - clientSession, err := client.Connect(ctx, clientTransport) - if err != nil { - log.Fatal(err) - } - res, err := clientSession.ListTools(ctx, nil) - serverSession.Close() - clientSession.Close() - if err != nil { - log.Fatal(err) - } - if len(res.Tools) != len(tc.want) { - t.Fatalf("expected %d tools, got %d", len(tc.want), len(res.Tools)) - } - if diff := cmp.Diff(res.Tools, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("expected tools %+v, got %+v", tc.want, res.Tools) - } - if tc.pageSize < len(tc.tools) && res.NextCursor == "" { - t.Fatalf("expected next cursor, got none") - } - } -} - -func TestListToolPaginateInvalidCursor(t *testing.T) { - toolA := mcp.NewTool("apple", "apple tool", SayHi) - ctx := context.Background() +// createSessions creates and connects an in-memory client and server session for testing purposes. +func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession, *mcp.Server) { server := mcp.NewServer("server", "v0.0.1", nil) - server.AddTools(toolA) - clientTransport, serverTransport := mcp.NewInMemoryTransports() + client := mcp.NewClient("client", "v0.0.1", nil) + serverTransport, clientTransport := mcp.NewInMemoryTransports() serverSession, err := server.Connect(ctx, serverTransport) if err != nil { log.Fatal(err) } - client := mcp.NewClient("client", "v0.0.1", nil) clientSession, err := client.Connect(ctx, clientTransport) if err != nil { log.Fatal(err) } - _, err = clientSession.ListTools(ctx, &mcp.ListToolsParams{Cursor: "invalid"}) - if err == nil { - t.Fatalf("expected error, got none") - } - serverSession.Close() - clientSession.Close() + return clientSession, serverSession, server } -func TestListToolPaginate(t *testing.T) { - serverTools := []*mcp.ServerTool{ - mcp.NewTool("apple", "apple tool", SayHi), - mcp.NewTool("banana", "banana tool", SayHi), - mcp.NewTool("cherry", "cherry tool", SayHi), - mcp.NewTool("durian", "durian tool", SayHi), - mcp.NewTool("elderberry", "elderberry tool", SayHi), +func TestListTool(t *testing.T) { + toolA := mcp.NewTool("apple", "apple tool", SayHi) + toolB := mcp.NewTool("banana", "banana tool", SayHi) + toolC := mcp.NewTool("cherry", "cherry tool", SayHi) + tools := []*mcp.ServerTool{toolA, toolB, toolC} + wantTools := []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool} + ctx := context.Background() + clientSession, serverSession, server := createSessions(ctx) + defer clientSession.Close() + defer serverSession.Close() + server.AddTools(tools...) + res, err := clientSession.ListTools(ctx, nil) + if err != nil { + t.Fatal("ListTools() failed:", err) } - var wantTools []*mcp.Tool - for _, tool := range serverTools { - wantTools = append(wantTools, tool.Tool) + if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("ListTools() mismatch (-want +got):\n%s", diff) } - ctx := context.Background() - // Try all possible page sizes, ensuring we get the correct list of tools. - for pageSize := 1; pageSize < len(serverTools)+1; pageSize++ { - server := mcp.NewServer("server", "v0.0.1", &mcp.ServerOptions{PageSize: pageSize}) - server.AddTools(serverTools...) - clientTransport, serverTransport := mcp.NewInMemoryTransports() - serverSession, err := server.Connect(ctx, serverTransport) - if err != nil { - log.Fatal(err) - } - client := mcp.NewClient("client", "v0.0.1", nil) - clientSession, err := client.Connect(ctx, clientTransport) - if err != nil { - log.Fatal(err) - } - var gotTools []*mcp.Tool - var nextCursor string - wantChunks := slices.Collect(slices.Chunk(wantTools, pageSize)) - index := 0 - // Iterate through all pages, comparing sub-slices to the paginated list. - for { - res, err := clientSession.ListTools(ctx, &mcp.ListToolsParams{Cursor: nextCursor}) - if err != nil { - log.Fatal(err) - } - gotTools = append(gotTools, res.Tools...) - nextCursor = res.NextCursor - if diff := cmp.Diff(res.Tools, wantChunks[index], cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Errorf("expected %v, got %v, (-want +got):\n%s", wantChunks[index], res.Tools, diff) - } - if res.NextCursor == "" { - break - } - index++ - } - serverSession.Close() - clientSession.Close() +} - if len(gotTools) != len(wantTools) { - t.Fatalf("expected %d tools, got %d", len(wantTools), len(gotTools)) - } +func TestListResources(t *testing.T) { + resourceA := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://apple"}} + resourceB := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://banana"}} + resourceC := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://cherry"}} + resources := []*mcp.ServerResource{resourceA, resourceB, resourceC} + wantResource := []*mcp.Resource{resourceA.Resource, resourceB.Resource, resourceC.Resource} + ctx := context.Background() + clientSession, serverSession, server := createSessions(ctx) + defer clientSession.Close() + defer serverSession.Close() + server.AddResources(resources...) + res, err := clientSession.ListResources(ctx, nil) + if err != nil { + t.Fatal("ListResources() failed:", err) + } + if diff := cmp.Diff(wantResource, res.Resources, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("ListResources() mismatch (-want +got):\n%s", diff) } } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go new file mode 100644 index 00000000000..67a13f9b9ba --- /dev/null +++ b/internal/mcp/server_test.go @@ -0,0 +1,202 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "log" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" +) + +type TestItem struct { + Name string + Value string +} + +var allTestItems = []*TestItem{ + {"alpha", "val-A"}, + {"bravo", "val-B"}, + {"charlie", "val-C"}, + {"delta", "val-D"}, + {"echo", "val-E"}, + {"foxtrot", "val-F"}, + {"golf", "val-G"}, + {"hotel", "val-H"}, + {"india", "val-I"}, + {"juliet", "val-J"}, + {"kilo", "val-K"}, +} + +// getCursor encodes a string input into a URL-safe base64 cursor, +// fatally logging any encoding errors. +func getCursor(input string) string { + cursor, err := encodeCursor(input) + if err != nil { + log.Fatalf("encodeCursor(%s) error = %v", input, err) + } + return cursor +} + +func TestPaginateBasic(t *testing.T) { + testCases := []struct { + name string + initialItems []*TestItem + inputCursor string + inputPageSize int + wantFeatures []*TestItem + wantNextCursor string + wantErr bool + }{ + { + name: "FirstPage_DefaultSize_Full", + initialItems: allTestItems, + inputCursor: "", + inputPageSize: 5, + wantFeatures: allTestItems[0:5], + wantNextCursor: getCursor("echo"), // Based on last item of first page + wantErr: false, + }, + { + name: "SecondPage_DefaultSize_Full", + initialItems: allTestItems, + inputCursor: getCursor("echo"), + inputPageSize: 5, + wantFeatures: allTestItems[5:10], + wantNextCursor: getCursor("juliet"), // Based on last item of second page + wantErr: false, + }, + { + name: "SecondPage_DefaultSize_Full_OutOfOrder", + initialItems: append(allTestItems[5:], allTestItems[0:5]...), + inputCursor: getCursor("echo"), + inputPageSize: 5, + wantFeatures: allTestItems[5:10], + wantNextCursor: getCursor("juliet"), // Based on last item of second page + wantErr: false, + }, + { + name: "SecondPage_DefaultSize_Full_Duplicates", + initialItems: append(allTestItems, allTestItems[0:5]...), + inputCursor: getCursor("echo"), + inputPageSize: 5, + wantFeatures: allTestItems[5:10], + wantNextCursor: getCursor("juliet"), // Based on last item of second page + wantErr: false, + }, + { + name: "LastPage_Remaining", + initialItems: allTestItems, + inputCursor: getCursor("juliet"), + inputPageSize: 5, + wantFeatures: allTestItems[10:11], // Only 1 item left + wantNextCursor: "", // No more pages + wantErr: false, + }, + { + name: "PageSize_1", + initialItems: allTestItems, + inputCursor: "", + inputPageSize: 1, + wantFeatures: allTestItems[0:1], + wantNextCursor: getCursor("alpha"), + wantErr: false, + }, + { + name: "PageSize_All", + initialItems: allTestItems, + inputCursor: "", + inputPageSize: len(allTestItems), // Page size equals total + wantFeatures: allTestItems, + wantNextCursor: "", // No more pages + wantErr: false, + }, + { + name: "PageSize_LargerThanAll", + initialItems: allTestItems, + inputCursor: "", + inputPageSize: len(allTestItems) + 5, // Page size larger than total + wantFeatures: allTestItems, + wantNextCursor: "", + wantErr: false, + }, + { + name: "EmptySet", + initialItems: nil, + inputCursor: "", + inputPageSize: 5, + wantFeatures: nil, + wantNextCursor: "", + wantErr: false, + }, + { + name: "InvalidCursor", + initialItems: allTestItems, + inputCursor: "not-a-valid-gob-base64-cursor", + inputPageSize: 5, + wantFeatures: nil, // Should be nil for error cases + wantNextCursor: "", + wantErr: true, + }, + { + name: "AboveNonExistentID", + initialItems: allTestItems, + inputCursor: getCursor("dne"), // A UID that doesn't exist + inputPageSize: 5, + wantFeatures: allTestItems[4:9], // Should return elements above UID. + wantNextCursor: getCursor("india"), + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fs := newFeatureSet(func(t *TestItem) string { return t.Name }) + fs.add(tc.initialItems...) + gotFeatures, gotNextCursor, err := paginateList(fs, tc.inputCursor, tc.inputPageSize) + if (err != nil) != tc.wantErr { + t.Errorf("paginateList(%s) error, got %v, wantErr %v", tc.name, err, tc.wantErr) + } + if diff := cmp.Diff(tc.wantFeatures, gotFeatures); diff != "" { + t.Errorf("paginateList(%s) mismatch (-want +got):\n%s", tc.name, diff) + } + if tc.wantNextCursor != gotNextCursor { + t.Errorf("paginateList(%s) nextCursor, got %v, want %v", tc.name, gotNextCursor, tc.wantNextCursor) + } + }) + } +} + +func TestPaginateVariousPageSizes(t *testing.T) { + fs := newFeatureSet(func(t *TestItem) string { return t.Name }) + fs.add(allTestItems...) + // Try all possible page sizes, ensuring we get the correct list of items. + for pageSize := 1; pageSize < len(allTestItems)+1; pageSize++ { + var gotItems []*TestItem + var nextCursor string + wantChunks := slices.Collect(slices.Chunk(allTestItems, pageSize)) + index := 0 + // Iterate through all pages, comparing sub-slices to the paginated list. + for { + gotFeatures, gotNextCursor, err := paginateList(fs, nextCursor, pageSize) + if err != nil { + } + if diff := cmp.Diff(wantChunks[index], gotFeatures); diff != "" { + t.Errorf("paginateList mismatch (-want +got):\n%s", diff) + } + gotItems = append(gotItems, gotFeatures...) + nextCursor = gotNextCursor + if nextCursor == "" { + break + } + index++ + } + + if len(gotItems) != len(allTestItems) { + t.Fatalf("paginateList() returned %d items, want %d", len(allTestItems), len(gotItems)) + } + } +} From 845000b76ebd98f68f6a7820138451b86d3a8f8f Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 21 May 2025 07:11:50 -0400 Subject: [PATCH 132/196] internal/mcp: meta and progress token Meta is a struct that contains a map[string]any, and a ProgressToken field. Each Params and Result type has a Meta as a value field. Meta marshals and unmarshals specially, so the ProgressToken field becomes part of the map. The marshalling code is generalized in anticipation of every Params and Result growing a map of additional properties. Change-Id: I6d2f34270d33501c96437f24f91ef3e51d20d955 Reviewed-on: https://go-review.googlesource.com/c/tools/+/674956 Reviewed-by: Sam Thanawalla LUCI-TryBot-Result: Go LUCI Auto-Submit: Jonathan Amsterdam Reviewed-by: Robert Findley --- internal/mcp/client.go | 22 +-- internal/mcp/design/design.md | 17 +- internal/mcp/generate.go | 35 +++- internal/mcp/jsonschema/validate.go | 11 +- internal/mcp/jsonschema/validate_test.go | 21 +++ internal/mcp/mcp_test.go | 2 +- internal/mcp/protocol.go | 201 ++++++++++++++++------- internal/mcp/server.go | 14 +- internal/mcp/shared.go | 85 ++++++++-- internal/mcp/shared_test.go | 70 ++++++++ internal/mcp/util.go | 137 +++++++++++++++ internal/mcp/util_test.go | 48 ++++++ 12 files changed, 548 insertions(+), 115 deletions(-) create mode 100644 internal/mcp/shared_test.go create mode 100644 internal/mcp/util_test.go diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 79ccd89e859..a42e74678fa 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -162,7 +162,7 @@ func (c *Client) RemoveRoots(uris ...string) { // changeAndNotify is called when a feature is added or removed. // It calls change, which should do the work and report whether a change actually occurred. // If there was a change, it notifies a snapshot of the sessions. -func (c *Client) changeAndNotify(notification string, params any, change func() bool) { +func (c *Client) changeAndNotify(notification string, params Params, change func() bool) { var sessions []*ClientSession // Lock for the change, but not for the notification. c.mu.Lock() @@ -231,8 +231,8 @@ func (cs *ClientSession) methodHandler() MethodHandler[ClientSession] { // getConn implements [session.getConn]. func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn } -func (c *ClientSession) ping(ct context.Context, params *PingParams) (struct{}, error) { - return struct{}{}, nil +func (c *ClientSession) ping(ct context.Context, params *PingParams) (Result, error) { + return emptyResult{}, nil } // Ping makes an MCP "ping" request to the server. @@ -296,29 +296,21 @@ func (c *ClientSession) ReadResource(ctx context.Context, params *ReadResourcePa return standardCall[ReadResourceResult](ctx, c.conn, methodReadResource, params) } -func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (any, error) { +func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (Result, error) { return callNotificationHandler(ctx, c.opts.ToolListChangedHandler, s, params) } -func (c *Client) callPromptChangedHandler(ctx context.Context, s *ClientSession, params *PromptListChangedParams) (any, error) { +func (c *Client) callPromptChangedHandler(ctx context.Context, s *ClientSession, params *PromptListChangedParams) (Result, error) { return callNotificationHandler(ctx, c.opts.PromptListChangedHandler, s, params) } -func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSession, params *ResourceListChangedParams) (any, error) { +func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSession, params *ResourceListChangedParams) (Result, error) { return callNotificationHandler(ctx, c.opts.ResourceListChangedHandler, s, params) } -func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, params *LoggingMessageParams) (any, error) { +func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, params *LoggingMessageParams) (Result, error) { if h := c.opts.LoggingMessageHandler; h != nil { h(ctx, cs, params) } return nil, nil } - -func standardCall[TRes, TParams any](ctx context.Context, conn *jsonrpc2.Connection, method string, params TParams) (*TRes, error) { - var result TRes - if err := call(ctx, conn, method, params, &result); err != nil { - return nil, err - } - return &result, nil -} diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 28de9aa6b4b..642fdda3add 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -243,7 +243,7 @@ Types needed for the protocol are generated from the [JSON schema of the MCP spe These types will be included in the `mcp` package, but will be unexported unless they are needed for the user-facing API. Notably, JSON-RPC request types are elided, since they are handled by the `jsonrpc2` package and should not be observed by the user. -For user-provided data, we use `json.RawMessage`, so that marshalling/unmarshalling can be delegated to the business logic of the client or server. +For user-provided data, we use `json.RawMessage` or `map[string]any`, depending on the use case. For union types, which can't be represented in Go (specifically `Content` and `ResourceContents`), we prefer distinguished unions: struct types with fields corresponding to the union of all properties for union elements. @@ -255,9 +255,9 @@ type ReadResourceParams struct { } type CallToolResult struct { - Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Content []Content `json:"content"` - IsError bool `json:"isError,omitempty"` + Meta Meta `json:"_meta,omitempty"` + Content []Content `json:"content"` + IsError bool `json:"isError,omitempty"` } // Content is the wire format for content. @@ -276,6 +276,7 @@ type Content struct { func NewTextContent(text string) *Content // etc. ``` +The `Meta` type includes a `map[string]any` for arbitrary data, and a `ProgressToken` field. **Differences from mcp-go**: these types are largely similar, but our type generator flattens types rather than using struct embedding. @@ -480,15 +481,21 @@ The server observes a client cancellation as a cancelled context. ### Progress handling -A caller can request progress notifications by setting the `ProgressToken` field on any request. +A caller can request progress notifications by setting the `Meta.ProgressToken` field on any request. ```go type XXXParams struct { // where XXX is each type of call + Meta Meta ... +} + +type Meta struct { + Data map[string]any ProgressToken any // string or int } ``` + Handlers can notify their peer about progress by calling the `NotifyProgress` method. The notification is only sent if the peer requested it by providing a progress token. ```go diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index a7cfe6d745e..2d62d4350ca 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -299,13 +299,21 @@ func loadSchema(schemaFile string) (data []byte, err error) { func writeDecl(configName string, config typeConfig, def *jsonschema.Schema, named map[string]*bytes.Buffer) error { var w io.Writer = io.Discard - if typeName := config.Name; typeName != "-" { + var typeName string + if typeName = config.Name; typeName != "-" { if typeName == "" { typeName = configName } if _, ok := named[typeName]; ok { return nil } + // The JSON schema does not accurately represent the source of truth, which is typescript. + // Every Params and Result type should have a _meta property. + // Also, those with a progress token will turn into a struct; we want the progress token to + // be a map item. So replace all metas. + if strings.HasSuffix(typeName, "Params") || strings.HasSuffix(typeName, "Result") { + def.Properties["_meta"] = metaSchema + } buf := new(bytes.Buffer) w = buf named[typeName] = buf @@ -318,6 +326,12 @@ func writeDecl(configName string, config typeConfig, def *jsonschema.Schema, nam return err // Better error here? } fmt.Fprintf(w, "\n") + + // Any decl with a _meta field gets a GetMeta method. + if _, ok := def.Properties["_meta"]; ok { + fmt.Fprintf(w, "\nfunc (x *%s) GetMeta() *Meta { return &x.Meta }", typeName) + } + return nil } @@ -354,11 +368,7 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma // For types that explicitly allow additional properties, we can either // unmarshal them into a map[string]any, or delay unmarshalling with - // json.RawMessage. For now, use json.RawMessage as it defers the choice. - // - // TODO(jba): further refine this classification of object schemas. - // For example, the typescript "object" type, which should map to a Go "any", - // is represented in schema.json by `{type: object, properties: {}, additionalProperties: true}`. + // json.RawMessage. We use any. if def.Type == "object" && canHaveAdditionalProperties(def) && def.Properties == nil { w.Write([]byte("map[string]")) return writeType(w, nil, def.AdditionalProperties, named) @@ -372,7 +382,7 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma fmt.Fprintf(w, "*Content") } else { // E.g. union types. - fmt.Fprintf(w, "json.RawMessage") + fmt.Fprintf(w, "any") } } else { switch def.Type { @@ -398,6 +408,11 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma if fieldDef.Description != "" { fmt.Fprintf(w, "%s\n", toComment(fieldDef.Description)) } + if name == "_meta" { + fmt.Fprintln(w, "\tMeta Meta `json:\"_meta,omitempty\"`") + continue + } + export := exportName(name) fmt.Fprintf(w, "\t%s ", export) @@ -551,6 +566,12 @@ func isStruct(s *jsonschema.Schema) bool { return s.Type == "object" && s.Properties != nil && !canHaveAdditionalProperties(s) } +// The schema for "_meta". +// We only need the description: the rest is a special case. +var metaSchema = &jsonschema.Schema{ + Description: "This property is reserved by the protocol to allow clients and servers to attach additional metadata to their responses.", +} + // schemaJSON returns the JSON for s. // For debugging. func schemaJSON(s *jsonschema.Schema) string { diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index 6651ec10ea2..1bc58dfc116 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -616,14 +616,19 @@ func structPropertiesOf(t reflect.Type) propertyMap { // jsonName returns the name for f as would be used by [json.Marshal]. // That is the name in the json struct tag, or the field name if there is no tag. -// If f is not exported or the tag name is "-", jsonName returns "", false. +// If f is not exported or the tag is "-", jsonName returns "", false. func jsonName(f reflect.StructField) (string, bool) { if !f.IsExported() { return "", false } if tag, ok := f.Tag.Lookup("json"); ok { - if name, _, _ := strings.Cut(tag, ","); name != "" { - return name, name != "-" + name, _, found := strings.Cut(tag, ",") + // "-" means omit, but "-," means the name is "-" + if name == "-" && !found { + return "", false + } + if name != "" { + return name, true } } return f.Name, true diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index b6b8d61a191..b5d75438e17 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "strings" "testing" ) @@ -171,6 +172,26 @@ func TestStructInstance(t *testing.T) { } } +func TestJSONName(t *testing.T) { + type S struct { + A int + B int `json:","` + C int `json:"-"` + D int `json:"-,"` + E int `json:"echo"` + F int `json:"foxtrot,omitempty"` + g int `json:"golf"` + } + want := []string{"A", "B", "", "-", "echo", "foxtrot", ""} + tt := reflect.TypeFor[S]() + for i := range tt.NumField() { + got, _ := jsonName(tt.Field(i)) + if got != want[i] { + t.Errorf("got %q, want %q", got, want[i]) + } + } +} + // loadRemote loads a remote reference used in the test suite. func loadRemote(uri *url.URL) (*Schema, error) { // Anything with localhost:1234 refers to the remotes directory in the test suite repo. diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 798da88e313..8a8ddfd1290 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -620,7 +620,7 @@ func TestMiddleware(t *testing.T) { // with the given prefix. func traceCalls[S ClientSession | ServerSession](w io.Writer, prefix string) Middleware[S] { return func(h MethodHandler[S]) MethodHandler[S] { - return func(ctx context.Context, sess *S, method string, params any) (any, error) { + return func(ctx context.Context, sess *S, method string, params Params) (Result, error) { fmt.Fprintf(w, "%s >%s\n", prefix, method) defer fmt.Fprintf(w, "%s <%s\n", prefix, method) return h(ctx, sess, method, params) diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index 25aeff1f567..15babb6fd5f 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -29,10 +29,15 @@ type Annotations struct { } type CallToolParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` Arguments json.RawMessage `json:"arguments,omitempty"` Name string `json:"name"` } +func (x *CallToolParams) GetMeta() *Meta { return &x.Meta } + // The server's response to a tool call. // // Any errors that originate from the tool SHOULD be reported inside the result @@ -44,17 +49,22 @@ type CallToolParams struct { // server does not support tool calls, or any other exceptional conditions, // should be reported as an MCP error response. type CallToolResult struct { - // This result property is reserved by the protocol to allow clients and servers - // to attach additional metadata to their responses. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Content []*Content `json:"content"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` + Content []*Content `json:"content"` // Whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). IsError bool `json:"isError,omitempty"` } +func (x *CallToolResult) GetMeta() *Meta { return &x.Meta } + type CancelledParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // An optional string describing the reason for the cancellation. This MAY be // logged or presented to the user. Reason string `json:"reason,omitempty"` @@ -65,6 +75,8 @@ type CancelledParams struct { RequestID any `json:"requestId"` } +func (x *CancelledParams) GetMeta() *Meta { return &x.Meta } + // Capabilities a client may support. Known capabilities are defined here, in // this schema, but this is not a closed set: any client can define its own, // additional capabilities. @@ -82,6 +94,9 @@ type ClientCapabilities struct { } type CreateMessageParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // A request to include context from one or more MCP servers (including the // caller), to be attached to the prompt. The client MAY ignore this request. IncludeContext string `json:"includeContext,omitempty"` @@ -103,15 +118,17 @@ type CreateMessageParams struct { Temperature float64 `json:"temperature,omitempty"` } +func (x *CreateMessageParams) GetMeta() *Meta { return &x.Meta } + // The client's response to a sampling/create_message request from the server. // The client should inform the user before returning the sampled message, to // allow them to inspect the response (human in the loop) and decide whether to // allow the server to see it. type CreateMessageResult struct { - // This result property is reserved by the protocol to allow clients and servers - // to attach additional metadata to their responses. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Content *Content `json:"content"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` + Content *Content `json:"content"` // The name of the model that generated the message. Model string `json:"model"` Role Role `json:"role"` @@ -119,24 +136,36 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } +func (x *CreateMessageResult) GetMeta() *Meta { return &x.Meta } + type GetPromptParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // Arguments to use for templating the prompt. Arguments map[string]string `json:"arguments,omitempty"` // The name of the prompt or prompt template. Name string `json:"name"` } +func (x *GetPromptParams) GetMeta() *Meta { return &x.Meta } + // The server's response to a prompts/get request from the client. type GetPromptResult struct { - // This result property is reserved by the protocol to allow clients and servers - // to attach additional metadata to their responses. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // An optional description for the prompt. Description string `json:"description,omitempty"` Messages []*PromptMessage `json:"messages"` } +func (x *GetPromptResult) GetMeta() *Meta { return &x.Meta } + type InitializeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` Capabilities *ClientCapabilities `json:"capabilities"` ClientInfo *implementation `json:"clientInfo"` // The latest version of the Model Context Protocol that the client supports. @@ -144,13 +173,15 @@ type InitializeParams struct { ProtocolVersion string `json:"protocolVersion"` } +func (x *InitializeParams) GetMeta() *Meta { return &x.Meta } + // After receiving an initialize request from the client, the server sends this // response. type InitializeResult struct { - // This result property is reserved by the protocol to allow clients and servers - // to attach additional metadata to their responses. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Capabilities *serverCapabilities `json:"capabilities"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` + Capabilities *serverCapabilities `json:"capabilities"` // Instructions describing how to use the server and its features. // // This can be used by clients to improve the LLM's understanding of available @@ -164,83 +195,108 @@ type InitializeResult struct { ServerInfo *implementation `json:"serverInfo"` } +func (x *InitializeResult) GetMeta() *Meta { return &x.Meta } + type InitializedParams struct { - // This parameter name is reserved by MCP to allow clients and servers to attach - // additional metadata to their notifications. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` } +func (x *InitializedParams) GetMeta() *Meta { return &x.Meta } + type ListPromptsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // An opaque token representing the current pagination position. If provided, // the server should return results starting after this cursor. Cursor string `json:"cursor,omitempty"` } +func (x *ListPromptsParams) GetMeta() *Meta { return &x.Meta } + // The server's response to a prompts/list request from the client. type ListPromptsResult struct { - // This result property is reserved by the protocol to allow clients and servers - // to attach additional metadata to their responses. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // An opaque token representing the pagination position after the last returned // result. If present, there may be more results available. NextCursor string `json:"nextCursor,omitempty"` Prompts []*Prompt `json:"prompts"` } +func (x *ListPromptsResult) GetMeta() *Meta { return &x.Meta } + type ListResourcesParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // An opaque token representing the current pagination position. If provided, // the server should return results starting after this cursor. Cursor string `json:"cursor,omitempty"` } +func (x *ListResourcesParams) GetMeta() *Meta { return &x.Meta } + // The server's response to a resources/list request from the client. type ListResourcesResult struct { - // This result property is reserved by the protocol to allow clients and servers - // to attach additional metadata to their responses. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // An opaque token representing the pagination position after the last returned // result. If present, there may be more results available. NextCursor string `json:"nextCursor,omitempty"` Resources []*Resource `json:"resources"` } +func (x *ListResourcesResult) GetMeta() *Meta { return &x.Meta } + type ListRootsParams struct { - Meta struct { - // If specified, the caller is requesting out-of-band progress notifications for - // this request (as represented by notifications/progress). The value of this - // parameter is an opaque token that will be attached to any subsequent - // notifications. The receiver is not obligated to provide these notifications. - ProgressToken any `json:"progressToken,omitempty"` - } `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` } +func (x *ListRootsParams) GetMeta() *Meta { return &x.Meta } + // The client's response to a roots/list request from the server. This result // contains an array of Root objects, each representing a root directory or file // that the server can operate on. type ListRootsResult struct { - // This result property is reserved by the protocol to allow clients and servers - // to attach additional metadata to their responses. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Roots []*Root `json:"roots"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` + Roots []*Root `json:"roots"` } +func (x *ListRootsResult) GetMeta() *Meta { return &x.Meta } + type ListToolsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // An opaque token representing the current pagination position. If provided, // the server should return results starting after this cursor. Cursor string `json:"cursor,omitempty"` } +func (x *ListToolsParams) GetMeta() *Meta { return &x.Meta } + // The server's response to a tools/list request from the client. type ListToolsResult struct { - // This result property is reserved by the protocol to allow clients and servers - // to attach additional metadata to their responses. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // An opaque token representing the pagination position after the last returned // result. If present, there may be more results available. NextCursor string `json:"nextCursor,omitempty"` Tools []*Tool `json:"tools"` } +func (x *ListToolsResult) GetMeta() *Meta { return &x.Meta } + // The severity of a log message. // // These map to syslog message severities, as specified in RFC-5424: @@ -248,6 +304,9 @@ type ListToolsResult struct { type LoggingLevel string type LoggingMessageParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // The data to be logged, such as a string message or an object. Any JSON // serializable type is allowed here. Data any `json:"data"` @@ -257,6 +316,8 @@ type LoggingMessageParams struct { Logger string `json:"logger,omitempty"` } +func (x *LoggingMessageParams) GetMeta() *Meta { return &x.Meta } + // Hints to use for model selection. // // Keys not declared here are currently left unspecified by the spec and are up @@ -310,15 +371,13 @@ type ModelPreferences struct { } type PingParams struct { - Meta struct { - // If specified, the caller is requesting out-of-band progress notifications for - // this request (as represented by notifications/progress). The value of this - // parameter is an opaque token that will be attached to any subsequent - // notifications. The receiver is not obligated to provide these notifications. - ProgressToken any `json:"progressToken,omitempty"` - } `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` } +func (x *PingParams) GetMeta() *Meta { return &x.Meta } + // A prompt or prompt template that the server offers. type Prompt struct { // A list of arguments to use for templating the prompt. @@ -340,11 +399,13 @@ type PromptArgument struct { } type PromptListChangedParams struct { - // This parameter name is reserved by MCP to allow clients and servers to attach - // additional metadata to their notifications. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` } +func (x *PromptListChangedParams) GetMeta() *Meta { return &x.Meta } + // Describes a message returned as part of a prompt. // // This is similar to `SamplingMessage`, but also supports the embedding of @@ -355,19 +416,26 @@ type PromptMessage struct { } type ReadResourceParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // The URI of the resource to read. The URI can use any protocol; it is up to // the server how to interpret it. URI string `json:"uri"` } +func (x *ReadResourceParams) GetMeta() *Meta { return &x.Meta } + // The server's response to a resources/read request from the client. type ReadResourceResult struct { - // This result property is reserved by the protocol to allow clients and servers - // to attach additional metadata to their responses. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` - Contents *ResourceContents `json:"contents"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` + Contents *ResourceContents `json:"contents"` } +func (x *ReadResourceResult) GetMeta() *Meta { return &x.Meta } + // A known resource that the server is capable of reading. type Resource struct { // Optional annotations for the client. @@ -394,11 +462,13 @@ type Resource struct { } type ResourceListChangedParams struct { - // This parameter name is reserved by MCP to allow clients and servers to attach - // additional metadata to their notifications. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` } +func (x *ResourceListChangedParams) GetMeta() *Meta { return &x.Meta } + // The sender or recipient of messages and data in a conversation. type Role string @@ -415,11 +485,13 @@ type Root struct { } type RootsListChangedParams struct { - // This parameter name is reserved by MCP to allow clients and servers to attach - // additional metadata to their notifications. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` } +func (x *RootsListChangedParams) GetMeta() *Meta { return &x.Meta } + // Present if the client supports sampling from an LLM. type SamplingCapabilities struct { } @@ -431,12 +503,17 @@ type SamplingMessage struct { } type SetLevelParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` // The level of logging that the client wants to receive from the server. The // server should send all logs at this level and higher (i.e., more severe) to // the client as notifications/message. Level LoggingLevel `json:"level"` } +func (x *SetLevelParams) GetMeta() *Meta { return &x.Meta } + // Definition for a tool the client can call. type Tool struct { // Optional additional tool information. @@ -490,11 +567,13 @@ type ToolAnnotations struct { } type ToolListChangedParams struct { - // This parameter name is reserved by MCP to allow clients and servers to attach - // additional metadata to their notifications. - Meta map[string]json.RawMessage `json:"_meta,omitempty"` + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta Meta `json:"_meta,omitempty"` } +func (x *ToolListChangedParams) GetMeta() *Meta { return &x.Meta } + // Describes the name and version of an MCP implementation. type implementation struct { Name string `json:"name"` diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 18d3038181f..85c31b20a87 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -157,7 +157,7 @@ func (s *Server) RemoveResources(uris ...string) { // changeAndNotify is called when a feature is added or removed. // It calls change, which should do the work and report whether a change actually occurred. // If there was a change, it notifies a snapshot of the sessions. -func (s *Server) changeAndNotify(notification string, params any, change func() bool) { +func (s *Server) changeAndNotify(notification string, params Params, change func() bool) { var sessions []*ServerSession // Lock for the change, but not for the notification. s.mu.Lock() @@ -358,11 +358,11 @@ func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, erro return connect(ctx, t, s) } -func (s *Server) callInitializedHandler(ctx context.Context, ss *ServerSession, params *InitializedParams) (any, error) { +func (s *Server) callInitializedHandler(ctx context.Context, ss *ServerSession, params *InitializedParams) (Result, error) { return callNotificationHandler(ctx, s.opts.InitializedHandler, ss, params) } -func (s *Server) callRootsListChangedHandler(ctx context.Context, ss *ServerSession, params *RootsListChangedParams) (any, error) { +func (s *Server) callRootsListChangedHandler(ctx context.Context, ss *ServerSession, params *RootsListChangedParams) (Result, error) { return callNotificationHandler(ctx, s.opts.RootsListChangedHandler, ss, params) } @@ -519,15 +519,15 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam }, nil } -func (ss *ServerSession) ping(context.Context, *PingParams) (struct{}, error) { - return struct{}{}, nil +func (ss *ServerSession) ping(context.Context, *PingParams) (Result, error) { + return emptyResult{}, nil } -func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (struct{}, error) { +func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (Result, error) { ss.mu.Lock() defer ss.mu.Unlock() ss.logLevel = params.Level - return struct{}{}, nil + return emptyResult{}, nil } // Close performs a graceful shutdown of the connection, preventing new diff --git a/internal/mcp/shared.go b/internal/mcp/shared.go index bb7957e05d8..ce9c93b34ff 100644 --- a/internal/mcp/shared.go +++ b/internal/mcp/shared.go @@ -14,6 +14,7 @@ import ( "encoding/json" "fmt" "log" + "reflect" "slices" "time" @@ -21,12 +22,10 @@ import ( ) // A MethodHandler handles MCP messages. -// The params argument is an XXXParams struct pointer, such as *GetPromptParams. -// For methods, a MethodHandler must return either an XXResult struct pointer and a nil error, or -// nil with a non-nil error. -// For notifications, a MethodHandler must return nil, nil. +// For methods, exactly one of the return values must be nil. +// For notifications, both must be nil. type MethodHandler[S ClientSession | ServerSession] func( - ctx context.Context, _ *S, method string, params any) (result any, err error) + ctx context.Context, _ *S, method string, params Params) (result Result, err error) // Middleware is a function from MethodHandlers to MethodHandlers. type Middleware[S ClientSession | ServerSession] func(MethodHandler[S]) MethodHandler[S] @@ -56,7 +55,7 @@ func toSession[S ClientSession | ServerSession](sess *S) session[S] { } // defaultMethodHandler is the initial MethodHandler for servers and clients, before being wrapped by middleware. -func defaultMethodHandler[S ClientSession | ServerSession](ctx context.Context, sess *S, method string, params any) (any, error) { +func defaultMethodHandler[S ClientSession | ServerSession](ctx context.Context, sess *S, method string, params Params) (Result, error) { info, ok := toSession(sess).methodInfos()[method] if !ok { // This can be called from user code, with an arbitrary value for method. @@ -87,7 +86,7 @@ func handleRequest[S ClientSession | ServerSession](ctx context.Context, req *js // methodInfo is information about invoking a method. type methodInfo[TSession ClientSession | ServerSession] struct { // unmarshal params from the wire into an XXXParams struct - unmarshalParams func(json.RawMessage) (any, error) + unmarshalParams func(json.RawMessage) (Params, error) // run the code for the method handleMethod MethodHandler[TSession] } @@ -99,40 +98,40 @@ type methodInfo[TSession ClientSession | ServerSession] struct { // - R: results // A typedMethodHandler is like a MethodHandler, but with type information. -type typedMethodHandler[S, P, R any] func(context.Context, *S, P) (R, error) +type typedMethodHandler[S any, P Params, R Result] func(context.Context, *S, P) (R, error) // newMethodInfo creates a methodInfo from a typedMethodHandler. -func newMethodInfo[S ClientSession | ServerSession, P, R any](d typedMethodHandler[S, P, R]) methodInfo[S] { +func newMethodInfo[S ClientSession | ServerSession, P Params, R Result](d typedMethodHandler[S, P, R]) methodInfo[S] { return methodInfo[S]{ - unmarshalParams: func(m json.RawMessage) (any, error) { + unmarshalParams: func(m json.RawMessage) (Params, error) { var p P if err := json.Unmarshal(m, &p); err != nil { return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) } return p, nil }, - handleMethod: func(ctx context.Context, ss *S, _ string, params any) (any, error) { + handleMethod: func(ctx context.Context, ss *S, _ string, params Params) (Result, error) { return d(ctx, ss, params.(P)) }, } } // serverMethod is glue for creating a typedMethodHandler from a method on Server. -func serverMethod[P, R any](f func(*Server, context.Context, *ServerSession, P) (R, error)) typedMethodHandler[ServerSession, P, R] { +func serverMethod[P Params, R Result](f func(*Server, context.Context, *ServerSession, P) (R, error)) typedMethodHandler[ServerSession, P, R] { return func(ctx context.Context, ss *ServerSession, p P) (R, error) { return f(ss.server, ctx, ss, p) } } // clientMethod is glue for creating a typedMethodHandler from a method on Server. -func clientMethod[P, R any](f func(*Client, context.Context, *ClientSession, P) (R, error)) typedMethodHandler[ClientSession, P, R] { +func clientMethod[P Params, R Result](f func(*Client, context.Context, *ClientSession, P) (R, error)) typedMethodHandler[ClientSession, P, R] { return func(ctx context.Context, cs *ClientSession, p P) (R, error) { return f(cs.client, ctx, cs, p) } } // sessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. -func sessionMethod[S ClientSession | ServerSession, P, R any](f func(*S, context.Context, P) (R, error)) typedMethodHandler[S, P, R] { +func sessionMethod[S ClientSession | ServerSession, P Params, R Result](f func(*S, context.Context, P) (R, error)) typedMethodHandler[S, P, R] { return func(ctx context.Context, sess *S, p P) (R, error) { return f(sess, ctx, p) } @@ -151,7 +150,7 @@ const ( CodeUnsupportedMethod = -31001 ) -func callNotificationHandler[S ClientSession | ServerSession, P any](ctx context.Context, h func(context.Context, *S, *P), sess *S, params *P) (any, error) { +func callNotificationHandler[S ClientSession | ServerSession, P any](ctx context.Context, h func(context.Context, *S, *P), sess *S, params *P) (Result, error) { if h != nil { h(ctx, sess, params) } @@ -160,7 +159,7 @@ func callNotificationHandler[S ClientSession | ServerSession, P any](ctx context // notifySessions calls Notify on all the sessions. // Should be called on a copy of the peer sessions. -func notifySessions[S ClientSession | ServerSession](sessions []*S, method string, params any) { +func notifySessions[S ClientSession | ServerSession](sessions []*S, method string, params Params) { if sessions == nil { return } @@ -174,3 +173,57 @@ func notifySessions[S ClientSession | ServerSession](sessions []*S, method strin } } } + +func standardCall[TRes, TParams any](ctx context.Context, conn *jsonrpc2.Connection, method string, params TParams) (*TRes, error) { + var result TRes + if err := call(ctx, conn, method, params, &result); err != nil { + return nil, err + } + return &result, nil +} + +type Meta struct { + Data map[string]any `json:",omitempty"` + // For params, the progress token can be nil, a string or an integer. + // It should be nil for results. + ProgressToken any `json:"progressToken,omitempty"` +} + +type metaSansMethods Meta // avoid infinite recursion during marshaling + +func (m Meta) MarshalJSON() ([]byte, error) { + if p := m.ProgressToken; p != nil { + if k := reflect.ValueOf(p).Kind(); k != reflect.Int && k != reflect.String { + return nil, fmt.Errorf("bad type %T for Meta.ProgressToken: must be int or string", p) + } + } + // If ProgressToken is nil, accept Data["progressToken"]. We can't call marshalStructWithMap + // in that case because it will complain about duplicate fields. (We'd have to + // make it much smarter to avoid that; not worth it.) + if m.ProgressToken == nil { + return json.Marshal(m.Data) + } + return marshalStructWithMap((*metaSansMethods)(&m), "Data") +} + +func (m *Meta) UnmarshalJSON(data []byte) error { + return unmarshalStructWithMap(data, (*metaSansMethods)(m), "Data") +} + +// Params is a parameter (input) type for an MCP call or notification. +type Params interface { + // Returns a pointer to the params's Meta field. + GetMeta() *Meta +} + +// Result is a result of an MCP call. +type Result interface { + // Returns a pointer to the result's Meta field. + GetMeta() *Meta +} + +// emptyResult is returned by methods that have no result, like ping. +// Those methods cannot return nil, because jsonrpc2 cannot handle nils. +type emptyResult struct{} + +func (emptyResult) GetMeta() *Meta { panic("should never be called") } diff --git a/internal/mcp/shared_test.go b/internal/mcp/shared_test.go new file mode 100644 index 00000000000..7d3d5334253 --- /dev/null +++ b/internal/mcp/shared_test.go @@ -0,0 +1,70 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestMetaMarshal(t *testing.T) { + // Verify that Meta values round-trip. + for _, meta := range []Meta{ + {Data: nil, ProgressToken: nil}, + {Data: nil, ProgressToken: "p"}, + {Data: map[string]any{"d": true}, ProgressToken: nil}, + {Data: map[string]any{"d": true}, ProgressToken: "p"}, + } { + got := roundTrip(t, meta) + if !cmp.Equal(got, meta) { + t.Errorf("\ngot %#v\nwant %#v", got, meta) + } + } + + // Check errors. + for _, tt := range []struct { + meta Meta + want string + }{ + { + Meta{Data: map[string]any{"progressToken": "p"}, ProgressToken: 1}, + "duplicate", + }, + { + Meta{ProgressToken: true}, + "bad type", + }, + } { + _, err := json.Marshal(tt.meta) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Errorf("%+v: got %v, want error containing %q", tt.meta, err, tt.want) + } + } + + // Accept progressToken in map if the field is nil. + // It will unmarshal by populating ProgressToken. + meta := Meta{Data: map[string]any{"progressToken": "p"}} + got := roundTrip(t, meta) + want := Meta{ProgressToken: "p"} + if !cmp.Equal(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} + +func roundTrip[T any](t *testing.T, v T) T { + t.Helper() + bytes, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + var res T + if err := json.Unmarshal(bytes, &res); err != nil { + t.Fatal(err) + } + return res +} diff --git a/internal/mcp/util.go b/internal/mcp/util.go index 15b3e63d874..3b1f53b124e 100644 --- a/internal/mcp/util.go +++ b/internal/mcp/util.go @@ -6,6 +6,11 @@ package mcp import ( "crypto/rand" + "encoding/json" + "fmt" + "reflect" + "strings" + "sync" ) func assert(cond bool, msg string) { @@ -27,3 +32,135 @@ func randText() string { } return string(src) } + +// marshalStructWithMap marshals its first argument to JSON, treating the field named +// mapField as an embedded map. The first argument must be a pointer to +// a struct. The underlying type of mapField must be a map[string]any, and it must have +// an "omitempty" json tag. +// +// For example, given this struct: +// +// type S struct { +// A int +// Extra map[string] any `json:,omitempty` +// } +// +// and this value: +// +// s := S{A: 1, Extra: map[string]any{"B": 2}} +// +// the call marshalJSONWithMap(s, "Extra") would return +// +// {"A": 1, "B": 2} +// +// It is an error if the map contains the same key as another struct field's +// JSON name. +// +// marshalStructWithMap calls json.Marshal on a value of type T, so T must not +// have a MarshalJSON method that calls this function, on pain of infinite regress. +// +// TODO: avoid this restriction on T by forcing it to marshal in a default way. +// See https://go.dev/play/p/EgXKJHxEx_R. +func marshalStructWithMap[T any](s *T, mapField string) ([]byte, error) { + // Marshal the struct and the map separately, and concatenate the bytes. + // This strategy is dramatically less complicated than + // constructing a synthetic struct or map with the combined keys. + if s == nil { + return []byte("null"), nil + } + s2 := *s + vMapField := reflect.ValueOf(&s2).Elem().FieldByName(mapField) + mapVal := vMapField.Interface().(map[string]any) + + // Check for duplicates. + names := jsonNames(reflect.TypeFor[T]()) + for key := range mapVal { + if names[key] { + return nil, fmt.Errorf("map key %q duplicates struct field", key) + } + } + + // Clear the map field, relying on the omitempty tag to omit it. + vMapField.Set(reflect.Zero(vMapField.Type())) + structBytes, err := json.Marshal(s2) + if err != nil { + return nil, fmt.Errorf("marshalStructWithMap(%+v): %w", s, err) + } + if len(mapVal) == 0 { + return structBytes, nil + } + mapBytes, err := json.Marshal(mapVal) + if err != nil { + return nil, err + } + if len(structBytes) == 2 { // must be "{}" + return mapBytes, nil + } + // "{X}" + "{Y}" => "{X,Y}" + res := append(structBytes[:len(structBytes)-1], ',') + res = append(res, mapBytes[1:]...) + return res, nil +} + +// unmarshalStructWithMap is the inverse of marshalStructWithMap. +// T has the same restrictions as in that function. +func unmarshalStructWithMap[T any](data []byte, v *T, mapField string) error { + // Unmarshal into the struct, ignoring unknown fields. + if err := json.Unmarshal(data, v); err != nil { + return err + } + // Unmarshal into the map. + m := map[string]any{} + if err := json.Unmarshal(data, &m); err != nil { + return err + } + // Delete from the map the fields of the struct. + for n := range jsonNames(reflect.TypeFor[T]()) { + delete(m, n) + } + if len(m) != 0 { + reflect.ValueOf(v).Elem().FieldByName(mapField).Set(reflect.ValueOf(m)) + } + return nil +} + +var jsonNamesMap sync.Map // from reflect.Type to map[string]bool + +// jsonNames returns the set of JSON object keys that t will marshal into. +// t must be a struct type. +func jsonNames(t reflect.Type) map[string]bool { + // Lock not necessary: at worst we'll duplicate work. + if val, ok := jsonNamesMap.Load(t); ok { + return val.(map[string]bool) + } + m := map[string]bool{} + for i := range t.NumField() { + if n, ok := jsonName(t.Field(i)); ok { + m[n] = true + } + } + jsonNamesMap.Store(t, m) + return m +} + +// jsonName returns the name for f as would be used by [json.Marshal]. +// That is the name in the json struct tag, or the field name if there is no tag. +// If f is not exported or the tag is "-", jsonName returns "", false. +// +// Copied from jsonschema/validate.go. +func jsonName(f reflect.StructField) (string, bool) { + if !f.IsExported() { + return "", false + } + if tag, ok := f.Tag.Lookup("json"); ok { + name, _, found := strings.Cut(tag, ",") + // "-" means omit, but "-," means the name is "-" + if name == "-" && !found { + return "", false + } + if name != "" { + return name, true + } + } + return f.Name, true +} diff --git a/internal/mcp/util_test.go b/internal/mcp/util_test.go new file mode 100644 index 00000000000..e7b727cc396 --- /dev/null +++ b/internal/mcp/util_test.go @@ -0,0 +1,48 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestMarshalStructWithMap(t *testing.T) { + type S struct { + A int + B string `json:"b,omitempty"` + u bool + M map[string]any `json:",omitempty"` + } + t.Run("basic", func(t *testing.T) { + s := S{A: 1, B: "two", M: map[string]any{"!@#": true}} + got, err := marshalStructWithMap(&s, "M") + if err != nil { + t.Fatal(err) + } + want := `{"A":1,"b":"two","!@#":true}` + if g := string(got); g != want { + t.Errorf("\ngot %s\nwant %s", g, want) + } + + var un S + if err := unmarshalStructWithMap(got, &un, "M"); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(s, un, cmpopts.IgnoreUnexported(S{})); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } + }) + t.Run("duplicate", func(t *testing.T) { + s := S{A: 1, B: "two", M: map[string]any{"b": "dup"}} + _, err := marshalStructWithMap(&s, "M") + if err == nil || !strings.Contains(err.Error(), "duplicate") { + t.Errorf("got %v, want error with 'duplicate'", err) + } + }) +} From 82fa2c075907fcbbe59ae75d5533a8d1873e0729 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 26 May 2025 07:49:23 -0400 Subject: [PATCH 133/196] internal/mcp/jsonschema: clearer validations errors Validation errors now more clearly indicate the path from the root schema. - Enhance schema iteration to yield the path, and generalize it by using reflection. - Store each schema's path inside the schema. - Remove the path arg from validate: we don't need it anymore. - Improve the test for the iterators and add one for validation errors. Change-Id: Ief55133ef4ae2eac0736208a3b9f78ed69933b31 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676375 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI Auto-Submit: Jonathan Amsterdam --- internal/mcp/jsonschema/json_pointer.go | 30 +++--- internal/mcp/jsonschema/resolve.go | 14 ++- internal/mcp/jsonschema/schema.go | 118 +++++++++++++++-------- internal/mcp/jsonschema/schema_test.go | 52 ++++++++-- internal/mcp/jsonschema/validate.go | 62 ++++-------- internal/mcp/jsonschema/validate_test.go | 15 +++ 6 files changed, 185 insertions(+), 106 deletions(-) diff --git a/internal/mcp/jsonschema/json_pointer.go b/internal/mcp/jsonschema/json_pointer.go index 687743ffbae..bcdc6b76ccd 100644 --- a/internal/mcp/jsonschema/json_pointer.go +++ b/internal/mcp/jsonschema/json_pointer.go @@ -28,7 +28,18 @@ import ( "strings" ) -var jsonPointerReplacer = strings.NewReplacer("~0", "~", "~1", "/") +var ( + jsonPointerEscaper = strings.NewReplacer("~", "~0", "/", "~1") + jsonPointerUnescaper = strings.NewReplacer("~0", "~", "~1", "/") +) + +func escapeJSONPointerSegment(s string) string { + return jsonPointerEscaper.Replace(s) +} + +func unescapeJSONPointerSegment(s string) string { + return jsonPointerUnescaper.Replace(s) +} // parseJSONPointer splits a JSON Pointer into a sequence of segments. It doesn't // convert strings to numbers, because that depends on the traversal: a segment @@ -47,7 +58,7 @@ func parseJSONPointer(ptr string) (segments []string, err error) { if strings.Contains(ptr, "~") { // Undo the simple escaping rules that allow one to include a slash in a segment. for i := range segments { - segments[i] = jsonPointerReplacer.Replace(segments[i]) + segments[i] = unescapeJSONPointerSegment(segments[i]) } } return segments, nil @@ -121,17 +132,6 @@ func dereferenceJSONPointer(s *Schema, sptr string) (_ *Schema, err error) { return nil, fmt.Errorf("does not refer to a schema, but to a %s", v.Type()) } -// map from JSON names for fields in a Schema to their indexes in the struct. -var schemaFields = map[string][]int{} - -func init() { - for _, f := range reflect.VisibleFields(reflect.TypeFor[Schema]()) { - if name, ok := jsonName(f); ok { - schemaFields[name] = f.Index - } - } -} - // lookupSchemaField returns the value of the field with the given name in v, // or the zero value if there is no such field or it is not of type Schema or *Schema. func lookupSchemaField(v reflect.Value, name string) reflect.Value { @@ -143,8 +143,8 @@ func lookupSchemaField(v reflect.Value, name string) reflect.Value { } return v.FieldByName("Types") } - if index := schemaFields[name]; index != nil { - return v.FieldByIndex(index) + if sf, ok := schemaFieldMap[name]; ok { + return v.FieldByIndex(sf.Index) } return reflect.Value{} } diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index fa3e3d2ad50..63a6df3a69b 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -118,7 +118,7 @@ func (root *Schema) check() error { report := func(err error) { errs = append(errs, err) } seen := map[*Schema]bool{} - for ss := range root.all() { + for ss, path := range root.all() { if seen[ss] { // The schema graph rooted at s is not a tree, but it needs to // be because we assume a unique parent when we store a schema's base @@ -127,6 +127,11 @@ func (root *Schema) check() error { return fmt.Errorf("schemas rooted at %s do not form a tree (saw %s twice)", root, ss) } seen[ss] = true + if len(path) == 0 { + ss.path = "root" + } else { + ss.path = "/" + strings.Join(path, "/") + } ss.checkLocal(report) } return errors.Join(errs...) @@ -138,7 +143,8 @@ func (root *Schema) check() error { // It appends the errors it finds to errs. func (s *Schema) checkLocal(report func(error)) { addf := func(format string, args ...any) { - report(fmt.Errorf("jsonschema.Schema: "+format, args...)) + msg := fmt.Sprintf(format, args...) + report(fmt.Errorf("jsonschema.Schema: %s: %s", s.path, msg)) } if s == nil { @@ -165,7 +171,7 @@ func (s *Schema) checkLocal(report func(error)) { if s.Pattern != "" { re, err := regexp.Compile(s.Pattern) if err != nil { - addf("pattern: %w", err) + addf("pattern: %v", err) } else { s.pattern = re } @@ -175,7 +181,7 @@ func (s *Schema) checkLocal(report func(error)) { for reString, subschema := range s.PatternProperties { re, err := regexp.Compile(reString) if err != nil { - addf("patternProperties[%q]: %w", reString, err) + addf("patternProperties[%q]: %v", reString, err) continue } s.patternProperties[re] = subschema diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 0cf9d4d4b7a..245a4ac685f 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -11,9 +11,13 @@ import ( "errors" "fmt" "iter" + "maps" "math" "net/url" + "reflect" "regexp" + "slices" + "strconv" ) // A Schema is a JSON schema object. @@ -139,6 +143,10 @@ type Schema struct { // s.base == s <=> s.uri != nil uri *url.URL + // The JSON Pointer path from the root schema to here. + // Used in errors. + path string + // The schema to which Ref refers. resolvedRef *Schema @@ -180,7 +188,9 @@ func (s *Schema) String() string { if a := cmp.Or(s.Anchor, s.DynamicAnchor); a != "" { return fmt.Sprintf("%q, anchor %s", s.base.uri.String(), a) } - // TODO: return something better, like a JSON Pointer from the base. + if s.path != "" { + return s.path + } return "" } @@ -192,15 +202,6 @@ func (s *Schema) ResolvedRef() *Schema { return s.resolvedRef } -// json returns the schema in json format. -func (s *Schema) json() string { - data, err := json.MarshalIndent(s, "", " ") - if err != nil { - return fmt.Sprintf("", err) - } - return string(data) -} - func (s *Schema) basicChecks() error { if s.Type != "" && s.Types != nil { return errors.New("both Type and Types are set; at most one should be") @@ -358,43 +359,57 @@ func (ip *integer) UnmarshalJSON(data []byte) error { func Ptr[T any](x T) *T { return &x } // every applies f preorder to every schema under s including s. +// The second argument to f is the path to the schema appended to the argument path. // It stops when f returns false. -func (s *Schema) every(f func(*Schema) bool) bool { +func (s *Schema) every(f func(*Schema, []string) bool, path []string) bool { return s == nil || - f(s) && s.everyChild(func(s *Schema) bool { return s.every(f) }) + f(s, path) && s.everyChild(func(s *Schema, p []string) bool { return s.every(f, p) }, path) } // everyChild reports whether f is true for every immediate child schema of s. +// The second argument to f is the path to the schema appended to the argument path. // // It does not call f on nil-valued fields holding individual schemas, like Contains, // because a nil value indicates that the field is absent. // It does call f on nils when they occur in slices and maps, so those invalid values // can be detected when the schema is validated. -func (s *Schema) everyChild(f func(*Schema) bool) bool { - // Fields that contain individual schemas. A nil is valid: it just means the field isn't present. - for _, c := range []*Schema{ - s.Items, s.AdditionalItems, s.Contains, s.PropertyNames, s.AdditionalProperties, - s.If, s.Then, s.Else, s.Not, s.UnevaluatedItems, s.UnevaluatedProperties, - } { - if c != nil && !f(c) { - return false - } +func (s *Schema) everyChild(f func(*Schema, []string) bool, path []string) bool { + if s == nil { + return false } - // Fields that contain slices of schemas. Yield nils so we can check for their presence. - for _, sl := range [][]*Schema{s.PrefixItems, s.AllOf, s.AnyOf, s.OneOf} { - for _, c := range sl { - if !f(c) { + var ( + schemaType = reflect.TypeFor[*Schema]() + schemaSliceType = reflect.TypeFor[[]*Schema]() + schemaMapType = reflect.TypeFor[map[string]*Schema]() + ) + v := reflect.ValueOf(s) + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + // A field that contains an individual schema. A nil is valid: it just means the field isn't present. + c := fv.Interface().(*Schema) + if c != nil && !f(c, append(path, info.jsonName)) { return false } - } - } - // Fields that are maps of schemas. Ditto about nils. - for _, m := range []map[string]*Schema{ - s.Defs, s.Definitions, s.Properties, s.PatternProperties, s.DependentSchemas, - } { - for _, c := range m { - if !f(c) { - return false + + case schemaSliceType: + slice := fv.Interface().([]*Schema) + // A field that contains a slice of schemas. Yield nils so we can check for their presence. + for i, c := range slice { + if !f(c, append(path, info.jsonName, strconv.Itoa(i))) { + return false + } + } + + case schemaMapType: + // A field that is a map of schemas. Ditto about nils. + // Sort keys for determinism. + m := fv.Interface().(map[string]*Schema) + for _, k := range slices.Sorted(maps.Keys(m)) { + if !f(m[k], append(path, info.jsonName, escapeJSONPointerSegment(k))) { + return false + } } } } @@ -402,11 +417,38 @@ func (s *Schema) everyChild(f func(*Schema) bool) bool { } // all wraps every in an iterator. -func (s *Schema) all() iter.Seq[*Schema] { - return func(yield func(*Schema) bool) { s.every(yield) } +func (s *Schema) all() iter.Seq2[*Schema, []string] { + return func(yield func(*Schema, []string) bool) { s.every(yield, nil) } } // children wraps everyChild in an iterator. -func (s *Schema) children() iter.Seq[*Schema] { - return func(yield func(*Schema) bool) { s.everyChild(yield) } +func (s *Schema) children() iter.Seq2[*Schema, []string] { + var pathBuffer [4]string + return func(yield func(*Schema, []string) bool) { s.everyChild(yield, pathBuffer[:0]) } +} + +type structFieldInfo struct { + sf reflect.StructField + jsonName string +} + +var ( + // the visible fields of Schema that have a JSON name, sorted by that name + schemaFieldInfos []structFieldInfo + // map from JSON name to field + schemaFieldMap = map[string]reflect.StructField{} +) + +func init() { + for _, sf := range reflect.VisibleFields(reflect.TypeFor[Schema]()) { + if name, ok := jsonName(sf); ok { + schemaFieldInfos = append(schemaFieldInfos, structFieldInfo{sf, name}) + } + } + slices.SortFunc(schemaFieldInfos, func(i1, i2 structFieldInfo) int { + return cmp.Compare(i1.jsonName, i2.jsonName) + }) + for _, info := range schemaFieldInfos { + schemaFieldMap[info.jsonName] = info.sf + } } diff --git a/internal/mcp/jsonschema/schema_test.go b/internal/mcp/jsonschema/schema_test.go index 8394bb587fa..921656cf30f 100644 --- a/internal/mcp/jsonschema/schema_test.go +++ b/internal/mcp/jsonschema/schema_test.go @@ -9,6 +9,8 @@ import ( "fmt" "math" "regexp" + "slices" + "strings" "testing" ) @@ -112,15 +114,31 @@ func TestUnmarshalErrors(t *testing.T) { func TestEvery(t *testing.T) { // Schema.every should visit all descendants of a schema, not just the immediate ones. s := &Schema{ - Items: &Schema{ - Items: &Schema{}, - }, + Type: "string", + PrefixItems: []*Schema{{Type: "int"}, {Items: &Schema{Type: "null"}}}, + Contains: &Schema{Properties: map[string]*Schema{ + "~1": {Type: "boolean"}, + "p": {}, + }}, + } + + type item struct { + s *Schema + p string } - want := 3 - got := 0 - s.every(func(*Schema) bool { got++; return true }) - if got != want { - t.Errorf("got %d, want %d", got, want) + want := []item{ + {s, ""}, + {s.Contains, "contains"}, + {s.Contains.Properties["p"], "contains/properties/p"}, + {s.Contains.Properties["~1"], "contains/properties/~01"}, + {s.PrefixItems[0], "prefixItems/0"}, + {s.PrefixItems[1], "prefixItems/1"}, + {s.PrefixItems[1].Items, "prefixItems/1/items"}, + } + var got []item + s.every(func(s *Schema, p []string) bool { got = append(got, item{s, strings.Join(p, "/")}); return true }, nil) + if !slices.Equal(got, want) { + t.Errorf("\n got %v\nwant %v", got, want) } } @@ -130,3 +148,21 @@ func mustUnmarshal(t *testing.T, data []byte, ptr any) { t.Fatal(err) } } + +// json returns the schema in json format. +func (s *Schema) json() string { + data, err := json.Marshal(s) + if err != nil { + return fmt.Sprintf("", err) + } + return string(data) +} + +// json returns the schema in json format, indented. +func (s *Schema) jsonIndent() string { + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return fmt.Sprintf("", err) + } + return string(data) +} diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index 1bc58dfc116..5e82bc58deb 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -27,8 +27,7 @@ func (rs *Resolved) Validate(instance any) error { return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) } st := &state{rs: rs} - var pathBuffer [4]any - return st.validate(reflect.ValueOf(instance), st.rs.root, nil, pathBuffer[:0]) + return st.validate(reflect.ValueOf(instance), st.rs.root, nil) } // state is the state of single call to ResolvedSchema.Validate. @@ -42,13 +41,10 @@ type state struct { } // validate validates the reflected value of the instance. -// It keeps track of the path within the instance for better error messages. -func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations, path []any) (err error) { +func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations) (err error) { defer func() { if err != nil { - if p := formatPath(path); p != "" { - err = fmt.Errorf("%s: %w", p, err) - } + err = fmt.Errorf("%s: %w", schema, err) } }() @@ -72,7 +68,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an if schema.Type != "" || schema.Types != nil { gotType, ok := jsonType(instance) if !ok { - return fmt.Errorf("%v of type %[1]T is not a valid JSON value", instance) + return fmt.Errorf("type: %v of type %[1]T is not a valid JSON value", instance) } if schema.Type != "" { // "number" subsumes integers @@ -163,7 +159,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // $ref: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.1 if schema.Ref != "" { - if err := st.validate(instance, schema.resolvedRef, &anns, path); err != nil { + if err := st.validate(instance, schema.resolvedRef, &anns); err != nil { return err } } @@ -175,7 +171,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an "DynamicRef not resolved properly") if schema.resolvedDynamicRef != nil { // Same as $ref. - if err := st.validate(instance, schema.resolvedDynamicRef, &anns, path); err != nil { + if err := st.validate(instance, schema.resolvedDynamicRef, &anns); err != nil { return err } } else { @@ -199,7 +195,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an if dynamicSchema == nil { return fmt.Errorf("missing dynamic anchor %q", schema.dynamicRefAnchor) } - if err := st.validate(instance, dynamicSchema, &anns, path); err != nil { + if err := st.validate(instance, dynamicSchema, &anns); err != nil { return err } } @@ -214,11 +210,11 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // If any of these fail, then validation fails, even if there is an unevaluatedXXX // keyword in the schema. The spec is unclear about this, but that is the intention. - valid := func(s *Schema, anns *annotations) bool { return st.validate(instance, s, anns, path) == nil } + valid := func(s *Schema, anns *annotations) bool { return st.validate(instance, s, anns) == nil } if schema.AllOf != nil { for _, ss := range schema.AllOf { - if err := st.validate(instance, ss, &anns, path); err != nil { + if err := st.validate(instance, ss, &anns); err != nil { return err } } @@ -264,7 +260,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an ss = schema.Else } if ss != nil { - if err := st.validate(instance, ss, &anns, path); err != nil { + if err := st.validate(instance, ss, &anns); err != nil { return err } } @@ -281,7 +277,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an if i >= instance.Len() { break // shorter is OK } - if err := st.validate(instance.Index(i), ischema, nil, append(path, i)); err != nil { + if err := st.validate(instance.Index(i), ischema, nil); err != nil { return err } } @@ -289,7 +285,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an if schema.Items != nil { for i := len(schema.PrefixItems); i < instance.Len(); i++ { - if err := st.validate(instance.Index(i), schema.Items, nil, append(path, i)); err != nil { + if err := st.validate(instance.Index(i), schema.Items, nil); err != nil { return err } } @@ -300,14 +296,13 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an nContains := 0 if schema.Contains != nil { for i := range instance.Len() { - if err := st.validate(instance.Index(i), schema.Contains, nil, append(path, i)); err == nil { + if err := st.validate(instance.Index(i), schema.Contains, nil); err == nil { nContains++ anns.noteIndex(i) } } if nContains == 0 && (schema.MinContains == nil || *schema.MinContains > 0) { - return fmt.Errorf("contains: %s does not have an item matching %s", - instance, schema.Contains) + return fmt.Errorf("contains: %s does not have an item matching %s", instance, schema.Contains) } } @@ -366,7 +361,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // That includes validations by subschemas on the same instance, like allOf. for i := anns.endIndex; i < instance.Len(); i++ { if !anns.evaluatedIndexes[i] { - if err := st.validate(instance.Index(i), schema.UnevaluatedItems, nil, append(path, i)); err != nil { + if err := st.validate(instance.Index(i), schema.UnevaluatedItems, nil); err != nil { return err } } @@ -399,7 +394,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an if instance.Kind() == reflect.Struct && val.IsZero() && !schema.isRequired[prop] { continue } - if err := st.validate(val, subschema, nil, append(path, prop)); err != nil { + if err := st.validate(val, subschema, nil); err != nil { return err } evalProps[prop] = true @@ -409,7 +404,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // Check every matching pattern. for re, schema := range schema.patternProperties { if re.MatchString(prop) { - if err := st.validate(val, schema, nil, append(path, prop)); err != nil { + if err := st.validate(val, schema, nil); err != nil { return err } evalProps[prop] = true @@ -421,7 +416,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // Apply to all properties not handled above. for prop, val := range properties(instance) { if !evalProps[prop] { - if err := st.validate(val, schema.AdditionalProperties, nil, append(path, prop)); err != nil { + if err := st.validate(val, schema.AdditionalProperties, nil); err != nil { return err } evalProps[prop] = true @@ -433,7 +428,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // Note: properties unnecessarily fetches each value. We could define a propertyNames function // if performance ever matters. for prop := range properties(instance) { - if err := st.validate(reflect.ValueOf(prop), schema.PropertyNames, nil, append(path, prop)); err != nil { + if err := st.validate(reflect.ValueOf(prop), schema.PropertyNames, nil); err != nil { return err } } @@ -493,7 +488,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an for dprop, ss := range schema.DependentSchemas { if hasProperty(dprop) { // TODO: include dependentSchemas[dprop] in the errors. - err := st.validate(instance, ss, &anns, path) + err := st.validate(instance, ss, &anns) if err != nil { return err } @@ -505,7 +500,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // in addition to sibling keywords. for prop, val := range properties(instance) { if !anns.evaluatedProperties[prop] { - if err := st.validate(val, schema.UnevaluatedProperties, nil, append(path, prop)); err != nil { + if err := st.validate(val, schema.UnevaluatedProperties, nil); err != nil { return err } } @@ -633,18 +628,3 @@ func jsonName(f reflect.StructField) (string, bool) { } return f.Name, true } - -func formatPath(path []any) string { - var b strings.Builder - for i, p := range path { - if n, ok := p.(int); ok { - fmt.Fprintf(&b, "[%d]", n) - } else { - if i > 0 { - b.WriteByte('.') - } - fmt.Fprintf(&b, "%q", p) - } - } - return b.String() -} diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index b5d75438e17..89cc2475b22 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -78,6 +78,21 @@ func TestValidate(t *testing.T) { } } +func TestValidateErrors(t *testing.T) { + schema := &Schema{ + PrefixItems: []*Schema{{Contains: &Schema{Type: "integer"}}}, + } + rs, err := schema.Resolve("", nil) + if err != nil { + t.Fatal(err) + } + err = rs.Validate([]any{[]any{"1"}}) + want := "prefixItems/0" + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("error:\n%s\ndoes not contain %q", err, want) + } +} + func TestStructInstance(t *testing.T) { instance := struct { I int From cc4b6feda09a4e5b155978ac5b125f65bacc6930 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 27 May 2025 07:32:06 -0400 Subject: [PATCH 134/196] internal/mcp/jsonschema: generalize error wrapping Add a wrapf function to simplify defers that wrap errors. Change-Id: If4bedec7d3f56f609d3ca4620aebe7637a59f6bf Reviewed-on: https://go-review.googlesource.com/c/tools/+/676475 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- internal/mcp/jsonschema/json_pointer.go | 6 +----- internal/mcp/jsonschema/util.go | 7 +++++++ internal/mcp/jsonschema/validate.go | 6 +----- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/internal/mcp/jsonschema/json_pointer.go b/internal/mcp/jsonschema/json_pointer.go index bcdc6b76ccd..0769a4c8d9a 100644 --- a/internal/mcp/jsonschema/json_pointer.go +++ b/internal/mcp/jsonschema/json_pointer.go @@ -69,11 +69,7 @@ func parseJSONPointer(ptr string) (segments []string, err error) { // This implementation suffices for JSON Schema: pointers are applied only to Schemas, // and refer only to Schemas. func dereferenceJSONPointer(s *Schema, sptr string) (_ *Schema, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("JSON Pointer %q: %w", sptr, err) - } - }() + defer wrapf(&err, "JSON Pointer %q", sptr) segments, err := parseJSONPointer(sptr) if err != nil { diff --git a/internal/mcp/jsonschema/util.go b/internal/mcp/jsonschema/util.go index 7e07345f8cc..58c11ff1df7 100644 --- a/internal/mcp/jsonschema/util.go +++ b/internal/mcp/jsonschema/util.go @@ -282,3 +282,10 @@ func assert(cond bool, msg string) { panic("assertion failed: " + msg) } } + +// wrapf wraps *errp with the given formatted message if *errp is not nil. +func wrapf(errp *error, format string, args ...any) { + if *errp != nil { + *errp = fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), *errp) + } +} diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index 5e82bc58deb..510dc73a80e 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -42,11 +42,7 @@ type state struct { // validate validates the reflected value of the instance. func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations) (err error) { - defer func() { - if err != nil { - err = fmt.Errorf("%s: %w", schema, err) - } - }() + defer wrapf(&err, "validating %s", schema) st.stack = append(st.stack, schema) // push defer func() { From 15e680e6b2582a70f84e48072f7027d47bf5cb6e Mon Sep 17 00:00:00 2001 From: Peter Weinberger Date: Thu, 15 May 2025 08:47:17 -0400 Subject: [PATCH 135/196] gopls/.../completion: unimported completion snippets Add code that produces snippets (if the user wants them) for functions from the standard library. Change-Id: I8e50ff8dc2ff3f9423adde3acf1d2b986cb9cc0c Reviewed-on: https://go-review.googlesource.com/c/tools/+/673175 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- .../internal/golang/completion/unimported.go | 46 +++++++++++++++++-- .../integration/completion/completion_test.go | 11 +++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/gopls/internal/golang/completion/unimported.go b/gopls/internal/golang/completion/unimported.go index 87c059697f3..9b0d20b08b0 100644 --- a/gopls/internal/golang/completion/unimported.go +++ b/gopls/internal/golang/completion/unimported.go @@ -203,8 +203,10 @@ func (c *completer) stdlibMatches(pkgs []metadata.PackagePath, pkg metadata.Pack } var kind protocol.CompletionItemKind var detail string + var params []string switch sym.Kind { case stdlib.Func: + params = parseSignature(sym.Signature) kind = protocol.FunctionCompletion detail = fmt.Sprintf("func (from %q)", candpkg) case stdlib.Const: @@ -220,12 +222,10 @@ func (c *completer) stdlibMatches(pkgs []metadata.PackagePath, pkg metadata.Pack continue } got = c.appendNewItem(got, sym.Name, - //fmt.Sprintf("(from %q)", candpkg), candpkg, detail, candpkg, - //convKind(sym.Kind), kind, - pkg, nil) + pkg, params) } } } @@ -369,3 +369,43 @@ func funcParams(f *ast.File, fname string) []string { } return params } + +// extract the formal parameters from the signature. +// func[M1 ~map[K]V, M2 ~map[K]V, K comparable, V any](dst M1, src M2) -> []{"dst M1", "src M2"} +// func[K comparable, V any](seq iter.Seq2[K, V]) map[K]V -> []{"seq iter.Seq2[K, V]"} +// func(args ...any) *Logger -> []{"args ...any"} +// func[M ~map[K]V, K comparable, V any](m M, del func(K, V) bool) -> []{"m M", "del func(K, V) bool"} +func parseSignature(sig string) []string { + var level int // nesting level of delimiters + var processing bool // are we doing the params + var last int // start of current parameter + var params []string + for i := range len(sig) { + switch sig[i] { + case '[', '{': + level++ + case ']', '}': + level-- + case '(': + level++ + if level == 1 { + processing = true + last = i + 1 + } + case ')': + level-- + if level == 0 && processing { // done + if i > last { + params = append(params, strings.TrimSpace(sig[last:i])) + } + return params + } + case ',': + if level == 1 && processing { + params = append(params, strings.TrimSpace(sig[last:i])) + last = i + 1 + } + } + } + return nil +} diff --git a/gopls/internal/test/integration/completion/completion_test.go b/gopls/internal/test/integration/completion/completion_test.go index 59f10f8dff0..d6e00055879 100644 --- a/gopls/internal/test/integration/completion/completion_test.go +++ b/gopls/internal/test/integration/completion/completion_test.go @@ -1339,6 +1339,7 @@ func _() { } // Fix for golang/go#60062: unimported completion included "golang.org/toolchain" results. +// and check that functions (from the standard library) have snippets func TestToolchainCompletions(t *testing.T) { const files = ` -- go.mod -- @@ -1385,6 +1386,16 @@ func Join() {} if strings.Contains(item.Detail, "golang.org/toolchain") { t.Errorf("Completion(...) returned toolchain item %#v", item) } + if strings.HasPrefix(item.Detail, "func") { + // check that there are snippets + x, ok := item.TextEdit.Value.(protocol.InsertReplaceEdit) + if !ok { + t.Errorf("item.TextEdit.Value unexpected type %T", item.TextEdit.Value) + } + if !strings.Contains(x.NewText, "${1") { + t.Errorf("expected snippet in %q", x.NewText) + } + } } } }) From 4354923be0c02b493588d8a5806b11786e77a5e9 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Tue, 27 May 2025 19:44:25 +0000 Subject: [PATCH 136/196] internal/mcp: add pagination for prompts Change-Id: Ib231427299dbfe807fb80ad8da6ba16d4a8fde03 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676615 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/mcp/server.go | 11 ++++++++++- internal/mcp/server_example_test.go | 22 +++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 85c31b20a87..fda8bf7a9b0 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -179,8 +179,17 @@ func (s *Server) Sessions() iter.Seq[*ServerSession] { func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPromptsParams) (*ListPromptsResult, error) { s.mu.Lock() defer s.mu.Unlock() + var cursor string + if params != nil { + cursor = params.Cursor + } + prompts, nextCursor, err := paginateList(s.prompts, cursor, s.opts.PageSize) + if err != nil { + return nil, err + } res := new(ListPromptsResult) - for p := range s.prompts.all() { + res.NextCursor = nextCursor + for _, p := range prompts { res.Prompts = append(res.Prompts, p.Prompt) } return res, nil diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index 913a41cff20..c17dafc814b 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -72,7 +72,7 @@ func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession return clientSession, serverSession, server } -func TestListTool(t *testing.T) { +func TestListTools(t *testing.T) { toolA := mcp.NewTool("apple", "apple tool", SayHi) toolB := mcp.NewTool("banana", "banana tool", SayHi) toolC := mcp.NewTool("cherry", "cherry tool", SayHi) @@ -111,3 +111,23 @@ func TestListResources(t *testing.T) { t.Fatalf("ListResources() mismatch (-want +got):\n%s", diff) } } + +func TestListPrompts(t *testing.T) { + promptA := mcp.NewPrompt("apple", "apple prompt", testPromptHandler[struct{}]) + promptB := mcp.NewPrompt("banana", "banana prompt", testPromptHandler[struct{}]) + promptC := mcp.NewPrompt("cherry", "cherry prompt", testPromptHandler[struct{}]) + prompts := []*mcp.ServerPrompt{promptA, promptB, promptC} + wantPrompts := []*mcp.Prompt{promptA.Prompt, promptB.Prompt, promptC.Prompt} + ctx := context.Background() + clientSession, serverSession, server := createSessions(ctx) + defer clientSession.Close() + defer serverSession.Close() + server.AddPrompts(prompts...) + res, err := clientSession.ListPrompts(ctx, nil) + if err != nil { + t.Fatal("ListPrompts() failed:", err) + } + if diff := cmp.Diff(wantPrompts, res.Prompts, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("ListPrompts() mismatch (-want +got):\n%s", diff) + } +} From 50c5c27ce8d97a276f5b75881ecac3be775b04a0 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 26 May 2025 13:12:35 -0400 Subject: [PATCH 137/196] internal/mcp/jsonschema: check for structure separately Redo iteration. It was too subtle. Write a separate recursive check for whether the schemas form a tree and have any invalid nils. While doing that, assign paths to schemas for better error messages, even as we check structure. This also lets us detect re-occurring schemas without a separate map. Everything else, including the iterators, can now assume valid structure and populated paths. Change-Id: I90ec93d79d1d95c92258ad7afc94ef068fdf123f Reviewed-on: https://go-review.googlesource.com/c/tools/+/676376 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan Auto-Submit: Jonathan Amsterdam --- internal/mcp/jsonschema/resolve.go | 80 ++++++++++++++++++++----- internal/mcp/jsonschema/resolve_test.go | 73 +++++++++++++++++----- internal/mcp/jsonschema/schema.go | 47 ++++++--------- internal/mcp/jsonschema/schema_test.go | 33 ---------- internal/mcp/jsonschema/validate.go | 6 +- 5 files changed, 143 insertions(+), 96 deletions(-) diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index 63a6df3a69b..a2257b8de17 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "net/url" + "reflect" "regexp" "strings" ) @@ -111,30 +112,81 @@ func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { } func (root *Schema) check() error { - if root == nil { - return errors.New("nil schema") + // Check for structural validity. Do this first and fail fast: + // bad structure will cause other code to panic. + if err := root.checkStructure(); err != nil { + return err } + var errs []error report := func(err error) { errs = append(errs, err) } - seen := map[*Schema]bool{} - for ss, path := range root.all() { - if seen[ss] { - // The schema graph rooted at s is not a tree, but it needs to + for ss := range root.all() { + ss.checkLocal(report) + } + return errors.Join(errs...) +} + +// checkStructure verifies that root and its subschemas form a tree. +// It also assigns each schema a unique path, to improve error messages. +func (root *Schema) checkStructure() error { + var check func(reflect.Value, []byte) error + check = func(v reflect.Value, path []byte) error { + // For the purpose of error messages, the root schema has path "root" + // and other schemas' paths are their JSON Pointer from the root. + p := "root" + if len(path) > 0 { + p = string(path) + } + s := v.Interface().(*Schema) + if s == nil { + return fmt.Errorf("jsonschema: schema at %s is nil", p) + } + if s.path != "" { + // We've seen s before. + // The schema graph at root is not a tree, but it needs to // be because we assume a unique parent when we store a schema's base // in the Schema. A cycle would also put Schema.all into an infinite // recursion. - return fmt.Errorf("schemas rooted at %s do not form a tree (saw %s twice)", root, ss) + return fmt.Errorf("jsonschema: schemas at %s do not form a tree; %s appears more than once (also at %s)", + root, s.path, p) } - seen[ss] = true - if len(path) == 0 { - ss.path = "root" - } else { - ss.path = "/" + strings.Join(path, "/") + s.path = p + + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + // A field that contains an individual schema. + // A nil is valid: it just means the field isn't present. + if !fv.IsNil() { + if err := check(fv, fmt.Appendf(path, "/%s", info.jsonName)); err != nil { + return err + } + } + + case schemaSliceType: + for i := range fv.Len() { + if err := check(fv.Index(i), fmt.Appendf(path, "/%s/%d", info.jsonName, i)); err != nil { + return err + } + } + + case schemaMapType: + iter := fv.MapRange() + for iter.Next() { + key := escapeJSONPointerSegment(iter.Key().String()) + if err := check(iter.Value(), fmt.Appendf(path, "/%s/%s", info.jsonName, key)); err != nil { + return err + } + } + } + } - ss.checkLocal(report) + return nil } - return errors.Join(errs...) + + return check(reflect.ValueOf(root), make([]byte, 0, 256)) } // checkLocal checks s for validity, independently of other schemas it may refer to. diff --git a/internal/mcp/jsonschema/resolve_test.go b/internal/mcp/jsonschema/resolve_test.go index 2b60759cf7d..5621e38eb3f 100644 --- a/internal/mcp/jsonschema/resolve_test.go +++ b/internal/mcp/jsonschema/resolve_test.go @@ -14,18 +14,41 @@ import ( "testing" ) +func TestSchemaStructure(t *testing.T) { + check := func(s *Schema, want string) { + t.Helper() + err := s.checkStructure() + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("checkStructure returned error %q, want %q", err, want) + } + } + + dag := &Schema{Type: "number"} + dag = &Schema{Items: dag, Contains: dag} + check(dag, "do not form a tree") + + tree := &Schema{Type: "number"} + tree.Items = tree + check(tree, "do not form a tree") + + sliceNil := &Schema{PrefixItems: []*Schema{nil}} + check(sliceNil, "is nil") + + sliceMap := &Schema{Properties: map[string]*Schema{"a": nil}} + check(sliceMap, "is nil") +} + func TestCheckLocal(t *testing.T) { for _, tt := range []struct { s *Schema want string // error must be non-nil and match this regexp }{ - {nil, "nil"}, { &Schema{Pattern: "]["}, "regexp", }, { - &Schema{PatternProperties: map[string]*Schema{"*": nil}}, + &Schema{PatternProperties: map[string]*Schema{"*": {}}}, "regexp", }, } { @@ -35,26 +58,48 @@ func TestCheckLocal(t *testing.T) { continue } if !regexp.MustCompile(tt.want).MatchString(err.Error()) { - t.Errorf("%s: did not match\nerror: %s\nregexp: %s", + t.Errorf("checkLocal returned error\n%q\nwanted it to match\n%s\nregexp: %s", tt.s.json(), err, tt.want) } } } -func TestSchemaNonTree(t *testing.T) { - run := func(s *Schema, kind string) { - err := s.check() - if err == nil || !strings.Contains(err.Error(), "tree") { - t.Fatalf("did not detect %s", kind) - } +func TestPaths(t *testing.T) { + // CheckStructure should assign paths to schemas. + // This test also verifies that Schema.all visits maps in sorted order. + root := &Schema{ + Type: "string", + PrefixItems: []*Schema{{Type: "int"}, {Items: &Schema{Type: "null"}}}, + Contains: &Schema{Properties: map[string]*Schema{ + "~1": {Type: "boolean"}, + "p": {}, + }}, } - s := &Schema{Type: "number"} - run(&Schema{Items: s, Contains: s}, "DAG") + type item struct { + s *Schema + p string + } + want := []item{ + {root, "root"}, + {root.Contains, "/contains"}, + {root.Contains.Properties["p"], "/contains/properties/p"}, + {root.Contains.Properties["~1"], "/contains/properties/~01"}, + {root.PrefixItems[0], "/prefixItems/0"}, + {root.PrefixItems[1], "/prefixItems/1"}, + {root.PrefixItems[1].Items, "/prefixItems/1/items"}, + } + if err := root.checkStructure(); err != nil { + t.Fatal(err) + } - root := &Schema{Items: s} - s.Items = root - run(root, "cycle") + var got []item + for s := range root.all() { + got = append(got, item{s, s.path}) + } + if !slices.Equal(got, want) { + t.Errorf("\ngot %v\nwant %v", got, want) + } } func TestResolveURIs(t *testing.T) { diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 245a4ac685f..1ced58787d7 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -17,7 +17,6 @@ import ( "reflect" "regexp" "slices" - "strconv" ) // A Schema is a JSON schema object. @@ -361,27 +360,12 @@ func Ptr[T any](x T) *T { return &x } // every applies f preorder to every schema under s including s. // The second argument to f is the path to the schema appended to the argument path. // It stops when f returns false. -func (s *Schema) every(f func(*Schema, []string) bool, path []string) bool { - return s == nil || - f(s, path) && s.everyChild(func(s *Schema, p []string) bool { return s.every(f, p) }, path) +func (s *Schema) every(f func(*Schema) bool) bool { + return f(s) && s.everyChild(func(s *Schema) bool { return s.every(f) }) } // everyChild reports whether f is true for every immediate child schema of s. -// The second argument to f is the path to the schema appended to the argument path. -// -// It does not call f on nil-valued fields holding individual schemas, like Contains, -// because a nil value indicates that the field is absent. -// It does call f on nils when they occur in slices and maps, so those invalid values -// can be detected when the schema is validated. -func (s *Schema) everyChild(f func(*Schema, []string) bool, path []string) bool { - if s == nil { - return false - } - var ( - schemaType = reflect.TypeFor[*Schema]() - schemaSliceType = reflect.TypeFor[[]*Schema]() - schemaMapType = reflect.TypeFor[map[string]*Schema]() - ) +func (s *Schema) everyChild(f func(*Schema) bool) bool { v := reflect.ValueOf(s) for _, info := range schemaFieldInfos { fv := v.Elem().FieldByIndex(info.sf.Index) @@ -389,25 +373,23 @@ func (s *Schema) everyChild(f func(*Schema, []string) bool, path []string) bool case schemaType: // A field that contains an individual schema. A nil is valid: it just means the field isn't present. c := fv.Interface().(*Schema) - if c != nil && !f(c, append(path, info.jsonName)) { + if c != nil && !f(c) { return false } case schemaSliceType: slice := fv.Interface().([]*Schema) - // A field that contains a slice of schemas. Yield nils so we can check for their presence. - for i, c := range slice { - if !f(c, append(path, info.jsonName, strconv.Itoa(i))) { + for _, c := range slice { + if !f(c) { return false } } case schemaMapType: - // A field that is a map of schemas. Ditto about nils. // Sort keys for determinism. m := fv.Interface().(map[string]*Schema) for _, k := range slices.Sorted(maps.Keys(m)) { - if !f(m[k], append(path, info.jsonName, escapeJSONPointerSegment(k))) { + if !f(m[k]) { return false } } @@ -417,16 +399,21 @@ func (s *Schema) everyChild(f func(*Schema, []string) bool, path []string) bool } // all wraps every in an iterator. -func (s *Schema) all() iter.Seq2[*Schema, []string] { - return func(yield func(*Schema, []string) bool) { s.every(yield, nil) } +func (s *Schema) all() iter.Seq[*Schema] { + return func(yield func(*Schema) bool) { s.every(yield) } } // children wraps everyChild in an iterator. -func (s *Schema) children() iter.Seq2[*Schema, []string] { - var pathBuffer [4]string - return func(yield func(*Schema, []string) bool) { s.everyChild(yield, pathBuffer[:0]) } +func (s *Schema) children() iter.Seq[*Schema] { + return func(yield func(*Schema) bool) { s.everyChild(yield) } } +var ( + schemaType = reflect.TypeFor[*Schema]() + schemaSliceType = reflect.TypeFor[[]*Schema]() + schemaMapType = reflect.TypeFor[map[string]*Schema]() +) + type structFieldInfo struct { sf reflect.StructField jsonName string diff --git a/internal/mcp/jsonschema/schema_test.go b/internal/mcp/jsonschema/schema_test.go index 921656cf30f..cc331417286 100644 --- a/internal/mcp/jsonschema/schema_test.go +++ b/internal/mcp/jsonschema/schema_test.go @@ -9,8 +9,6 @@ import ( "fmt" "math" "regexp" - "slices" - "strings" "testing" ) @@ -111,37 +109,6 @@ func TestUnmarshalErrors(t *testing.T) { } } -func TestEvery(t *testing.T) { - // Schema.every should visit all descendants of a schema, not just the immediate ones. - s := &Schema{ - Type: "string", - PrefixItems: []*Schema{{Type: "int"}, {Items: &Schema{Type: "null"}}}, - Contains: &Schema{Properties: map[string]*Schema{ - "~1": {Type: "boolean"}, - "p": {}, - }}, - } - - type item struct { - s *Schema - p string - } - want := []item{ - {s, ""}, - {s.Contains, "contains"}, - {s.Contains.Properties["p"], "contains/properties/p"}, - {s.Contains.Properties["~1"], "contains/properties/~01"}, - {s.PrefixItems[0], "prefixItems/0"}, - {s.PrefixItems[1], "prefixItems/1"}, - {s.PrefixItems[1].Items, "prefixItems/1/items"}, - } - var got []item - s.every(func(s *Schema, p []string) bool { got = append(got, item{s, strings.Join(p, "/")}); return true }, nil) - if !slices.Equal(got, want) { - t.Errorf("\n got %v\nwant %v", got, want) - } -} - func mustUnmarshal(t *testing.T, data []byte, ptr any) { t.Helper() if err := json.Unmarshal(data, ptr); err != nil { diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index 510dc73a80e..991c5f6701e 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -32,8 +32,7 @@ func (rs *Resolved) Validate(instance any) error { // state is the state of single call to ResolvedSchema.Validate. type state struct { - rs *Resolved - depth int + rs *Resolved // stack holds the schemas from recursive calls to validate. // These are the "dynamic scopes" used to resolve dynamic references. // https://json-schema.org/draft/2020-12/json-schema-core#scopes @@ -48,9 +47,6 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an defer func() { st.stack = st.stack[:len(st.stack)-1] // pop }() - if depth := len(st.stack); depth >= 100 { - return fmt.Errorf("max recursion depth of %d reached", depth) - } // We checked for nil schemas in [Schema.Resolve]. assert(schema != nil, "nil schema") From 3eaf5e21c82c3effa43998f4d3a26cc6dfb3881b Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 26 May 2025 13:31:17 -0400 Subject: [PATCH 138/196] internal/mcp/jsonschema: validate defaults Optionally validate a schema's defaults when resolving it. We simply walk the tree of schemas, validating each locally. This can't handle dynamic references, which require maintaing the stack from the root. It seems unlikely that a default would require one, however. Since there may be other options to Resolve, use an options struct to future-proof it. Also, fix a bug: you can't resolve a schema twice. Change-Id: I374c582d9cc49d66d4ed56233d1070b8b3263bdf Reviewed-on: https://go-review.googlesource.com/c/tools/+/676395 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- internal/mcp/jsonschema/resolve.go | 71 ++++++++++++++++-------- internal/mcp/jsonschema/resolve_test.go | 4 +- internal/mcp/jsonschema/schema.go | 4 +- internal/mcp/jsonschema/validate.go | 25 +++++++++ internal/mcp/jsonschema/validate_test.go | 34 +++++++++++- 5 files changed, 110 insertions(+), 28 deletions(-) diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index a2257b8de17..4725c0d4436 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -30,17 +30,33 @@ type Resolved struct { // A Loader reads and unmarshals the schema at uri, if any. type Loader func(uri *url.URL) (*Schema, error) +// ResolveOptions are options for [Schema.Resolve]. +type ResolveOptions struct { + // BaseURI is the URI relative to which the root schema should be resolved. + // If non-empty, must be an absolute URI (one that starts with a scheme). + // It is resolved (in the URI sense; see [url.ResolveReference]) with root's + // $id property. + // If the resulting URI is not absolute, then the schema cannot contain + // relative URI references. + BaseURI string + // Loader loads schemas that are referred to by a $ref but are not under the + // root schema (remote references). + // If nil, resolving a remote reference will return an error. + Loader Loader + // ValidateDefaults determines whether to validate values of "default" keywords + // against their schemas. + // The [JSON Schema specification] does not require this, but it is + // recommended if defaults will be used. + // + // [JSON Schema specification]: https://json-schema.org/understanding-json-schema/reference/annotations + ValidateDefaults bool +} + // Resolve resolves all references within the schema and performs other tasks that // prepare the schema for validation. -// -// baseURI can be empty, or an absolute URI (one that starts with a scheme). -// It is resolved (in the URI sense; see [url.ResolveReference]) with root's $id property. -// If the resulting URI is not absolute, then the schema cannot not contain relative URI references. -// -// loader loads schemas that are referred to by a $ref but not under root (a remote reference). -// If nil, remote references will return an error. -func (root *Schema) Resolve(baseURI string, loader Loader) (*Resolved, error) { - // There are four steps involved in preparing a schema to validate. +// If opts is nil, the default values are used. +func (root *Schema) Resolve(opts *ResolveOptions) (*Resolved, error) { + // There are up to five steps required to prepare a schema to validate. // 1. Load: read the schema from somewhere and unmarshal it. // This schema (root) may have been loaded or created in memory, but other schemas that // come into the picture in step 4 will be loaded by the given loader. @@ -49,37 +65,48 @@ func (root *Schema) Resolve(baseURI string, loader Loader) (*Resolved, error) { // 3. Resolve URIs: determine the base URI of the root and all its subschemas, and // resolve (in the URI sense) all identifiers and anchors with their bases. This step results // in a map from URIs to schemas within root. - // These three steps are idempotent. They may occur a several times on a schema, if - // it is loaded from several places. // 4. Resolve references: all refs in the schemas are replaced with the schema they refer to. + // 5. (Optional.) If opts.ValidateDefaults is true, validate the defaults. + if root.path != "" { + return nil, fmt.Errorf("jsonschema: Resolve: %s already resolved", root) + } + r := &resolver{loaded: map[string]*Resolved{}} + if opts != nil { + r.opts = *opts + } var base *url.URL - if baseURI == "" { + if r.opts.BaseURI == "" { base = &url.URL{} // so we can call ResolveReference on it } else { var err error - base, err = url.Parse(baseURI) + base, err = url.Parse(r.opts.BaseURI) if err != nil { return nil, fmt.Errorf("parsing base URI: %w", err) } } - if loader == nil { - loader = func(uri *url.URL) (*Schema, error) { + if r.opts.Loader == nil { + r.opts.Loader = func(uri *url.URL) (*Schema, error) { return nil, errors.New("cannot resolve remote schemas: no loader passed to Schema.Resolve") } } - r := &resolver{ - loader: loader, - loaded: map[string]*Resolved{}, - } - return r.resolve(root, base) + resolved, err := r.resolve(root, base) + if err != nil { + return nil, err + } + if r.opts.ValidateDefaults { + if err := resolved.validateDefaults(); err != nil { + return nil, err + } + } // TODO: before we return, throw away anything we don't need for validation. + return resolved, nil } // A resolver holds the state for resolution. type resolver struct { - loader Loader + opts ResolveOptions // A cache of loaded and partly resolved schemas. (They may not have had their // refs resolved.) The cache ensures that the loader will never be called more // than once with the same URI, and that reference cycles are handled properly. @@ -406,7 +433,7 @@ func (r *resolver) resolveRef(rs *Resolved, s *Schema, ref string) (_ *Schema, d referencedSchema = lrs.root } else { // Try to load the schema. - ls, err := r.loader(fraglessRefURI) + ls, err := r.opts.Loader(fraglessRefURI) if err != nil { return nil, "", fmt.Errorf("loading %s: %w", fraglessRefURI, err) } diff --git a/internal/mcp/jsonschema/resolve_test.go b/internal/mcp/jsonschema/resolve_test.go index 5621e38eb3f..b501147e0db 100644 --- a/internal/mcp/jsonschema/resolve_test.go +++ b/internal/mcp/jsonschema/resolve_test.go @@ -52,7 +52,7 @@ func TestCheckLocal(t *testing.T) { "regexp", }, } { - _, err := tt.s.Resolve("", nil) + _, err := tt.s.Resolve(nil) if err == nil { t.Errorf("%s: unexpectedly passed", tt.s.json()) continue @@ -192,7 +192,7 @@ func TestRefCycle(t *testing.T) { return s, nil } - rs, err := schemas["root"].Resolve("", loader) + rs, err := schemas["root"].Resolve(&ResolveOptions{Loader: loader}) if err != nil { t.Fatal(err) } diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 1ced58787d7..34ec5be73b1 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -182,7 +182,9 @@ type anchorInfo struct { // String returns a short description of the schema. func (s *Schema) String() string { if s.uri != nil { - return s.uri.String() + if u := s.uri.String(); u != "" { + return u + } } if a := cmp.Or(s.Anchor, s.DynamicAnchor); a != "" { return fmt.Sprintf("%q, anchor %s", s.base.uri.String(), a) diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index 991c5f6701e..ca26e891c35 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -30,6 +30,31 @@ func (rs *Resolved) Validate(instance any) error { return st.validate(reflect.ValueOf(instance), st.rs.root, nil) } +// validateDefaults walks the schema tree. If it finds a default, it validates it +// against the schema containing it. +// +// TODO(jba): account for dynamic refs. This algorithm simple-mindedly +// treats each schema with a default as its own root. +func (rs *Resolved) validateDefaults() error { + if s := rs.root.Schema; s != "" && s != draft202012 { + return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) + } + st := &state{rs: rs} + for s := range rs.root.all() { + // We checked for nil schemas in [Schema.Resolve]. + assert(s != nil, "nil schema") + if s.DynamicRef != "" { + return fmt.Errorf("jsonschema: %s: validateDefaults does not support dynamic refs", s) + } + if s.Default != nil { + if err := st.validate(reflect.ValueOf(*s.Default), s, nil); err != nil { + return err + } + } + } + return nil +} + // state is the state of single call to ResolvedSchema.Validate. type state struct { rs *Resolved diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index 89cc2475b22..0b710f1d958 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -53,7 +53,7 @@ func TestValidate(t *testing.T) { } for _, g := range groups { t.Run(g.Description, func(t *testing.T) { - rs, err := g.Schema.Resolve("", loadRemote) + rs, err := g.Schema.Resolve(&ResolveOptions{Loader: loadRemote}) if err != nil { t.Fatal(err) } @@ -82,7 +82,7 @@ func TestValidateErrors(t *testing.T) { schema := &Schema{ PrefixItems: []*Schema{{Contains: &Schema{Type: "integer"}}}, } - rs, err := schema.Resolve("", nil) + rs, err := schema.Resolve(nil) if err != nil { t.Fatal(err) } @@ -93,6 +93,34 @@ func TestValidateErrors(t *testing.T) { } } +func TestValidateDefaults(t *testing.T) { + anyptr := func(x any) *any { return &x } + + s := &Schema{ + Properties: map[string]*Schema{ + "a": {Type: "integer", Default: anyptr(3)}, + "b": {Type: "string", Default: anyptr("s")}, + }, + Default: anyptr(map[string]any{"a": 1, "b": "two"}), + } + if _, err := s.Resolve(&ResolveOptions{ValidateDefaults: true}); err != nil { + t.Fatal(err) + } + + s = &Schema{ + Properties: map[string]*Schema{ + "a": {Type: "integer", Default: anyptr(3)}, + "b": {Type: "string", Default: anyptr("s")}, + }, + Default: anyptr(map[string]any{"a": 1, "b": 2}), + } + _, err := s.Resolve(&ResolveOptions{ValidateDefaults: true}) + want := `has type "integer", want "string"` + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("Resolve returned error %q, want %q", err, want) + } +} + func TestStructInstance(t *testing.T) { instance := struct { I int @@ -174,7 +202,7 @@ func TestStructInstance(t *testing.T) { false, }, } { - res, err := tt.s.Resolve("", nil) + res, err := tt.s.Resolve(nil) if err != nil { t.Fatal(err) } From 9233e122509d237673c6a375a16bbe09a253e201 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 28 May 2025 10:52:00 -0400 Subject: [PATCH 139/196] gopls: update to staticcheck@master Dominik recently added a number of optimizations based on the typeindex analyzer, and they are now ready for use; see https://github.com/dominikh/go-tools/issues/1652#issuecomment-2905907150. Change-Id: I0cfd6df01487d9b242d138bedf100ebbb2d0f9f9 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676795 Auto-Submit: Alan Donovan Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- gopls/doc/analyzers.md | 4 ++-- gopls/go.mod | 4 ++-- gopls/go.sum | 4 ++-- gopls/internal/doc/api.json | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md index 2a974eaa496..331456a6c82 100644 --- a/gopls/doc/analyzers.md +++ b/gopls/doc/analyzers.md @@ -1875,7 +1875,7 @@ In a type switch like the following type T struct{} func (T) Read(b []byte) (int, error) { return 0, nil } - var v interface{} = T{} + var v any = T{} switch v.(type) { case io.Reader: @@ -1893,7 +1893,7 @@ Another example: func (T) Read(b []byte) (int, error) { return 0, nil } func (T) Close() error { return nil } - var v interface{} = T{} + var v any = T{} switch v.(type) { case io.Reader: diff --git a/gopls/go.mod b/gopls/go.mod index 9868579f20d..80da71f797e 100644 --- a/gopls/go.mod +++ b/gopls/go.mod @@ -11,10 +11,10 @@ require ( golang.org/x/sys v0.33.0 golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3 golang.org/x/text v0.25.0 - golang.org/x/tools v0.30.0 + golang.org/x/tools v0.33.1-0.20250521210010-423c5afcceff golang.org/x/vuln v1.1.4 gopkg.in/yaml.v3 v3.0.1 - honnef.co/go/tools v0.6.1 + honnef.co/go/tools v0.7.0-0.dev.0.20250523013057-bbc2f4dd71ea mvdan.cc/gofumpt v0.7.0 mvdan.cc/xurls/v2 v2.6.0 ) diff --git a/gopls/go.sum b/gopls/go.sum index 143edbc8909..d6d9d39c7cd 100644 --- a/gopls/go.sum +++ b/gopls/go.sum @@ -59,8 +59,8 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogR gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= -honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= +honnef.co/go/tools v0.7.0-0.dev.0.20250523013057-bbc2f4dd71ea h1:fj8r9irJSpolAGUdZBxJIRY3lLc4jH2Dt4lwnWyWwpw= +honnef.co/go/tools v0.7.0-0.dev.0.20250523013057-bbc2f4dd71ea/go.mod h1:EPDDhEZqVHhWuPI5zPAsjU0U7v9xNIWjoOVyZ5ZcniQ= mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU= mvdan.cc/gofumpt v0.7.0/go.mod h1:txVFJy/Sc/mvaycET54pV8SW8gWxTlUuGHVEcncmNUo= mvdan.cc/xurls/v2 v2.6.0 h1:3NTZpeTxYVWNSokW3MKeyVkz/j7uYXYiMtXRUfmjbgI= diff --git a/gopls/internal/doc/api.json b/gopls/internal/doc/api.json index fa73b711868..664561ab5f1 100644 --- a/gopls/internal/doc/api.json +++ b/gopls/internal/doc/api.json @@ -1006,7 +1006,7 @@ }, { "Name": "\"SA4020\"", - "Doc": "Unreachable case clause in a type switch\n\nIn a type switch like the following\n\n type T struct{}\n func (T) Read(b []byte) (int, error) { return 0, nil }\n\n var v interface{} = T{}\n\n switch v.(type) {\n case io.Reader:\n // ...\n case T:\n // unreachable\n }\n\nthe second case clause can never be reached because T implements\nio.Reader and case clauses are evaluated in source order.\n\nAnother example:\n\n type T struct{}\n func (T) Read(b []byte) (int, error) { return 0, nil }\n func (T) Close() error { return nil }\n\n var v interface{} = T{}\n\n switch v.(type) {\n case io.Reader:\n // ...\n case io.ReadCloser:\n // unreachable\n }\n\nEven though T has a Close method and thus implements io.ReadCloser,\nio.Reader will always match first. The method set of io.Reader is a\nsubset of io.ReadCloser. Thus it is impossible to match the second\ncase without matching the first case.\n\n\nStructurally equivalent interfaces\n\nA special case of the previous example are structurally identical\ninterfaces. Given these declarations\n\n type T error\n type V error\n\n func doSomething() error {\n err, ok := doAnotherThing()\n if ok {\n return T(err)\n }\n\n return U(err)\n }\n\nthe following type switch will have an unreachable case clause:\n\n switch doSomething().(type) {\n case T:\n // ...\n case V:\n // unreachable\n }\n\nT will always match before V because they are structurally equivalent\nand therefore doSomething()'s return value implements both.\n\nAvailable since\n 2019.2\n", + "Doc": "Unreachable case clause in a type switch\n\nIn a type switch like the following\n\n type T struct{}\n func (T) Read(b []byte) (int, error) { return 0, nil }\n\n var v any = T{}\n\n switch v.(type) {\n case io.Reader:\n // ...\n case T:\n // unreachable\n }\n\nthe second case clause can never be reached because T implements\nio.Reader and case clauses are evaluated in source order.\n\nAnother example:\n\n type T struct{}\n func (T) Read(b []byte) (int, error) { return 0, nil }\n func (T) Close() error { return nil }\n\n var v any = T{}\n\n switch v.(type) {\n case io.Reader:\n // ...\n case io.ReadCloser:\n // unreachable\n }\n\nEven though T has a Close method and thus implements io.ReadCloser,\nio.Reader will always match first. The method set of io.Reader is a\nsubset of io.ReadCloser. Thus it is impossible to match the second\ncase without matching the first case.\n\n\nStructurally equivalent interfaces\n\nA special case of the previous example are structurally identical\ninterfaces. Given these declarations\n\n type T error\n type V error\n\n func doSomething() error {\n err, ok := doAnotherThing()\n if ok {\n return T(err)\n }\n\n return U(err)\n }\n\nthe following type switch will have an unreachable case clause:\n\n switch doSomething().(type) {\n case T:\n // ...\n case V:\n // unreachable\n }\n\nT will always match before V because they are structurally equivalent\nand therefore doSomething()'s return value implements both.\n\nAvailable since\n 2019.2\n", "Default": "true", "Status": "" }, @@ -2732,7 +2732,7 @@ }, { "Name": "SA4020", - "Doc": "Unreachable case clause in a type switch\n\nIn a type switch like the following\n\n type T struct{}\n func (T) Read(b []byte) (int, error) { return 0, nil }\n\n var v interface{} = T{}\n\n switch v.(type) {\n case io.Reader:\n // ...\n case T:\n // unreachable\n }\n\nthe second case clause can never be reached because T implements\nio.Reader and case clauses are evaluated in source order.\n\nAnother example:\n\n type T struct{}\n func (T) Read(b []byte) (int, error) { return 0, nil }\n func (T) Close() error { return nil }\n\n var v interface{} = T{}\n\n switch v.(type) {\n case io.Reader:\n // ...\n case io.ReadCloser:\n // unreachable\n }\n\nEven though T has a Close method and thus implements io.ReadCloser,\nio.Reader will always match first. The method set of io.Reader is a\nsubset of io.ReadCloser. Thus it is impossible to match the second\ncase without matching the first case.\n\n\nStructurally equivalent interfaces\n\nA special case of the previous example are structurally identical\ninterfaces. Given these declarations\n\n type T error\n type V error\n\n func doSomething() error {\n err, ok := doAnotherThing()\n if ok {\n return T(err)\n }\n\n return U(err)\n }\n\nthe following type switch will have an unreachable case clause:\n\n switch doSomething().(type) {\n case T:\n // ...\n case V:\n // unreachable\n }\n\nT will always match before V because they are structurally equivalent\nand therefore doSomething()'s return value implements both.\n\nAvailable since\n 2019.2\n", + "Doc": "Unreachable case clause in a type switch\n\nIn a type switch like the following\n\n type T struct{}\n func (T) Read(b []byte) (int, error) { return 0, nil }\n\n var v any = T{}\n\n switch v.(type) {\n case io.Reader:\n // ...\n case T:\n // unreachable\n }\n\nthe second case clause can never be reached because T implements\nio.Reader and case clauses are evaluated in source order.\n\nAnother example:\n\n type T struct{}\n func (T) Read(b []byte) (int, error) { return 0, nil }\n func (T) Close() error { return nil }\n\n var v any = T{}\n\n switch v.(type) {\n case io.Reader:\n // ...\n case io.ReadCloser:\n // unreachable\n }\n\nEven though T has a Close method and thus implements io.ReadCloser,\nio.Reader will always match first. The method set of io.Reader is a\nsubset of io.ReadCloser. Thus it is impossible to match the second\ncase without matching the first case.\n\n\nStructurally equivalent interfaces\n\nA special case of the previous example are structurally identical\ninterfaces. Given these declarations\n\n type T error\n type V error\n\n func doSomething() error {\n err, ok := doAnotherThing()\n if ok {\n return T(err)\n }\n\n return U(err)\n }\n\nthe following type switch will have an unreachable case clause:\n\n switch doSomething().(type) {\n case T:\n // ...\n case V:\n // unreachable\n }\n\nT will always match before V because they are structurally equivalent\nand therefore doSomething()'s return value implements both.\n\nAvailable since\n 2019.2\n", "URL": "https://staticcheck.dev/docs/checks/#SA4020", "Default": true }, From 359ea3adb94a8d93dc6ffe853b91b1e2d31a1d9c Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 28 May 2025 04:02:00 +0000 Subject: [PATCH 140/196] internal/mcp: add JSON conformance tests, and fix some bugs Add the beginnings of some testdata-driven conformance tests that check JSON-level compatibility of SDK messages, and fix some bugs in our JSON message handling: - "params" can be missing - "_meta" should not be null - lists of tools/prompts/resources should not be null In order to implement this, we need to be able to know when the server is done processing messages. In general, this is a hard problem; for gopls we do this with a complicated system of progress reporting and accounting. However, the new synctest package makes this trivial, as we can call synctest.Wait to detect when the server is idle. Therefore, this test requires 1.25 (it uses synctest.Test, added in 1.25, and so can't run at 1.24 with GOEXPERIMENT=synctest). Change-Id: I97fa28b56868340ece266a58e1eb3c3bf8cfe0e9 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676538 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam Auto-Submit: Robert Findley --- internal/jsonrpc2_v2/messages.go | 13 + internal/mcp/conformance_test.go | 280 ++++++++++++++++++ internal/mcp/mcp_test.go | 111 ++++--- internal/mcp/server.go | 7 + internal/mcp/shared.go | 8 +- .../testdata/conformance/server/prompts.txtar | 74 +++++ .../testdata/conformance/server/tools.txtar | 93 ++++++ 7 files changed, 535 insertions(+), 51 deletions(-) create mode 100644 internal/mcp/conformance_test.go create mode 100644 internal/mcp/testdata/conformance/server/prompts.txtar create mode 100644 internal/mcp/testdata/conformance/server/tools.txtar diff --git a/internal/jsonrpc2_v2/messages.go b/internal/jsonrpc2_v2/messages.go index 3b2ebc7afeb..0aa321d92b6 100644 --- a/internal/jsonrpc2_v2/messages.go +++ b/internal/jsonrpc2_v2/messages.go @@ -146,6 +146,19 @@ func EncodeMessage(msg Message) ([]byte, error) { return data, nil } +// EncodeIndent is like EncodeMessage, but honors indents. +// TODO(rfindley): refactor so that this concern is handled independently. +// Perhaps we should pass in a json.Encoder? +func EncodeIndent(msg Message, prefix, indent string) ([]byte, error) { + wire := wireCombined{VersionTag: wireVersion} + msg.marshal(&wire) + data, err := json.MarshalIndent(&wire, prefix, indent) + if err != nil { + return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + } + return data, nil +} + func DecodeMessage(data []byte) (Message, error) { msg := wireCombined{} if err := json.Unmarshal(data, &msg); err != nil { diff --git a/internal/mcp/conformance_test.go b/internal/mcp/conformance_test.go new file mode 100644 index 00000000000..dfd6b264269 --- /dev/null +++ b/internal/mcp/conformance_test.go @@ -0,0 +1,280 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.25 + +package mcp + +import ( + "bytes" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "testing/synctest" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" + "golang.org/x/tools/txtar" +) + +var update = flag.Bool("update", false, "if set, update conformance test data") + +// A conformance test checks JSON-level conformance of a test server or client. +// This allows us to confirm that we can handle the input or output of other +// SDKs, even if they behave differently at the JSON level (for example, have +// different behavior with respect to optional fields). +// +// The client and server fields hold an encoded sequence of JSON-RPC messages. +// +// For server tests, the client messages are a sequence of messages to be sent +// from the (synthetic) client and the server messages are the expected +// messages to be received from the real server. +// +// For client tests, it's the other way around: server messages are synthetic, +// and client messages are expected from the real client. +// +// Conformance tests are loaded from txtar-encoded testdata files. Run the test +// with -update to have the test runner update the expected output, which may +// be client or server depending on the perspective of the test. +type conformanceTest struct { + name string // test name + path string // path to test file + archive *txtar.Archive // raw archive, for updating + tools, prompts, resources []string // named features to include + client []jsonrpc2.Message // client messages + server []jsonrpc2.Message // server messages +} + +// TODO(rfindley): add client conformance tests. + +func TestServerConformance(t *testing.T) { + var tests []*conformanceTest + dir := filepath.Join("testdata", "conformance", "server") + if err := filepath.WalkDir(dir, func(path string, _ fs.DirEntry, err error) error { + if err != nil { + return err + } + if strings.HasSuffix(path, ".txtar") { + test, err := loadConformanceTest(dir, path) + if err != nil { + return fmt.Errorf("%s: %v", path, err) + } + tests = append(tests, test) + } + return nil + }); err != nil { + t.Fatal(err) + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // We use synctest here because in general, there is no way to know when the + // server is done processing any notifications. As long as our server doesn't + // do background work, synctest provides an easy way for us to detect when the + // server is done processing. + // + // By comparison, gopls has a complicated framework based on progress + // reporting and careful accounting to detect when all 'expected' work + // on the server is complete. + synctest.Test(t, func(t *testing.T) { + runServerTest(t, test) + }) + }) + } +} + +// runServerTest runs the server conformance test. +// It must be executed in a synctest bubble. +func runServerTest(t *testing.T, test *conformanceTest) { + ctx := t.Context() + // Construct the server based on features listed in the test. + s := NewServer("testServer", "v1.0.0", nil) + add(tools, s.AddTools, test.tools...) + add(prompts, s.AddPrompts, test.prompts...) + add(resources, s.AddResources, test.resources...) + + // Connect the server, and connect the client stream, + // but don't connect an actual client. + cTransport, sTransport := NewInMemoryTransports() + ss, err := s.Connect(ctx, sTransport) + if err != nil { + t.Fatal(err) + } + cStream, err := cTransport.Connect(ctx) + if err != nil { + t.Fatal(err) + } + + // Collect server messages asynchronously. + var wg sync.WaitGroup + var ( + serverMessages []jsonrpc2.Message + serverErr error // abnormal failure of the server stream + ) + wg.Add(1) + go func() { + defer wg.Done() + for { + msg, _, err := cStream.Read(ctx) + if err != nil { + // TODO(rfindley): we don't document (or want to document) that the in + // memory transports use a net.Pipe. How can users detect this failure? + // Should we promote it to EOF? + if !errors.Is(err, io.ErrClosedPipe) { + serverErr = err + } + break + } + serverMessages = append(serverMessages, msg) + } + }() + + // Write client messages to the stream. + for _, msg := range test.client { + if _, err := cStream.Write(ctx, msg); err != nil { + t.Fatalf("Write failed: %v", err) + } + } + + // Before closing the stream, wait for all messages to be processed. + synctest.Wait() + if err := cStream.Close(); err != nil { + t.Fatalf("Stream.Close failed: %v", err) + } + + ss.Wait() + wg.Wait() + if serverErr != nil { + t.Fatalf("reading server messages failed: %v", serverErr) + } + + // Handle server output. If -update is set, write the 'server' file. + // Otherwise, compare with expected. + if *update { + arch := &txtar.Archive{ + Comment: test.archive.Comment, + } + var buf bytes.Buffer + for _, msg := range serverMessages { + data, err := jsonrpc2.EncodeIndent(msg, "", "\t") + if err != nil { + t.Fatalf("jsonrpc2.EncodeIndent failed: %v", err) + } + buf.Write(data) + buf.WriteByte('\n') + } + serverFile := txtar.File{Name: "server", Data: buf.Bytes()} + seenServer := false // replace or append the 'server' file + for _, f := range test.archive.Files { + if f.Name == "server" { + seenServer = true + arch.Files = append(arch.Files, serverFile) + } else { + arch.Files = append(arch.Files, f) + } + } + if !seenServer { + arch.Files = append(arch.Files, serverFile) + } + if err := os.WriteFile(test.path, txtar.Format(arch), 0666); err != nil { + t.Fatalf("os.WriteFile(%q) failed: %v", test.path, err) + } + } else { + // jsonrpc2.Messages are not comparable, so we instead compare lines of JSON. + transform := cmpopts.AcyclicTransformer("toJSON", func(msg jsonrpc2.Message) []string { + encoded, err := jsonrpc2.EncodeIndent(msg, "", "\t") + if err != nil { + t.Fatal(err) + } + return strings.Split(string(encoded), "\n") + }) + if diff := cmp.Diff(test.server, serverMessages, transform); diff != "" { + t.Errorf("Mismatching server messages (-want +got):\n%s", diff) + } + } +} + +// loadConformanceTest loads one conformance test from the given path contained +// in the root dir. +func loadConformanceTest(dir, path string) (*conformanceTest, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, err + } + test := &conformanceTest{ + name: strings.TrimPrefix(path, dir+string(filepath.Separator)), + path: path, + archive: txtar.Parse(content), + } + if len(test.archive.Files) == 0 { + return nil, fmt.Errorf("txtar archive %q has no '-- filename --' sections", path) + } + + // decodeMessages loads JSON-RPC messages from the archive file. + decodeMessages := func(data []byte) ([]jsonrpc2.Message, error) { + dec := json.NewDecoder(bytes.NewReader(data)) + var res []jsonrpc2.Message + for dec.More() { + var raw json.RawMessage + if err := dec.Decode(&raw); err != nil { + return nil, err + } + m, err := jsonrpc2.DecodeMessage(raw) + if err != nil { + return nil, err + } + res = append(res, m) + } + return res, nil + } + // loadFeatures loads lists of named features from the archive file. + loadFeatures := func(data []byte) []string { + var feats []string + for line := range strings.SplitSeq(string(data), "\n") { + if f := strings.TrimSpace(line); f != "" { + feats = append(feats, f) + } + } + return feats + } + + seen := make(map[string]bool) // catch accidentally duplicate files + for _, f := range test.archive.Files { + if seen[f.Name] { + return nil, fmt.Errorf("duplicate file name %q", f.Name) + } + seen[f.Name] = true + switch f.Name { + case "tools": + test.tools = loadFeatures(f.Data) + case "prompts": + test.prompts = loadFeatures(f.Data) + case "resource": + test.resources = loadFeatures(f.Data) + case "client": + test.client, err = decodeMessages(f.Data) + if err != nil { + return nil, fmt.Errorf("txtar archive %q contains bad -- client -- section: %v", path, err) + } + case "server": + test.server, err = decodeMessages(f.Data) + if err != nil { + return nil, fmt.Errorf("txtar archive %q contains bad -- server -- section: %v", path, err) + } + default: + return nil, fmt.Errorf("txtar archive %q contains unexpected file %q", path, f.Name) + } + } + + return test, nil +} diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 8a8ddfd1290..ca62608d1ac 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -58,34 +58,10 @@ func TestEndToEnd(t *testing.T) { InitializedHandler: func(context.Context, *ServerSession, *InitializedParams) { notificationChans["initialized"] <- 0 }, RootsListChangedHandler: func(context.Context, *ServerSession, *RootsListChangedParams) { notificationChans["roots"] <- 0 }, } - s := NewServer("testServer", "v1.0.0", sopts) - - // The 'greet' tool says hi. - s.AddTools(NewTool("greet", "say hi", sayHi)) - - // The 'fail' tool returns this error. - failure := errors.New("mcp failure") - s.AddTools( - NewTool("fail", "just fail", func(context.Context, *ServerSession, struct{}) ([]*Content, error) { - return nil, failure - }), - ) - - s.AddPrompts( - NewPrompt("code_review", "do a code review", - func(_ context.Context, _ *ServerSession, params struct{ Code string }, _ *GetPromptParams) (*GetPromptResult, error) { - return &GetPromptResult{ - Description: "Code review prompt", - Messages: []*PromptMessage{ - {Role: "user", Content: NewTextContent("Please review the following code: " + params.Code)}, - }, - }, nil - }), - NewPrompt("fail", "", func(_ context.Context, _ *ServerSession, args struct{}, _ *GetPromptParams) (*GetPromptResult, error) { - return nil, failure - }), - ) + add(tools, s.AddTools, "greet", "fail") + add(prompts, s.AddPrompts, "code_review", "fail") + add(resources, s.AddResources, "info.txt", "fail.txt") // Connect the server. ss, err := s.Connect(ctx, st) @@ -167,8 +143,8 @@ func TestEndToEnd(t *testing.T) { t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) } - if _, err := cs.GetPrompt(ctx, &GetPromptParams{Name: "fail"}); err == nil || !strings.Contains(err.Error(), failure.Error()) { - t.Errorf("fail returned unexpected error: got %v, want containing %v", err, failure) + if _, err := cs.GetPrompt(ctx, &GetPromptParams{Name: "fail"}); err == nil || !strings.Contains(err.Error(), errTestFailure.Error()) { + t.Errorf("fail returned unexpected error: got %v, want containing %v", err, errTestFailure) } s.AddPrompts(&ServerPrompt{Prompt: &Prompt{Name: "T"}}) @@ -227,7 +203,7 @@ func TestEndToEnd(t *testing.T) { } wantFail := &CallToolResult{ IsError: true, - Content: []*Content{{Type: "text", Text: failure.Error()}}, + Content: []*Content{{Type: "text", Text: errTestFailure.Error()}}, } if diff := cmp.Diff(wantFail, gotFail); diff != "" { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) @@ -243,27 +219,12 @@ func TestEndToEnd(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("TODO: fix for Windows") } - resource1 := &Resource{ - Name: "public", - MIMEType: "text/plain", - URI: "file:///info.txt", - } - resource2 := &Resource{ - Name: "public", // names are not unique IDs - MIMEType: "text/plain", - URI: "file:///nonexistent.txt", - } - - readHandler := s.FileResourceHandler("testdata/files") - s.AddResources( - &ServerResource{resource1, readHandler}, - &ServerResource{resource2, readHandler}) - + wantResources := []*Resource{resource2, resource1} lrres, err := cs.ListResources(ctx, nil) if err != nil { t.Fatal(err) } - if diff := cmp.Diff([]*Resource{resource1, resource2}, lrres.Resources); diff != "" { + if diff := cmp.Diff(wantResources, lrres.Resources); diff != "" { t.Errorf("resources/list mismatch (-want, +got):\n%s", diff) } @@ -272,7 +233,7 @@ func TestEndToEnd(t *testing.T) { mimeType string // "": not found; "text/plain": resource; "text/template": template }{ {"file:///info.txt", "text/plain"}, - {"file:///nonexistent.txt", ""}, + {"file:///fail.txt", ""}, // TODO(jba): add resource template cases when we implement them } { rres, err := cs.ReadResource(ctx, &ReadResourceParams{URI: tt.uri}) @@ -416,6 +377,60 @@ func TestEndToEnd(t *testing.T) { } } +// Registry of values to be referenced in tests. +var ( + errTestFailure = errors.New("mcp failure") + + tools = map[string]*ServerTool{ + "greet": NewTool("greet", "say hi", sayHi), + "fail": NewTool("fail", "just fail", func(context.Context, *ServerSession, struct{}) ([]*Content, error) { + return nil, errTestFailure + }), + } + + prompts = map[string]*ServerPrompt{ + "code_review": NewPrompt("code_review", "do a code review", + func(_ context.Context, _ *ServerSession, params struct{ Code string }, _ *GetPromptParams) (*GetPromptResult, error) { + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{ + {Role: "user", Content: NewTextContent("Please review the following code: " + params.Code)}, + }, + }, nil + }), + "fail": NewPrompt("fail", "", func(_ context.Context, _ *ServerSession, args struct{}, _ *GetPromptParams) (*GetPromptResult, error) { + return nil, errTestFailure + }), + } + + resource1 = &Resource{ + Name: "public", + MIMEType: "text/plain", + URI: "file:///info.txt", + } + resource2 = &Resource{ + Name: "public", // names are not unique IDs + MIMEType: "text/plain", + URI: "file:///fail.txt", + } + readHandler = fileResourceHandler("testdata/files") + resources = map[string]*ServerResource{ + "info.txt": {resource1, readHandler}, + "fail.txt": {resource2, readHandler}, + } +) + +// Add calls the given function to add the named features. +func add[T any](m map[string]T, add func(...T), names ...string) { + for _, name := range names { + feat, ok := m[name] + if !ok { + panic("missing feature " + name) + } + add(feat) + } +} + // errorCode returns the code associated with err. // If err is nil, it returns 0. // If there is no code, it returns -1. diff --git a/internal/mcp/server.go b/internal/mcp/server.go index fda8bf7a9b0..76bebb25a0a 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -189,6 +189,7 @@ func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPr } res := new(ListPromptsResult) res.NextCursor = nextCursor + res.Prompts = []*Prompt{} // avoid JSON null for _, p := range prompts { res.Prompts = append(res.Prompts, p.Prompt) } @@ -219,6 +220,7 @@ func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListTool } res := new(ListToolsResult) res.NextCursor = nextCursor + res.Tools = []*Tool{} // avoid JSON null for _, t := range tools { res.Tools = append(res.Tools, t.Tool) } @@ -248,6 +250,7 @@ func (s *Server) listResources(_ context.Context, _ *ServerSession, params *List } res := new(ListResourcesResult) res.NextCursor = nextCursor + res.Resources = []*Resource{} // avoid JSON null for _, r := range resources { res.Resources = append(res.Resources, r.Resource) } @@ -296,6 +299,10 @@ func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *Re // are always caught. Go 1.24 and above also protects against symlink-based attacks, // where symlinks under dir lead out of the tree. func (s *Server) FileResourceHandler(dir string) ResourceHandler { + return fileResourceHandler(dir) +} + +func fileResourceHandler(dir string) ResourceHandler { // Convert dir to an absolute path. dirFilepath, err := filepath.Abs(dir) if err != nil { diff --git a/internal/mcp/shared.go b/internal/mcp/shared.go index ce9c93b34ff..6f6f3d0e3eb 100644 --- a/internal/mcp/shared.go +++ b/internal/mcp/shared.go @@ -105,8 +105,10 @@ func newMethodInfo[S ClientSession | ServerSession, P Params, R Result](d typedM return methodInfo[S]{ unmarshalParams: func(m json.RawMessage) (Params, error) { var p P - if err := json.Unmarshal(m, &p); err != nil { - return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) + if m != nil { + if err := json.Unmarshal(m, &p); err != nil { + return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) + } } return p, nil }, @@ -200,7 +202,7 @@ func (m Meta) MarshalJSON() ([]byte, error) { // If ProgressToken is nil, accept Data["progressToken"]. We can't call marshalStructWithMap // in that case because it will complain about duplicate fields. (We'd have to // make it much smarter to avoid that; not worth it.) - if m.ProgressToken == nil { + if m.ProgressToken == nil && len(m.Data) > 0 { return json.Marshal(m.Data) } return marshalStructWithMap((*metaSansMethods)(&m), "Data") diff --git a/internal/mcp/testdata/conformance/server/prompts.txtar b/internal/mcp/testdata/conformance/server/prompts.txtar new file mode 100644 index 00000000000..0ef0cdc22e9 --- /dev/null +++ b/internal/mcp/testdata/conformance/server/prompts.txtar @@ -0,0 +1,74 @@ +Check behavior of a server with just prompts. + +Fixed bugs: +- empty tools lists should not be returned as 'null' + +-- prompts -- +code_review + +-- client -- +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } + } +} +{ "jsonrpc": "2.0", "id": 2, "method": "tools/list" } +{ "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } +-- server -- +{ + "jsonrpc": "2.0", + "id": 1, + "result": { + "_meta": {}, + "capabilities": { + "completions": {}, + "logging": {}, + "prompts": { + "listChanged": true + }, + "resources": { + "listChanged": true + }, + "tools": { + "listChanged": true + } + }, + "protocolVersion": "2024-11-05", + "serverInfo": { + "name": "testServer", + "version": "v1.0.0" + } + } +} +{ + "jsonrpc": "2.0", + "id": 2, + "result": { + "_meta": {}, + "tools": [] + } +} +{ + "jsonrpc": "2.0", + "id": 4, + "result": { + "_meta": {}, + "prompts": [ + { + "arguments": [ + { + "name": "Code", + "required": true + } + ], + "description": "do a code review", + "name": "code_review" + } + ] + } +} diff --git a/internal/mcp/testdata/conformance/server/tools.txtar b/internal/mcp/testdata/conformance/server/tools.txtar new file mode 100644 index 00000000000..039e08d71ff --- /dev/null +++ b/internal/mcp/testdata/conformance/server/tools.txtar @@ -0,0 +1,93 @@ +Check behavior of a server with just tools. + +Fixed bugs: +- "tools/list" can have missing params +- "_meta" should not be nil +- empty resource or prompts should not be returned as 'null' + +-- tools -- +greet + +-- client -- +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } + } +} +{ "jsonrpc": "2.0", "id": 2, "method": "tools/list" } +{ "jsonrpc": "2.0", "id": 3, "method": "resources/list" } +{ "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } +-- server -- +{ + "jsonrpc": "2.0", + "id": 1, + "result": { + "_meta": {}, + "capabilities": { + "completions": {}, + "logging": {}, + "prompts": { + "listChanged": true + }, + "resources": { + "listChanged": true + }, + "tools": { + "listChanged": true + } + }, + "protocolVersion": "2024-11-05", + "serverInfo": { + "name": "testServer", + "version": "v1.0.0" + } + } +} +{ + "jsonrpc": "2.0", + "id": 2, + "result": { + "_meta": {}, + "tools": [ + { + "description": "say hi", + "inputSchema": { + "type": "object", + "required": [ + "Name" + ], + "properties": { + "Name": { + "type": "string" + } + }, + "additionalProperties": { + "not": {} + } + }, + "name": "greet" + } + ] + } +} +{ + "jsonrpc": "2.0", + "id": 3, + "result": { + "_meta": {}, + "resources": [] + } +} +{ + "jsonrpc": "2.0", + "id": 4, + "result": { + "_meta": {}, + "prompts": [] + } +} From 93f6460e4cb24d20cbf1d03f4fa71a6a7c2014d6 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Wed, 28 May 2025 16:50:51 +0000 Subject: [PATCH 141/196] internal/mcp: add iterator method for tools Change-Id: Ibe799caf6592b2763868bdb30f474bc224aa8b00 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676915 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/client.go | 29 +++++++++++++++ internal/mcp/server_example_test.go | 57 +++++++++++++++++++++-------- 2 files changed, 70 insertions(+), 16 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index a42e74678fa..1d8252c0e2f 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -8,6 +8,7 @@ import ( "context" "encoding/json" "fmt" + "iter" "slices" "sync" @@ -314,3 +315,31 @@ func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, para } return nil, nil } + +// Tools provides an iterator for all tools available on the server, +// automatically fetching pages and managing cursors. +// The `params` argument can set the initial cursor. +func (c *ClientSession) Tools(ctx context.Context, params *ListToolsParams) iter.Seq2[Tool, error] { + currentParams := &ListToolsParams{} + if params != nil { + *currentParams = *params + } + return func(yield func(Tool, error) bool) { + for { + res, err := c.ListTools(ctx, currentParams) + if err != nil { + yield(Tool{}, err) + return + } + for _, t := range res.Tools { + if !yield(*t, nil) { + return + } + } + if res.NextCursor == "" { + return + } + currentParams.Cursor = res.NextCursor + } + } +} diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index c17dafc814b..231a329d004 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -57,8 +57,8 @@ func ExampleServer() { } // createSessions creates and connects an in-memory client and server session for testing purposes. -func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession, *mcp.Server) { - server := mcp.NewServer("server", "v0.0.1", nil) +func createSessions(ctx context.Context, opts *mcp.ServerOptions) (*mcp.ClientSession, *mcp.ServerSession, *mcp.Server) { + server := mcp.NewServer("server", "v0.0.1", opts) client := mcp.NewClient("client", "v0.0.1", nil) serverTransport, clientTransport := mcp.NewInMemoryTransports() serverSession, err := server.Connect(ctx, serverTransport) @@ -77,19 +77,44 @@ func TestListTools(t *testing.T) { toolB := mcp.NewTool("banana", "banana tool", SayHi) toolC := mcp.NewTool("cherry", "cherry tool", SayHi) tools := []*mcp.ServerTool{toolA, toolB, toolC} - wantTools := []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool} + wantListTools := []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool} + wantIteratorTools := []mcp.Tool{*toolA.Tool, *toolB.Tool, *toolC.Tool} ctx := context.Background() - clientSession, serverSession, server := createSessions(ctx) - defer clientSession.Close() - defer serverSession.Close() - server.AddTools(tools...) - res, err := clientSession.ListTools(ctx, nil) - if err != nil { - t.Fatal("ListTools() failed:", err) - } - if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("ListTools() mismatch (-want +got):\n%s", diff) - } + t.Run("ListTools", func(t *testing.T) { + clientSession, serverSession, server := createSessions(ctx, nil) + defer clientSession.Close() + defer serverSession.Close() + server.AddTools(tools...) + res, err := clientSession.ListTools(ctx, nil) + if err != nil { + t.Fatal("ListTools() failed:", err) + } + if diff := cmp.Diff(wantListTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("ListTools() mismatch (-want +got):\n%s", diff) + } + }) + t.Run("ToolsIterator", func(t *testing.T) { + for pageSize := range len(tools) + 1 { + testName := fmt.Sprintf("PageSize=%v", pageSize) + t.Run(testName, func(t *testing.T) { + clientSession, serverSession, server := createSessions(ctx, &mcp.ServerOptions{PageSize: pageSize}) + defer clientSession.Close() + defer serverSession.Close() + server.AddTools(tools...) + var gotTools []mcp.Tool + seq := clientSession.Tools(ctx, nil) + for tool, err := range seq { + if err != nil { + t.Fatalf("Tools(%s) failed: %v", testName, err) + } + gotTools = append(gotTools, tool) + } + if diff := cmp.Diff(wantIteratorTools, gotTools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("Tools(%s) mismatch (-want +got):\n%s", testName, diff) + } + }) + } + }) } func TestListResources(t *testing.T) { @@ -99,7 +124,7 @@ func TestListResources(t *testing.T) { resources := []*mcp.ServerResource{resourceA, resourceB, resourceC} wantResource := []*mcp.Resource{resourceA.Resource, resourceB.Resource, resourceC.Resource} ctx := context.Background() - clientSession, serverSession, server := createSessions(ctx) + clientSession, serverSession, server := createSessions(ctx, nil) defer clientSession.Close() defer serverSession.Close() server.AddResources(resources...) @@ -119,7 +144,7 @@ func TestListPrompts(t *testing.T) { prompts := []*mcp.ServerPrompt{promptA, promptB, promptC} wantPrompts := []*mcp.Prompt{promptA.Prompt, promptB.Prompt, promptC.Prompt} ctx := context.Background() - clientSession, serverSession, server := createSessions(ctx) + clientSession, serverSession, server := createSessions(ctx, nil) defer clientSession.Close() defer serverSession.Close() server.AddPrompts(prompts...) From de7968da89c1fe050e2810b7da259a86b0577866 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 28 May 2025 12:51:16 -0400 Subject: [PATCH 142/196] internal/gofix: add -gofix.allow_binding_decl flag In some case, fixes from the inliner may involve literalizations: f(args) -> func(){...}() or the insertion of a binding declaration: f(args) -> var params = args; ... The gofix inliner always discards the first kind, since they are stylistically wanting; this change causes it to conditionally discard the second kind based on the -allow_binding_decl flag, default false. This default is based on feedback from batch inlining in Google's Go corpus. Also, use correct h2 marker ## in doc.go. + test, doc Change-Id: I6855a78b02c60d6a3f65d253b6017e1c6f6bd064 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676935 Reviewed-by: Jonathan Amsterdam Auto-Submit: Alan Donovan LUCI-TryBot-Result: Go LUCI --- gopls/doc/analyzers.md | 90 +++++++++++++++++++ gopls/internal/doc/api.json | 4 +- internal/gofix/doc.go | 22 ++++- internal/gofix/gofix.go | 15 ++++ internal/gofix/gofix_test.go | 16 ++++ .../gofix/testdata/src/binding_false/a.go | 14 +++ .../testdata/src/binding_false/a.go.golden | 14 +++ internal/gofix/testdata/src/binding_true/a.go | 14 +++ .../testdata/src/binding_true/a.go.golden | 15 ++++ 9 files changed, 200 insertions(+), 4 deletions(-) create mode 100644 internal/gofix/testdata/src/binding_false/a.go create mode 100644 internal/gofix/testdata/src/binding_false/a.go.golden create mode 100644 internal/gofix/testdata/src/binding_true/a.go create mode 100644 internal/gofix/testdata/src/binding_true/a.go.golden diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md index 331456a6c82..999e34d9a98 100644 --- a/gopls/doc/analyzers.md +++ b/gopls/doc/analyzers.md @@ -3488,6 +3488,96 @@ Package documentation: [framepointer](https://pkg.go.dev/golang.org/x/tools/go/a The gofix analyzer inlines functions and constants that are marked for inlining. +## Functions + +Given a function that is marked for inlining, like this one: + + //go:fix inline + func Square(x int) int { return Pow(x, 2) } + +this analyzer will recommend that calls to the function elsewhere, in the same +or other packages, should be inlined. + +Inlining can be used to move off of a deprecated function: + + // Deprecated: prefer Pow(x, 2). + //go:fix inline + func Square(x int) int { return Pow(x, 2) } + +It can also be used to move off of an obsolete package, +as when the import path has changed or a higher major version is available: + + package pkg + + import pkg2 "pkg/v2" + + //go:fix inline + func F() { pkg2.F(nil) } + +Replacing a call pkg.F() by pkg2.F(nil) can have no effect on the program, +so this mechanism provides a low-risk way to update large numbers of calls. +We recommend, where possible, expressing the old API in terms of the new one +to enable automatic migration. + +The inliner takes care to avoid behavior changes, even subtle ones, +such as changes to the order in which argument expressions are +evaluated. When it cannot safely eliminate all parameter variables, +it may introduce a "binding declaration" of the form + + var params = args + +to evaluate argument expressions in the correct order and bind them to +parameter variables. Since the resulting code transformation may be +stylistically suboptimal, such inlinings may be disabled by specifying +the -gofix.allow_binding_decl=false flag to the analyzer driver. + +(In cases where it is not safe to "reduce" a call—that is, to replace +a call f(x) by the body of function f, suitably substituted—the +inliner machinery is capable of replacing f by a function literal, +func(){...}(). However, the gofix analyzer discards all such +"literalizations" unconditionally, again on grounds of style.) + +## Constants + +Given a constant that is marked for inlining, like this one: + + //go:fix inline + const Ptr = Pointer + +this analyzer will recommend that uses of Ptr should be replaced with Pointer. + +As with functions, inlining can be used to replace deprecated constants and +constants in obsolete packages. + +A constant definition can be marked for inlining only if it refers to another +named constant. + +The "//go:fix inline" comment must appear before a single const declaration on its own, +as above; before a const declaration that is part of a group, as in this case: + + const ( + C = 1 + //go:fix inline + Ptr = Pointer + ) + +or before a group, applying to every constant in the group: + + //go:fix inline + const ( + Ptr = Pointer + Val = Value + ) + +The proposal https://go.dev/issue/32816 introduces the "//go:fix" directives. + +You can use this (officially unsupported) command to apply gofix fixes en masse: + + $ go run golang.org/x/tools/gopls/internal/analysis/gofix/cmd/gofix@latest -test ./... + +(Do not use "go get -tool" to add gopls as a dependency of your +module; gopls commands must be built from their release branch.) + Default: on. Package documentation: [gofix](https://pkg.go.dev/golang.org/x/tools/internal/gofix) diff --git a/gopls/internal/doc/api.json b/gopls/internal/doc/api.json index 664561ab5f1..1bab1dc08f2 100644 --- a/gopls/internal/doc/api.json +++ b/gopls/internal/doc/api.json @@ -1450,7 +1450,7 @@ }, { "Name": "\"gofix\"", - "Doc": "apply fixes based on go:fix comment directives\n\nThe gofix analyzer inlines functions and constants that are marked for inlining.", + "Doc": "apply fixes based on go:fix comment directives\n\nThe gofix analyzer inlines functions and constants that are marked for inlining.\n\n## Functions\n\nGiven a function that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nthis analyzer will recommend that calls to the function elsewhere, in the same\nor other packages, should be inlined.\n\nInlining can be used to move off of a deprecated function:\n\n\t// Deprecated: prefer Pow(x, 2).\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nIt can also be used to move off of an obsolete package,\nas when the import path has changed or a higher major version is available:\n\n\tpackage pkg\n\n\timport pkg2 \"pkg/v2\"\n\n\t//go:fix inline\n\tfunc F() { pkg2.F(nil) }\n\nReplacing a call pkg.F() by pkg2.F(nil) can have no effect on the program,\nso this mechanism provides a low-risk way to update large numbers of calls.\nWe recommend, where possible, expressing the old API in terms of the new one\nto enable automatic migration.\n\nThe inliner takes care to avoid behavior changes, even subtle ones,\nsuch as changes to the order in which argument expressions are\nevaluated. When it cannot safely eliminate all parameter variables,\nit may introduce a \"binding declaration\" of the form\n\n\tvar params = args\n\nto evaluate argument expressions in the correct order and bind them to\nparameter variables. Since the resulting code transformation may be\nstylistically suboptimal, such inlinings may be disabled by specifying\nthe -gofix.allow_binding_decl=false flag to the analyzer driver.\n\n(In cases where it is not safe to \"reduce\" a call—that is, to replace\na call f(x) by the body of function f, suitably substituted—the\ninliner machinery is capable of replacing f by a function literal,\nfunc(){...}(). However, the gofix analyzer discards all such\n\"literalizations\" unconditionally, again on grounds of style.)\n\n## Constants\n\nGiven a constant that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tconst Ptr = Pointer\n\nthis analyzer will recommend that uses of Ptr should be replaced with Pointer.\n\nAs with functions, inlining can be used to replace deprecated constants and\nconstants in obsolete packages.\n\nA constant definition can be marked for inlining only if it refers to another\nnamed constant.\n\nThe \"//go:fix inline\" comment must appear before a single const declaration on its own,\nas above; before a const declaration that is part of a group, as in this case:\n\n\tconst (\n\t C = 1\n\t //go:fix inline\n\t Ptr = Pointer\n\t)\n\nor before a group, applying to every constant in the group:\n\n\t//go:fix inline\n\tconst (\n\t\tPtr = Pointer\n\t Val = Value\n\t)\n\nThe proposal https://go.dev/issue/32816 introduces the \"//go:fix\" directives.\n\nYou can use this (officially unsupported) command to apply gofix fixes en masse:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/gofix/cmd/gofix@latest -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)", "Default": "true", "Status": "" }, @@ -3176,7 +3176,7 @@ }, { "Name": "gofix", - "Doc": "apply fixes based on go:fix comment directives\n\nThe gofix analyzer inlines functions and constants that are marked for inlining.", + "Doc": "apply fixes based on go:fix comment directives\n\nThe gofix analyzer inlines functions and constants that are marked for inlining.\n\n## Functions\n\nGiven a function that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nthis analyzer will recommend that calls to the function elsewhere, in the same\nor other packages, should be inlined.\n\nInlining can be used to move off of a deprecated function:\n\n\t// Deprecated: prefer Pow(x, 2).\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nIt can also be used to move off of an obsolete package,\nas when the import path has changed or a higher major version is available:\n\n\tpackage pkg\n\n\timport pkg2 \"pkg/v2\"\n\n\t//go:fix inline\n\tfunc F() { pkg2.F(nil) }\n\nReplacing a call pkg.F() by pkg2.F(nil) can have no effect on the program,\nso this mechanism provides a low-risk way to update large numbers of calls.\nWe recommend, where possible, expressing the old API in terms of the new one\nto enable automatic migration.\n\nThe inliner takes care to avoid behavior changes, even subtle ones,\nsuch as changes to the order in which argument expressions are\nevaluated. When it cannot safely eliminate all parameter variables,\nit may introduce a \"binding declaration\" of the form\n\n\tvar params = args\n\nto evaluate argument expressions in the correct order and bind them to\nparameter variables. Since the resulting code transformation may be\nstylistically suboptimal, such inlinings may be disabled by specifying\nthe -gofix.allow_binding_decl=false flag to the analyzer driver.\n\n(In cases where it is not safe to \"reduce\" a call—that is, to replace\na call f(x) by the body of function f, suitably substituted—the\ninliner machinery is capable of replacing f by a function literal,\nfunc(){...}(). However, the gofix analyzer discards all such\n\"literalizations\" unconditionally, again on grounds of style.)\n\n## Constants\n\nGiven a constant that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tconst Ptr = Pointer\n\nthis analyzer will recommend that uses of Ptr should be replaced with Pointer.\n\nAs with functions, inlining can be used to replace deprecated constants and\nconstants in obsolete packages.\n\nA constant definition can be marked for inlining only if it refers to another\nnamed constant.\n\nThe \"//go:fix inline\" comment must appear before a single const declaration on its own,\nas above; before a const declaration that is part of a group, as in this case:\n\n\tconst (\n\t C = 1\n\t //go:fix inline\n\t Ptr = Pointer\n\t)\n\nor before a group, applying to every constant in the group:\n\n\t//go:fix inline\n\tconst (\n\t\tPtr = Pointer\n\t Val = Value\n\t)\n\nThe proposal https://go.dev/issue/32816 introduces the \"//go:fix\" directives.\n\nYou can use this (officially unsupported) command to apply gofix fixes en masse:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/gofix/cmd/gofix@latest -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)", "URL": "https://pkg.go.dev/golang.org/x/tools/internal/gofix", "Default": true }, diff --git a/internal/gofix/doc.go b/internal/gofix/doc.go index 7b7576cb828..b3e693d7dce 100644 --- a/internal/gofix/doc.go +++ b/internal/gofix/doc.go @@ -13,7 +13,7 @@ gofix: apply fixes based on go:fix comment directives The gofix analyzer inlines functions and constants that are marked for inlining. -# Functions +## Functions Given a function that is marked for inlining, like this one: @@ -44,7 +44,25 @@ so this mechanism provides a low-risk way to update large numbers of calls. We recommend, where possible, expressing the old API in terms of the new one to enable automatic migration. -# Constants +The inliner takes care to avoid behavior changes, even subtle ones, +such as changes to the order in which argument expressions are +evaluated. When it cannot safely eliminate all parameter variables, +it may introduce a "binding declaration" of the form + + var params = args + +to evaluate argument expressions in the correct order and bind them to +parameter variables. Since the resulting code transformation may be +stylistically suboptimal, such inlinings may be disabled by specifying +the -gofix.allow_binding_decl=false flag to the analyzer driver. + +(In cases where it is not safe to "reduce" a call—that is, to replace +a call f(x) by the body of function f, suitably substituted—the +inliner machinery is capable of replacing f by a function literal, +func(){...}(). However, the gofix analyzer discards all such +"literalizations" unconditionally, again on grounds of style.) + +## Constants Given a constant that is marked for inlining, like this one: diff --git a/internal/gofix/gofix.go b/internal/gofix/gofix.go index 51b23c65849..9de3b2eaa6c 100644 --- a/internal/gofix/gofix.go +++ b/internal/gofix/gofix.go @@ -43,6 +43,13 @@ var Analyzer = &analysis.Analyzer{ Requires: []*analysis.Analyzer{inspect.Analyzer}, } +var allowBindingDecl bool + +func init() { + Analyzer.Flags.BoolVar(&allowBindingDecl, "allow_binding_decl", false, + "permit inlinings that require a 'var params = args' declaration") +} + // analyzer holds the state for this analysis. type analyzer struct { pass *analysis.Pass @@ -189,6 +196,14 @@ func (a *analyzer) inlineCall(call *ast.CallExpr, cur inspector.Cursor) { // has no indication of what the problem is.) return } + if res.BindingDecl && !allowBindingDecl { + // When applying fix en masse, users are similarly + // unenthusiastic about inlinings that cannot + // entirely eliminate the parameters and + // insert a 'var params = args' declaration. + // The flag allows them to decline such fixes. + return + } got := res.Content // Suggest the "fix". diff --git a/internal/gofix/gofix_test.go b/internal/gofix/gofix_test.go index 9194d893577..21ffd7078d0 100644 --- a/internal/gofix/gofix_test.go +++ b/internal/gofix/gofix_test.go @@ -5,6 +5,7 @@ package gofix import ( + "fmt" "go/ast" "go/importer" "go/parser" @@ -25,6 +26,21 @@ func TestAnalyzer(t *testing.T) { analysistest.RunWithSuggestedFixes(t, analysistest.TestData(), Analyzer, "a", "b") } +func TestAllowBindingDeclFlag(t *testing.T) { + saved := allowBindingDecl + defer func() { allowBindingDecl = saved }() + + run := func(allow bool) { + name := fmt.Sprintf("binding_%v", allow) + t.Run(name, func(t *testing.T) { + allowBindingDecl = allow + analysistest.RunWithSuggestedFixes(t, analysistest.TestData(), Analyzer, name) + }) + } + run(true) // testdata/src/binding_true + run(false) // testdata/src/binding_false +} + func TestTypesWithNames(t *testing.T) { // Test setup inspired by internal/analysisinternal/addimport_test.go. testenv.NeedsDefaultImporter(t) diff --git a/internal/gofix/testdata/src/binding_false/a.go b/internal/gofix/testdata/src/binding_false/a.go new file mode 100644 index 00000000000..1cab5a83275 --- /dev/null +++ b/internal/gofix/testdata/src/binding_false/a.go @@ -0,0 +1,14 @@ +package a + +//go:fix inline +func f(x, y int) int { // want f:`goFixInline a.f` + return y + x +} + +func g() { + f(1, 2) // want `Call of a.f should be inlined` + + f(h(1), h(2)) +} + +func h(int) int diff --git a/internal/gofix/testdata/src/binding_false/a.go.golden b/internal/gofix/testdata/src/binding_false/a.go.golden new file mode 100644 index 00000000000..51a740be06c --- /dev/null +++ b/internal/gofix/testdata/src/binding_false/a.go.golden @@ -0,0 +1,14 @@ +package a + +//go:fix inline +func f(x, y int) int { // want f:`goFixInline a.f` + return y + x +} + +func g() { + _ = 2 + 1 // want `Call of a.f should be inlined` + + f(h(1), h(2)) +} + +func h(int) int diff --git a/internal/gofix/testdata/src/binding_true/a.go b/internal/gofix/testdata/src/binding_true/a.go new file mode 100644 index 00000000000..eab7883d550 --- /dev/null +++ b/internal/gofix/testdata/src/binding_true/a.go @@ -0,0 +1,14 @@ +package a + +//go:fix inline +func f(x, y int) int { // want f:`goFixInline a.f` + return y + x +} + +func g() { + f(1, 2) // want `Call of a.f should be inlined` + + f(h(1), h(2)) // want `Call of a.f should be inlined` +} + +func h(int) int diff --git a/internal/gofix/testdata/src/binding_true/a.go.golden b/internal/gofix/testdata/src/binding_true/a.go.golden new file mode 100644 index 00000000000..97afac40241 --- /dev/null +++ b/internal/gofix/testdata/src/binding_true/a.go.golden @@ -0,0 +1,15 @@ +package a + +//go:fix inline +func f(x, y int) int { // want f:`goFixInline a.f` + return y + x +} + +func g() { + _ = 2 + 1 // want `Call of a.f should be inlined` + + var x int = h(1) + _ = h(2) + x // want `Call of a.f should be inlined` +} + +func h(int) int From d12ca1c0b0fd0b04e3c59646ade63051e86b589d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 28 May 2025 14:41:55 -0400 Subject: [PATCH 143/196] internal/mcp/jsonschema: remove shared structure from schema inference Schemas must form a tree. This is important for Schema.Resolve, and also for modifying parts of the schema before resolution, as with mcp SchemaOptions. Remove the `seen` argument from inference, creating all schemas afresh. Change-Id: Ief8295e8470f3f575ef51cc2ba697901250bbe6f Reviewed-on: https://go-review.googlesource.com/c/tools/+/676937 Auto-Submit: Jonathan Amsterdam Reviewed-by: Hongxiang Jiang LUCI-TryBot-Result: Go LUCI --- internal/mcp/jsonschema/infer.go | 14 +++---- internal/mcp/jsonschema/infer_test.go | 59 ++++++++++++++++++--------- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/internal/mcp/jsonschema/infer.go b/internal/mcp/jsonschema/infer.go index b5605fd56a1..4ce270e5159 100644 --- a/internal/mcp/jsonschema/infer.go +++ b/internal/mcp/jsonschema/infer.go @@ -44,21 +44,17 @@ func For[T any]() (*Schema, error) { // // TODO(rfindley): we could perhaps just skip these incompatible fields. func ForType(t reflect.Type) (*Schema, error) { - return typeSchema(t, make(map[reflect.Type]*Schema)) + return typeSchema(t) } -func typeSchema(t reflect.Type, seen map[reflect.Type]*Schema) (*Schema, error) { +func typeSchema(t reflect.Type) (*Schema, error) { if t.Kind() == reflect.Pointer { t = t.Elem() } - if s := seen[t]; s != nil { - return s, nil - } var ( s = new(Schema) err error ) - seen[t] = s switch t.Kind() { case reflect.Bool: @@ -80,14 +76,14 @@ func typeSchema(t reflect.Type, seen map[reflect.Type]*Schema) (*Schema, error) return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) } s.Type = "object" - s.AdditionalProperties, err = typeSchema(t.Elem(), seen) + s.AdditionalProperties, err = typeSchema(t.Elem()) if err != nil { return nil, fmt.Errorf("computing map value schema: %v", err) } case reflect.Slice, reflect.Array: s.Type = "array" - s.Items, err = typeSchema(t.Elem(), seen) + s.Items, err = typeSchema(t.Elem()) if err != nil { return nil, fmt.Errorf("computing element schema: %v", err) } @@ -113,7 +109,7 @@ func typeSchema(t reflect.Type, seen map[reflect.Type]*Schema) (*Schema, error) if s.Properties == nil { s.Properties = make(map[string]*Schema) } - s.Properties[name], err = typeSchema(field.Type, seen) + s.Properties[name], err = typeSchema(field.Type) if err != nil { return nil, err } diff --git a/internal/mcp/jsonschema/infer_test.go b/internal/mcp/jsonschema/infer_test.go index fe289815a2a..150824cb947 100644 --- a/internal/mcp/jsonschema/infer_test.go +++ b/internal/mcp/jsonschema/infer_test.go @@ -41,31 +41,52 @@ func TestForType(t *testing.T) { Type: "object", AdditionalProperties: &schema{}, }}, - {"struct", forType[struct { - F int `json:"f"` - G []float64 - P *bool - Skip string `json:"-"` - NoSkip string `json:",omitempty"` - unexported float64 - unexported2 int `json:"No"` - }](), &schema{ - Type: "object", - Properties: map[string]*schema{ - "f": {Type: "integer"}, - "G": {Type: "array", Items: &schema{Type: "number"}}, - "P": {Type: "boolean"}, - "NoSkip": {Type: "string"}, + { + "struct", + forType[struct { + F int `json:"f"` + G []float64 + P *bool + Skip string `json:"-"` + NoSkip string `json:",omitempty"` + unexported float64 + unexported2 int `json:"No"` + }](), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "f": {Type: "integer"}, + "G": {Type: "array", Items: &schema{Type: "number"}}, + "P": {Type: "boolean"}, + "NoSkip": {Type: "string"}, + }, + Required: []string{"f", "G", "P"}, + AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, }, - Required: []string{"f", "G", "P"}, - AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, - }}, + }, + { + "no sharing", + forType[struct{ X, Y int }](), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "X": {Type: "integer"}, + "Y": {Type: "integer"}, + }, + Required: []string{"X", "Y"}, + AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, + }, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { if diff := cmp.Diff(test.want, test.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Errorf("ForType mismatch (-want +got):\n%s", diff) + t.Fatalf("ForType mismatch (-want +got):\n%s", diff) + } + // These schemas should all resolve. + if _, err := test.got.Resolve(nil); err != nil { + t.Fatalf("Resolving: %v", err) } }) } From 53be3d459b4b3e2541005a4b2b49f8751e02f270 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 27 May 2025 07:08:57 -0400 Subject: [PATCH 144/196] internal/mcp/jsonschema: apply schema defaults to an instance This CL handles only the limited case of defaults on properties. This is just a start: I need to think more about defaults on other schema keywords, and I need to set defaults recursively. As part of this work, I realized that the Schema.Default field is more usable when represented as a json.RawMessage instead of a *any. Change-Id: Iedcedb405b52607a312dbd895fa61e7f7f94c6bc Reviewed-on: https://go-review.googlesource.com/c/tools/+/676938 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan Auto-Submit: Jonathan Amsterdam --- internal/mcp/jsonschema/resolve.go | 2 +- internal/mcp/jsonschema/schema.go | 20 ++-- internal/mcp/jsonschema/schema_test.go | 6 +- internal/mcp/jsonschema/validate.go | 116 ++++++++++++++++++++++- internal/mcp/jsonschema/validate_test.go | 68 +++++++++++-- 5 files changed, 187 insertions(+), 25 deletions(-) diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index 4725c0d4436..1754135e6aa 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -223,7 +223,7 @@ func (root *Schema) checkStructure() error { func (s *Schema) checkLocal(report func(error)) { addf := func(format string, args ...any) { msg := fmt.Sprintf(format, args...) - report(fmt.Errorf("jsonschema.Schema: %s: %s", s.path, msg)) + report(fmt.Errorf("jsonschema.Schema: %s: %s", s, msg)) } if s == nil { diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 34ec5be73b1..225beaec958 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -57,13 +57,13 @@ type Schema struct { Vocabulary map[string]bool `json:"$vocabulary,omitempty"` // metadata - Title string `json:"title,omitempty"` - Description string `json:"description,omitempty"` - Default *any `json:"default,omitempty"` - Deprecated bool `json:"deprecated,omitempty"` - ReadOnly bool `json:"readOnly,omitempty"` - WriteOnly bool `json:"writeOnly,omitempty"` - Examples []any `json:"examples,omitempty"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Default json.RawMessage `json:"default,omitempty"` + Deprecated bool `json:"deprecated,omitempty"` + ReadOnly bool `json:"readOnly,omitempty"` + WriteOnly bool `json:"writeOnly,omitempty"` + Examples []any `json:"examples,omitempty"` // validation // Use Type for a single type, or Types for multiple types; never both. @@ -254,7 +254,6 @@ func (s *Schema) UnmarshalJSON(data []byte) error { ms := struct { Type json.RawMessage `json:"type,omitempty"` Const json.RawMessage `json:"const,omitempty"` - Default json.RawMessage `json:"default,omitempty"` MinLength *integer `json:"minLength,omitempty"` MaxLength *integer `json:"maxLength,omitempty"` MinItems *integer `json:"minItems,omitempty"` @@ -298,14 +297,11 @@ func (s *Schema) UnmarshalJSON(data []byte) error { return json.Unmarshal(raw, p) } - // Setting Const or Default to a pointer to null will marshal properly, but won't + // Setting Const to a pointer to null will marshal properly, but won't // unmarshal: the *any is set to nil, not a pointer to nil. if err := unmarshalAnyPtr(&s.Const, ms.Const); err != nil { return err } - if err := unmarshalAnyPtr(&s.Default, ms.Default); err != nil { - return err - } set := func(dst **int, src *integer) { if src != nil { diff --git a/internal/mcp/jsonschema/schema_test.go b/internal/mcp/jsonschema/schema_test.go index cc331417286..a5b7baf28ed 100644 --- a/internal/mcp/jsonschema/schema_test.go +++ b/internal/mcp/jsonschema/schema_test.go @@ -24,17 +24,17 @@ func TestGoRoundTrip(t *testing.T) { {Const: Ptr(any(nil))}, {Const: Ptr(any([]int{}))}, {Const: Ptr(any(map[string]any{}))}, - {Default: Ptr(any(nil))}, + {Default: mustMarshal(1)}, + {Default: mustMarshal(nil)}, } { data, err := json.Marshal(s) if err != nil { t.Fatal(err) } - t.Logf("marshal: %s", data) var got *Schema mustUnmarshal(t, data, &got) if !Equal(got, s) { - t.Errorf("got %+v, want %+v", got, s) + t.Errorf("got %s, want %s", got.json(), s.json()) if got.Const != nil && s.Const != nil { t.Logf("Consts: got %#v (%[1]T), want %#v (%[2]T)", *got.Const, *s.Const) } diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index ca26e891c35..f0ddcf6ce58 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -5,6 +5,7 @@ package jsonschema import ( + "encoding/json" "fmt" "hash/maphash" "iter" @@ -47,7 +48,11 @@ func (rs *Resolved) validateDefaults() error { return fmt.Errorf("jsonschema: %s: validateDefaults does not support dynamic refs", s) } if s.Default != nil { - if err := st.validate(reflect.ValueOf(*s.Default), s, nil); err != nil { + var d any + if err := json.Unmarshal(s.Default, &d); err != nil { + fmt.Errorf("unmarshaling default value of schema %s: %w", s, err) + } + if err := st.validate(reflect.ValueOf(d), s, nil); err != nil { return err } } @@ -68,6 +73,7 @@ type state struct { func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations) (err error) { defer wrapf(&err, "validating %s", schema) + // Maintain a stack for dynamic schema resolution. st.stack = append(st.stack, schema) // push defer func() { st.stack = st.stack[:len(st.stack)-1] // pop @@ -536,6 +542,113 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an return nil } +// resolveDynamicRef returns the schema referred to by the argument schema's +// $dynamicRef value. +// It returns an error if the dynamic reference has no referent. +// If there is no $dynamicRef, resolveDynamicRef returns nil, nil. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2. +func (st *state) resolveDynamicRef(schema *Schema) (*Schema, error) { + if schema.DynamicRef == "" { + return nil, nil + } + // The ref behaves lexically or dynamically, but not both. + assert((schema.resolvedDynamicRef == nil) != (schema.dynamicRefAnchor == ""), + "DynamicRef not statically resolved properly") + if r := schema.resolvedDynamicRef; r != nil { + // Same as $ref. + return r, nil + } + // Dynamic behavior. + // Look for the base of the outermost schema on the stack with this dynamic + // anchor. (Yes, outermost: the one farthest from here. This the opposite + // of how ordinary dynamic variables behave.) + // Why the base of the schema being validated and not the schema itself? + // Because the base is the scope for anchors. In fact it's possible to + // refer to a schema that is not on the stack, but a child of some base + // on the stack. + // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. + for _, s := range st.stack { + info, ok := s.base.anchors[schema.dynamicRefAnchor] + if ok && info.dynamic { + return info.schema, nil + } + } + return nil, fmt.Errorf("missing dynamic anchor %q", schema.dynamicRefAnchor) +} + +// ApplyDefaults modifies an instance by applying the schema's defaults to it. If +// a schema or sub-schema has a default, then a corresponding zero instance value +// is set to the default. +// +// The JSON Schema specification does not describe how defaults should be interpreted. +// This method honors defaults only on properties, and only those that are not required. +// If the instance is a map and the property is missing, the property is added to +// the map with the default. +// If the instance is a struct, the field corresponding to the property exists, and +// its value is zero, the field is set to the default. +// ApplyDefaults can panic if a default cannot be assigned to a field. +// +// The argument must be a pointer to the instance. +// (In case we decide that top-level defaults are meaningful.) +// +// It is recommended to first call Resolve with a ValidateDefaults option of true, +// then call this method, and lastly call Validate. +// +// TODO(jba): consider what defaults on top-level or array instances might mean. +// TODO(jba): follow $ref and $dynamicRef +// TODO(jba): apply defaults on sub-schemas to corresponding sub-instances. +func (rs *Resolved) ApplyDefaults(instancep any) error { + st := &state{rs: rs} + return st.applyDefaults(reflect.ValueOf(instancep), rs.root) +} + +// Leave this as a potentially recursive helper function, because we'll surely want +// to apply defaults on sub-schemas someday. +func (st *state) applyDefaults(instancep reflect.Value, schema *Schema) (err error) { + defer wrapf(&err, "applyDefaults: schema %s, instance %v", schema, instancep) + + instance := instancep.Elem() + if instance.Kind() == reflect.Map || instance.Kind() == reflect.Struct { + if instance.Kind() == reflect.Map { + if kt := instance.Type().Key(); kt.Kind() != reflect.String { + return fmt.Errorf("map key type %s is not a string", kt) + } + } + for prop, subschema := range schema.Properties { + // Ignore defaults on required properties. (A required property shouldn't have a default.) + if schema.isRequired[prop] { + continue + } + val := property(instance, prop) + switch instance.Kind() { + case reflect.Map: + // If there is a default for this property, and the map key is missing, + // set the map value to the default. + if subschema.Default != nil && !val.IsValid() { + // Create an lvalue, since map values aren't addressable. + lvalue := reflect.New(instance.Type().Elem()) + if err := json.Unmarshal(subschema.Default, lvalue.Interface()); err != nil { + return err + } + instance.SetMapIndex(reflect.ValueOf(prop), lvalue.Elem()) + } + case reflect.Struct: + // If there is a default for this property, and the field exists but is zero, + // set the field to the default. + if subschema.Default != nil && val.IsValid() && val.IsZero() { + if err := json.Unmarshal(subschema.Default, val.Addr().Interface()); err != nil { + return err + } + } + default: + panic(fmt.Sprintf("applyDefaults: property %s: bad value %s of kind %s", + prop, instance, instance.Kind())) + } + } + } + return nil +} + // property returns the value of the property of v with the given name, or the invalid // reflect.Value if there is none. // If v is a map, the property is the value of the map whose key is name. @@ -548,6 +661,7 @@ func property(v reflect.Value, name string) reflect.Value { return v.MapIndex(reflect.ValueOf(name)) case reflect.Struct: props := structPropertiesOf(v.Type()) + // Ignore nonexistent properties. if index, ok := props[name]; ok { return v.FieldByIndex(index) } diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index 0b710f1d958..76eb7c27c10 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -94,14 +94,12 @@ func TestValidateErrors(t *testing.T) { } func TestValidateDefaults(t *testing.T) { - anyptr := func(x any) *any { return &x } - s := &Schema{ Properties: map[string]*Schema{ - "a": {Type: "integer", Default: anyptr(3)}, - "b": {Type: "string", Default: anyptr("s")}, + "a": {Type: "integer", Default: mustMarshal(1)}, + "b": {Type: "string", Default: mustMarshal("s")}, }, - Default: anyptr(map[string]any{"a": 1, "b": "two"}), + Default: mustMarshal(map[string]any{"a": 1, "b": "two"}), } if _, err := s.Resolve(&ResolveOptions{ValidateDefaults: true}); err != nil { t.Fatal(err) @@ -109,10 +107,10 @@ func TestValidateDefaults(t *testing.T) { s = &Schema{ Properties: map[string]*Schema{ - "a": {Type: "integer", Default: anyptr(3)}, - "b": {Type: "string", Default: anyptr("s")}, + "a": {Type: "integer", Default: mustMarshal(3)}, + "b": {Type: "string", Default: mustMarshal("s")}, }, - Default: anyptr(map[string]any{"a": 1, "b": 2}), + Default: mustMarshal(map[string]any{"a": 1, "b": 2}), } _, err := s.Resolve(&ResolveOptions{ValidateDefaults: true}) want := `has type "integer", want "string"` @@ -121,6 +119,52 @@ func TestValidateDefaults(t *testing.T) { } } +func TestApplyDefaults(t *testing.T) { + schema := &Schema{ + Properties: map[string]*Schema{ + "A": {Default: mustMarshal(1)}, + "B": {Default: mustMarshal(2)}, + "C": {Default: mustMarshal(3)}, + }, + Required: []string{"C"}, + } + rs, err := schema.Resolve(&ResolveOptions{ValidateDefaults: true}) + if err != nil { + t.Fatal(err) + } + + type S struct{ A, B, C int } + for _, tt := range []struct { + instancep any // pointer to instance value + want any // desired value (not a pointer) + }{ + { + &map[string]any{"B": 0}, + map[string]any{ + "A": float64(1), // filled from default + "B": 0, // untouched: it was already there + // "C" not added: it is required (Validate will catch that) + }, + }, + { + &S{B: 1}, + S{ + A: 1, // filled from default + B: 1, // untouched: non-zero + C: 0, // untouched: required + }, + }, + } { + if err := rs.ApplyDefaults(tt.instancep); err != nil { + t.Fatal(err) + } + got := reflect.ValueOf(tt.instancep).Elem().Interface() // dereference the pointer + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("\ngot %#v\nwant %#v", got, tt.want) + } + } +} + func TestStructInstance(t *testing.T) { instance := struct { I int @@ -235,6 +279,14 @@ func TestJSONName(t *testing.T) { } } +func mustMarshal(x any) json.RawMessage { + data, err := json.Marshal(x) + if err != nil { + panic(err) + } + return json.RawMessage(data) +} + // loadRemote loads a remote reference used in the test suite. func loadRemote(uri *url.URL) (*Schema, error) { // Anything with localhost:1234 refers to the remotes directory in the test suite repo. From c6e0ebc6b78aff0d677da8b7c6649c6e2a929627 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 28 May 2025 16:56:55 -0400 Subject: [PATCH 145/196] internal/mcp: run conformance tests on 1.24 with a GOEXPERIMENT setting Support the conformance tests on 1.24, if the user has set GOEXPERIMENT=synctest The API is different, but other than that it works fine. Change-Id: I08ffbe094b39a73cd1fc26e7dcf6f22158b2df1d Reviewed-on: https://go-review.googlesource.com/c/tools/+/677015 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/conformance_test.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/internal/mcp/conformance_test.go b/internal/mcp/conformance_test.go index dfd6b264269..d7841d16456 100644 --- a/internal/mcp/conformance_test.go +++ b/internal/mcp/conformance_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.25 +//go:build go1.24 && goexperiment.synctest package mcp @@ -24,6 +24,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" + "golang.org/x/tools/internal/testenv" "golang.org/x/tools/txtar" ) @@ -58,6 +59,7 @@ type conformanceTest struct { // TODO(rfindley): add client conformance tests. func TestServerConformance(t *testing.T) { + testenv.NeedsGoExperiment(t, "synctest") var tests []*conformanceTest dir := filepath.Join("testdata", "conformance", "server") if err := filepath.WalkDir(dir, func(path string, _ fs.DirEntry, err error) error { @@ -86,9 +88,12 @@ func TestServerConformance(t *testing.T) { // By comparison, gopls has a complicated framework based on progress // reporting and careful accounting to detect when all 'expected' work // on the server is complete. - synctest.Test(t, func(t *testing.T) { - runServerTest(t, test) - }) + synctest.Run(func() { runServerTest(t, test) }) + + // TODO: in 1.25, use the following instead: + // synctest.Test(t, func(t *testing.T) { + // runServerTest(t, test) + // }) }) } } @@ -186,7 +191,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { if !seenServer { arch.Files = append(arch.Files, serverFile) } - if err := os.WriteFile(test.path, txtar.Format(arch), 0666); err != nil { + if err := os.WriteFile(test.path, txtar.Format(arch), 0o666); err != nil { t.Fatalf("os.WriteFile(%q) failed: %v", test.path, err) } } else { From 7610d95ea8b939b24f1ec41817fa3a4a421c94d1 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 28 May 2025 18:24:47 -0400 Subject: [PATCH 146/196] go/analysis/passes/printf: refine diagnostic locations This change takes advantage of the new fmtstr package's precise position informtion to report each diagnostic associated with conversion (e.g. %v) in a format string at the precise location of the conversion. (Previously it would report the entire call, which often spanned several lines.) The analysistest framework doesn't let us test it directly, but there is some coverage via gopls' test suite. Change-Id: I47b698207b15ef668bff0bb521178cf9779333fb Reviewed-on: https://go-review.googlesource.com/c/tools/+/677016 Auto-Submit: Alan Donovan Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- go/analysis/passes/printf/printf.go | 54 ++++++++++++------- go/analysis/passes/tests/tests.go | 17 ++---- go/analysis/unitchecker/separate_test.go | 2 +- .../diagnostics/diagnostics_test.go | 4 +- .../marker/testdata/diagnostics/analyzers.txt | 2 +- internal/analysisinternal/analysis.go | 11 ++++ 6 files changed, 52 insertions(+), 38 deletions(-) diff --git a/go/analysis/passes/printf/printf.go b/go/analysis/passes/printf/printf.go index 07d4fcf0a80..159a95ae7d7 100644 --- a/go/analysis/passes/printf/printf.go +++ b/go/analysis/passes/printf/printf.go @@ -22,6 +22,7 @@ import ( "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/internal/analysisinternal" + "golang.org/x/tools/internal/astutil" "golang.org/x/tools/internal/fmtstr" "golang.org/x/tools/internal/typeparams" "golang.org/x/tools/internal/versions" @@ -540,7 +541,7 @@ func checkPrintf(pass *analysis.Pass, fileVersion string, kind Kind, call *ast.C firstArg := idx + 1 // Arguments are immediately after format string. if !strings.Contains(format, "%") { if len(call.Args) > firstArg { - pass.Reportf(call.Lparen, "%s call has arguments but no formatting directives", name) + pass.ReportRangef(call.Args[firstArg], "%s call has arguments but no formatting directives", name) } return } @@ -552,7 +553,7 @@ func checkPrintf(pass *analysis.Pass, fileVersion string, kind Kind, call *ast.C if err != nil { // All error messages are in predicate form ("call has a problem") // so that they may be affixed into a subject ("log.Printf "). - pass.ReportRangef(call.Args[idx], "%s %s", name, err) + pass.ReportRangef(formatArg, "%s %s", name, err) return } @@ -560,20 +561,21 @@ func checkPrintf(pass *analysis.Pass, fileVersion string, kind Kind, call *ast.C maxArgIndex := firstArg - 1 anyIndex := false // Check formats against args. - for _, operation := range operations { - if operation.Prec.Index != -1 || - operation.Width.Index != -1 || - operation.Verb.Index != -1 { + for _, op := range operations { + if op.Prec.Index != -1 || + op.Width.Index != -1 || + op.Verb.Index != -1 { anyIndex = true } - if !okPrintfArg(pass, call, &maxArgIndex, firstArg, name, operation) { + rng := opRange(formatArg, op) + if !okPrintfArg(pass, call, rng, &maxArgIndex, firstArg, name, op) { // One error per format is enough. return } - if operation.Verb.Verb == 'w' { + if op.Verb.Verb == 'w' { switch kind { case KindNone, KindPrint, KindPrintf: - pass.Reportf(call.Pos(), "%s does not support error-wrapping directive %%w", name) + pass.ReportRangef(rng, "%s does not support error-wrapping directive %%w", name) return } } @@ -594,6 +596,18 @@ func checkPrintf(pass *analysis.Pass, fileVersion string, kind Kind, call *ast.C } } +// opRange returns the source range for the specified printf operation, +// such as the position of the %v substring of "...%v...". +func opRange(formatArg ast.Expr, op *fmtstr.Operation) analysis.Range { + if lit, ok := formatArg.(*ast.BasicLit); ok { + start, end, err := astutil.RangeInStringLiteral(lit, op.Range.Start, op.Range.End) + if err == nil { + return analysisinternal.Range(start, end) // position of "%v" + } + } + return formatArg // entire format string +} + // printfArgType encodes the types of expressions a printf verb accepts. It is a bitmask. type printfArgType int @@ -657,7 +671,7 @@ var printVerbs = []printVerb{ // okPrintfArg compares the operation to the arguments actually present, // reporting any discrepancies it can discern, maxArgIndex was the index of the highest used index. // If the final argument is ellipsissed, there's little it can do for that. -func okPrintfArg(pass *analysis.Pass, call *ast.CallExpr, maxArgIndex *int, firstArg int, name string, operation *fmtstr.Operation) (ok bool) { +func okPrintfArg(pass *analysis.Pass, call *ast.CallExpr, rng analysis.Range, maxArgIndex *int, firstArg int, name string, operation *fmtstr.Operation) (ok bool) { verb := operation.Verb.Verb var v printVerb found := false @@ -680,7 +694,7 @@ func okPrintfArg(pass *analysis.Pass, call *ast.CallExpr, maxArgIndex *int, firs if !formatter { if !found { - pass.ReportRangef(call, "%s format %s has unknown verb %c", name, operation.Text, verb) + pass.ReportRangef(rng, "%s format %s has unknown verb %c", name, operation.Text, verb) return false } for _, flag := range operation.Flags { @@ -690,7 +704,7 @@ func okPrintfArg(pass *analysis.Pass, call *ast.CallExpr, maxArgIndex *int, firs continue } if !strings.ContainsRune(v.flags, rune(flag)) { - pass.ReportRangef(call, "%s format %s has unrecognized flag %c", name, operation.Text, flag) + pass.ReportRangef(rng, "%s format %s has unrecognized flag %c", name, operation.Text, flag) return false } } @@ -707,7 +721,7 @@ func okPrintfArg(pass *analysis.Pass, call *ast.CallExpr, maxArgIndex *int, firs // If len(argIndexes)>0, we have something like %.*s and all // indexes in argIndexes must be an integer. for _, argIndex := range argIndexes { - if !argCanBeChecked(pass, call, argIndex, firstArg, operation, name) { + if !argCanBeChecked(pass, call, rng, argIndex, firstArg, operation, name) { return } arg := call.Args[argIndex] @@ -716,7 +730,7 @@ func okPrintfArg(pass *analysis.Pass, call *ast.CallExpr, maxArgIndex *int, firs if reason != "" { details = " (" + reason + ")" } - pass.ReportRangef(call, "%s format %s uses non-int %s%s as argument of *", name, operation.Text, analysisinternal.Format(pass.Fset, arg), details) + pass.ReportRangef(rng, "%s format %s uses non-int %s%s as argument of *", name, operation.Text, analysisinternal.Format(pass.Fset, arg), details) return false } } @@ -738,12 +752,12 @@ func okPrintfArg(pass *analysis.Pass, call *ast.CallExpr, maxArgIndex *int, firs // Now check verb's type. verbArgIndex := operation.Verb.ArgIndex - if !argCanBeChecked(pass, call, verbArgIndex, firstArg, operation, name) { + if !argCanBeChecked(pass, call, rng, verbArgIndex, firstArg, operation, name) { return false } arg := call.Args[verbArgIndex] if isFunctionValue(pass, arg) && verb != 'p' && verb != 'T' { - pass.ReportRangef(call, "%s format %s arg %s is a func value, not called", name, operation.Text, analysisinternal.Format(pass.Fset, arg)) + pass.ReportRangef(rng, "%s format %s arg %s is a func value, not called", name, operation.Text, analysisinternal.Format(pass.Fset, arg)) return false } if reason, ok := matchArgType(pass, v.typ, arg); !ok { @@ -755,14 +769,14 @@ func okPrintfArg(pass *analysis.Pass, call *ast.CallExpr, maxArgIndex *int, firs if reason != "" { details = " (" + reason + ")" } - pass.ReportRangef(call, "%s format %s has arg %s of wrong type %s%s", name, operation.Text, analysisinternal.Format(pass.Fset, arg), typeString, details) + pass.ReportRangef(rng, "%s format %s has arg %s of wrong type %s%s", name, operation.Text, analysisinternal.Format(pass.Fset, arg), typeString, details) return false } // Detect recursive formatting via value's String/Error methods. // The '#' flag suppresses the methods, except with %x, %X, and %q. if v.typ&argString != 0 && v.verb != 'T' && (!strings.Contains(operation.Flags, "#") || strings.ContainsRune("qxX", v.verb)) { if methodName, ok := recursiveStringer(pass, arg); ok { - pass.ReportRangef(call, "%s format %s with arg %s causes recursive %s method call", name, operation.Text, analysisinternal.Format(pass.Fset, arg), methodName) + pass.ReportRangef(rng, "%s format %s with arg %s causes recursive %s method call", name, operation.Text, analysisinternal.Format(pass.Fset, arg), methodName) return false } } @@ -846,7 +860,7 @@ func isFunctionValue(pass *analysis.Pass, e ast.Expr) bool { // argCanBeChecked reports whether the specified argument is statically present; // it may be beyond the list of arguments or in a terminal slice... argument, which // means we can't see it. -func argCanBeChecked(pass *analysis.Pass, call *ast.CallExpr, argIndex, firstArg int, operation *fmtstr.Operation, name string) bool { +func argCanBeChecked(pass *analysis.Pass, call *ast.CallExpr, rng analysis.Range, argIndex, firstArg int, operation *fmtstr.Operation, name string) bool { if argIndex <= 0 { // Shouldn't happen, so catch it with prejudice. panic("negative argIndex") @@ -863,7 +877,7 @@ func argCanBeChecked(pass *analysis.Pass, call *ast.CallExpr, argIndex, firstArg // There are bad indexes in the format or there are fewer arguments than the format needs. // This is the argument number relative to the format: Printf("%s", "hi") will give 1 for the "hi". arg := argIndex - firstArg + 1 // People think of arguments as 1-indexed. - pass.ReportRangef(call, "%s format %s reads arg #%d, but call has %v", name, operation.Text, arg, count(len(call.Args)-firstArg, "arg")) + pass.ReportRangef(rng, "%s format %s reads arg #%d, but call has %v", name, operation.Text, arg, count(len(call.Args)-firstArg, "arg")) return false } diff --git a/go/analysis/passes/tests/tests.go b/go/analysis/passes/tests/tests.go index 9f59006ebb2..d4e9b025324 100644 --- a/go/analysis/passes/tests/tests.go +++ b/go/analysis/passes/tests/tests.go @@ -447,18 +447,6 @@ func checkExampleName(pass *analysis.Pass, fn *ast.FuncDecl) { } } -type tokenRange struct { - p, e token.Pos -} - -func (r tokenRange) Pos() token.Pos { - return r.p -} - -func (r tokenRange) End() token.Pos { - return r.e -} - func checkTest(pass *analysis.Pass, fn *ast.FuncDecl, prefix string) { // Want functions with 0 results and 1 parameter. if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 || @@ -476,8 +464,9 @@ func checkTest(pass *analysis.Pass, fn *ast.FuncDecl, prefix string) { if tparams := fn.Type.TypeParams; tparams != nil && len(tparams.List) > 0 { // Note: cmd/go/internal/load also errors about TestXXX and BenchmarkXXX functions with type parameters. // We have currently decided to also warn before compilation/package loading. This can help users in IDEs. - at := tokenRange{tparams.Opening, tparams.Closing} - pass.ReportRangef(at, "%s has type parameters: it will not be run by go test as a %sXXX function", fn.Name.Name, prefix) + pass.ReportRangef(analysisinternal.Range(tparams.Opening, tparams.Closing), + "%s has type parameters: it will not be run by go test as a %sXXX function", + fn.Name.Name, prefix) } if !isTestSuffix(fn.Name.Name[len(prefix):]) { diff --git a/go/analysis/unitchecker/separate_test.go b/go/analysis/unitchecker/separate_test.go index 8f4a9193d3d..2198154364b 100644 --- a/go/analysis/unitchecker/separate_test.go +++ b/go/analysis/unitchecker/separate_test.go @@ -222,7 +222,7 @@ func MyPrintf(format string, args ...any) { // Observe that the example produces a fact-based diagnostic // from separate analysis of "main", "lib", and "fmt": - const want = `/main/main.go:6:2: [printf] separate/lib.MyPrintf format %s has arg 123 of wrong type int` + const want = `/main/main.go:6:16: [printf] separate/lib.MyPrintf format %s has arg 123 of wrong type int` sort.Strings(allDiagnostics) if got := strings.Join(allDiagnostics, "\n"); got != want { t.Errorf("Got: %s\nWant: %s", got, want) diff --git a/gopls/internal/test/integration/diagnostics/diagnostics_test.go b/gopls/internal/test/integration/diagnostics/diagnostics_test.go index 0d074333352..222077d2e55 100644 --- a/gopls/internal/test/integration/diagnostics/diagnostics_test.go +++ b/gopls/internal/test/integration/diagnostics/diagnostics_test.go @@ -456,7 +456,7 @@ func TestResolveDiagnosticWithDownload(t *testing.T) { // diagnostic for the wrong formatting type. env.AfterChange( Diagnostics( - env.AtRegexp("print.go", "fmt.Printf"), + env.AtRegexp("print.go", "%s"), WithMessage("wrong type int"), ), ) @@ -2071,7 +2071,7 @@ func MyPrintf(format string, args ...interface{}) { env.OpenFile("a/a.go") env.AfterChange( Diagnostics( - env.AtRegexp("a/a.go", "new.*Printf"), + env.AtRegexp("a/a.go", "%d"), WithMessage("format %d has arg \"s\" of wrong type string"), ), ) diff --git a/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt b/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt index 2fa9fbeb2cc..252b4b4180a 100644 --- a/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt +++ b/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt @@ -34,7 +34,7 @@ func _() { // printf func _() { - printfWrapper("%s") //@diag(re`printfWrapper\(.*?\)`, re"example.com/bad.printfWrapper format %s reads arg #1, but call has 0 args") + printfWrapper("%s") //@diag(re`%s`, re"example.com/bad.printfWrapper format %s reads arg #1, but call has 0 args") } func printfWrapper(format string, args ...any) { diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go index e8292d007d5..e46aab02d6b 100644 --- a/internal/analysisinternal/analysis.go +++ b/internal/analysisinternal/analysis.go @@ -670,3 +670,14 @@ func IsStdPackage(path string) bool { } return !strings.Contains(path[:slash], ".") && path != "testdata" } + +// Range returns an [analysis.Range] for the specified start and end positions. +func Range(pos, end token.Pos) analysis.Range { + return tokenRange{pos, end} +} + +// tokenRange is an implementation of the [analysis.Range] interface. +type tokenRange struct{ StartPos, EndPos token.Pos } + +func (r tokenRange) Pos() token.Pos { return r.StartPos } +func (r tokenRange) End() token.Pos { return r.EndPos } From 81de76b22eab23fdb2bb335a9bff16b599853f4c Mon Sep 17 00:00:00 2001 From: cuishuang Date: Fri, 2 May 2025 21:44:54 +0800 Subject: [PATCH 147/196] gopls/internal/analysis/modernize: fix bug in minmax analyzer that incorrectly handles nested if-else-if structures The minmax analyzer currently has a bug when handling nested `if-else-if` structures.When processing code patterns like ```go if index < 0 { x = 0 } else if index >= maxVal { x = maxVal } else { x = index } ``` it incorrectly transforms them into ```go if index < 0 { x = 0 } else x = min(index, maxVal) ``` Ignore the handling of `else if` situations. Fixes golang/go#73576 Change-Id: I5ec1d5d34961fcb3343872bbc1f7edad5b956550 Reviewed-on: https://go-review.googlesource.com/c/tools/+/669495 Reviewed-by: Xie Yuchen Auto-Submit: Alan Donovan Reviewed-by: Robert Findley Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- gopls/internal/analysis/modernize/minmax.go | 12 ++++++++++++ .../modernize/testdata/src/minmax/minmax.go | 13 +++++++++++++ .../modernize/testdata/src/minmax/minmax.go.golden | 13 +++++++++++++ 3 files changed, 38 insertions(+) diff --git a/gopls/internal/analysis/modernize/minmax.go b/gopls/internal/analysis/modernize/minmax.go index 6c896289e1e..641ab38e889 100644 --- a/gopls/internal/analysis/modernize/minmax.go +++ b/gopls/internal/analysis/modernize/minmax.go @@ -13,6 +13,7 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/edge" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/internal/analysisinternal" "golang.org/x/tools/internal/typeparams" @@ -174,6 +175,17 @@ func minmax(pass *analysis.Pass) { astFile := curFile.Node().(*ast.File) for curIfStmt := range curFile.Preorder((*ast.IfStmt)(nil)) { ifStmt := curIfStmt.Node().(*ast.IfStmt) + + // Don't bother handling "if a < b { lhs = rhs }" when it appears + // as the "else" branch of another if-statement. + // if cond { ... } else if a < b { lhs = rhs } + // (This case would require introducing another block + // if cond { ... } else { if a < b { lhs = rhs } } + // and checking that there is no following "else".) + if ek, _ := curIfStmt.ParentEdge(); ek == edge.IfStmt_Else { + continue + } + if compare, ok := ifStmt.Cond.(*ast.BinaryExpr); ok && ifStmt.Init == nil && isInequality(compare.Op) != 0 && diff --git a/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go b/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go index 5f404ed717d..74d84b2edf1 100644 --- a/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go +++ b/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go @@ -156,3 +156,16 @@ func underscoreAssign(a, b int) { _ = a } } + +// Regression test for https://github.com/golang/go/issues/73576. +func nopeIfElseIf(a int) int { + x := 0 + if a < 0 { + x = 0 + } else if a > 100 { + x = 100 + } else { + x = a + } + return x +} diff --git a/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden b/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden index a13c72db5c0..6ae75ed5846 100644 --- a/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden +++ b/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden @@ -143,3 +143,16 @@ func underscoreAssign(a, b int) { _ = a } } + +// Regression test for https://github.com/golang/go/issues/73576. +func nopeIfElseIf(a int) int { + x := 0 + if a < 0 { + x = 0 + } else if a > 100 { + x = 100 + } else { + x = a + } + return x +} From 389a102e86144e488abb930a2e6ab6a1ec08871a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 29 May 2025 10:02:18 -0400 Subject: [PATCH 148/196] gopls/internal/telemetry/cmd/stacks: collect from go.dev frontend The previous storage endpoint is going away soon. Change-Id: If9001604d08ff7cdc1533f449892cf81bc6a17fb Reviewed-on: https://go-review.googlesource.com/c/tools/+/677155 Reviewed-by: Alessandro Arzilli LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Auto-Submit: Alan Donovan --- gopls/internal/telemetry/cmd/stacks/stacks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gopls/internal/telemetry/cmd/stacks/stacks.go b/gopls/internal/telemetry/cmd/stacks/stacks.go index 17f180cacd0..5c7625e3b9c 100644 --- a/gopls/internal/telemetry/cmd/stacks/stacks.go +++ b/gopls/internal/telemetry/cmd/stacks/stacks.go @@ -346,7 +346,7 @@ func readReports(pcfg ProgramConfig, days int) (stacks map[string]map[Info]int64 for i := range days { date := t.Add(-time.Duration(i+1) * 24 * time.Hour).Format(time.DateOnly) - url := fmt.Sprintf("https://storage.googleapis.com/prod-telemetry-merged/%s.json", date) + url := fmt.Sprintf("https://telemetry.go.dev/data/%s", date) resp, err := http.Get(url) if err != nil { return nil, 0, nil, fmt.Errorf("error on GET %s: %v", url, err) From 2815c8bd92cb4ed1b4ef79260efaab44c6237172 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 23 May 2025 15:37:22 -0400 Subject: [PATCH 149/196] internal/tokeninternal: tag AddExistingFiles for go1.24 ...and use FileSet.AddExistingFiles in go1.25. This change is necessary so that CL 675736, which changes the representation of FileSet, can be merged. Updates golang/go#73205 Change-Id: Ic2815130c17b7cadd3d7b55076ad8482c508c3c7 Reviewed-on: https://go-review.googlesource.com/c/tools/+/675955 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Auto-Submit: Alan Donovan --- internal/tokeninternal/tokeninternal_go124.go | 8 +------- internal/tokeninternal/tokeninternal_go125.go | 7 +------ 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/internal/tokeninternal/tokeninternal_go124.go b/internal/tokeninternal/tokeninternal_go124.go index bcf346cac6f..2f8981b8886 100644 --- a/internal/tokeninternal/tokeninternal_go124.go +++ b/internal/tokeninternal/tokeninternal_go124.go @@ -2,10 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TODO(rfindley): once the new AddExistingFiles API has had some time to soak -// in std, use it in x/tools and change the condition below to !go1.25. - -//go:build !addexistingfiles +//go:build !go1.25 package tokeninternal @@ -22,9 +19,6 @@ import ( // AddExistingFiles adds the specified files to the FileSet if they // are not already present. It panics if any pair of files in the // resulting FileSet would overlap. -// -// TODO(adonovan): replace with FileSet.AddExistingFiles in go1.25, -// which is much more efficient. func AddExistingFiles(fset *token.FileSet, files []*token.File) { // This function cannot be implemented as: diff --git a/internal/tokeninternal/tokeninternal_go125.go b/internal/tokeninternal/tokeninternal_go125.go index 712c3414130..bbd5b9504a2 100644 --- a/internal/tokeninternal/tokeninternal_go125.go +++ b/internal/tokeninternal/tokeninternal_go125.go @@ -2,10 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TODO(rfindley): once the new AddExistingFiles API has had some time to soak -// in std, use it here behind the go1.25 build tag. - -//go:build addexistingfiles +//go:build go1.25 package tokeninternal @@ -14,8 +11,6 @@ import "go/token" // AddExistingFiles adds the specified files to the FileSet if they // are not already present. It panics if any pair of files in the // resulting FileSet would overlap. -// -// TODO(adonovan): eliminate when go1.25 is always available. func AddExistingFiles(fset *token.FileSet, files []*token.File) { fset.AddExistingFiles(files...) } From ad22223c4a507b44d14e658f84050b7f4c2d1504 Mon Sep 17 00:00:00 2001 From: Madeline Kalil Date: Tue, 20 May 2025 13:21:44 -0400 Subject: [PATCH 150/196] gopls/internal/golang: fix crash in definition of invalid "continue label" Telemetry reported an out-of-bounds panic for a definition request on the "continue" keyword in a "continue label" statement. A label with an invalid location would not be found on the path enclosing interval, causing this error. This is illustrated by the following example: ``` func F() { label: for i := range 10 { } for i := range 10 { continue label } } ``` Add a bounds check and return nil if the label is not found. Update the marker test for @def to accept variadic arguments. If no argument for dst is specified, assert that there is no definition returned - the value will be an empty protocol.Location. @def now also accepts a named argument for err. Add a new test wrapper GoToDefinitions that returns all result definitions. Fixes golang/go#73797 Change-Id: I266079251eb79646a7906b11286d8b20ceb966ac Reviewed-on: https://go-review.googlesource.com/c/tools/+/674497 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- .../testdata/definition/branch_issue73797.txt | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 gopls/internal/test/marker/testdata/definition/branch_issue73797.txt diff --git a/gopls/internal/test/marker/testdata/definition/branch_issue73797.txt b/gopls/internal/test/marker/testdata/definition/branch_issue73797.txt new file mode 100644 index 00000000000..5b100971903 --- /dev/null +++ b/gopls/internal/test/marker/testdata/definition/branch_issue73797.txt @@ -0,0 +1,24 @@ +This test checks the case of a definition operation on a "continue" with an invalid label. +In gotip, the typechecker no longer associates the continue statement with its invalid label, +so this test case should only be run for go1.24 or earlier. +See the related change in go/types: https://go-review.git.corp.google.com/c/go/+/638257 + +-- flags -- +-max_go_command=go1.24 + +-- go.mod -- +module mod.com + +go 1.18 + +-- a/a.go -- +package a + +func InvalidLabel() { + label: + for i := 0; i < 10; i++ { + } + for i := 0; i < 10; i++ { + continue label //@def("continue") + } +} From bef2d59d69a4f0b0fd631b39ad14a8b4b89f688e Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 29 May 2025 12:59:01 -0400 Subject: [PATCH 151/196] internal/golang: re-fix crash in definition of invalid "continue label" The previous fix (CL 674497) added a test with max_go_command constraint, when what is needed is a max_go constraint (since it concerns the parser and type checker). Consequently the test didn't run on go1.25. Also, the meat of the fix appears to have been lost during a git merge. This CL adds support for _max_go, and for multi-result @def queries, including zero-result, since that's what the new test needs, and includes the fix for the original bounds check from the earlier patchsets of CL 674497. Fixes golang/go#73797 Change-Id: I9bbf089b2dd4669007d8dbf1c72e5a93de89d85b Reviewed-on: https://go-review.googlesource.com/c/tools/+/677295 Reviewed-by: Madeline Kalil LUCI-TryBot-Result: Go LUCI --- gopls/internal/golang/definition.go | 9 ++-- gopls/internal/test/marker/doc.go | 4 +- gopls/internal/test/marker/marker_test.go | 54 +++++++++++-------- .../testdata/definition/branch_issue73797.txt | 7 +-- .../definition/branch_issue73797_go124.txt | 25 +++++++++ 5 files changed, 69 insertions(+), 30 deletions(-) create mode 100644 gopls/internal/test/marker/testdata/definition/branch_issue73797_go124.txt diff --git a/gopls/internal/golang/definition.go b/gopls/internal/golang/definition.go index d64a53a5114..27755d93653 100644 --- a/gopls/internal/golang/definition.go +++ b/gopls/internal/golang/definition.go @@ -87,11 +87,12 @@ func Definition(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p // Handle definition requests for various special kinds of syntax node. path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos) + ancestors := path[1:] switch node := path[0].(type) { // Handle the case where the cursor is on a return statement by jumping to the result variables. case *ast.ReturnStmt: var funcType *ast.FuncType - for _, n := range path[1:] { + for _, n := range ancestors { switch n := n.(type) { case *ast.FuncLit: funcType = n.Type @@ -132,9 +133,9 @@ func Definition(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p } case token.BREAK, token.CONTINUE: // Find innermost relevant ancestor for break/continue. - for i, n := range path[1:] { - if isLabeled { - l, ok := path[1:][i+1].(*ast.LabeledStmt) + for i, n := range ancestors { + if isLabeled && i+1 < len(ancestors) { + l, ok := ancestors[i+1].(*ast.LabeledStmt) if !(ok && l.Label.Name == label.Name()) { continue } diff --git a/gopls/internal/test/marker/doc.go b/gopls/internal/test/marker/doc.go index f3ad975d6fd..40f5b5fa70e 100644 --- a/gopls/internal/test/marker/doc.go +++ b/gopls/internal/test/marker/doc.go @@ -176,8 +176,8 @@ Here is the list of supported action markers: additional fields (source="compiler", kind="error"). Restore them using optional named arguments. - - def(src, dst location): performs a textDocument/definition request at - the src location, and check the result points to the dst location. + - def(src, want ...location): performs a textDocument/definition request at + the src location, and checks that the results equals want. - documentLink(golden): asserts that textDocument/documentLink returns links as described by the golden file. diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index e12fa0f46a3..a41c2f670cf 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -138,6 +138,15 @@ func Test(t *testing.T) { } testenv.NeedsGo1Point(t, go1point) } + if test.maxGoVersion != "" { + // A max Go version may be useful when (e.g.) a recent go/types + // fix makes it impossible to reproduce a certain older crash. + var go1point int + if _, err := fmt.Sscanf(test.maxGoVersion, "go1.%d", &go1point); err != nil { + t.Fatalf("parsing -max_go version: %v", err) + } + testenv.SkipAfterGo1Point(t, go1point) + } if test.minGoCommandVersion != "" { var go1point int if _, err := fmt.Sscanf(test.minGoCommandVersion, "go1.%d", &go1point); err != nil { @@ -627,18 +636,18 @@ type markerTest struct { flags []string // flags extracted from the special "flags" archive file. // Parsed flags values. See the flag definitions below for documentation. - minGoVersion string // minimum Go runtime version; max should never be needed - minGoCommandVersion string - maxGoCommandVersion string - cgo bool - writeGoSum []string - skipGOOS []string - skipGOARCH []string - ignoreExtraDiags bool - filterBuiltins bool - filterKeywords bool - errorsOK bool - mcp bool + minGoVersion, maxGoVersion string // min/max version of Go runtime + minGoCommandVersion, maxGoCommandVersion string // min/max version of ambient go command + + cgo bool + writeGoSum []string + skipGOOS []string + skipGOARCH []string + ignoreExtraDiags bool + filterBuiltins bool + filterKeywords bool + errorsOK bool + mcp bool } // flagSet returns the flagset used for parsing the special "flags" file in the @@ -646,6 +655,7 @@ type markerTest struct { func (t *markerTest) flagSet() *flag.FlagSet { flags := flag.NewFlagSet(t.name, flag.ContinueOnError) flags.StringVar(&t.minGoVersion, "min_go", "", "if set, the minimum go1.X version required for this test") + flags.StringVar(&t.maxGoVersion, "max_go", "", "if set, the maximum go1.X version required for this test") flags.StringVar(&t.minGoCommandVersion, "min_go_command", "", "if set, the minimum go1.X go command version required for this test") flags.StringVar(&t.maxGoCommandVersion, "max_go_command", "", "if set, the maximum go1.X go command version required for this test") flags.BoolVar(&t.cgo, "cgo", false, "if set, requires cgo (both the cgo tool and CGO_ENABLED=1)") @@ -1681,15 +1691,17 @@ func acceptCompletionMarker(mark marker, src protocol.Location, label string, go } // defMarker implements the @def marker, running textDocument/definition at -// the given src location and asserting that there is exactly one resulting -// location, matching dst. -// -// TODO(rfindley): support a variadic destination set. -func defMarker(mark marker, src, dst protocol.Location) { - got := mark.run.env.FirstDefinition(src) - if got != dst { - mark.errorf("definition location does not match:\n\tgot: %s\n\twant %s", - mark.run.fmtLoc(got), mark.run.fmtLoc(dst)) +// the given location and asserting that there the results match want. +func defMarker(mark marker, loc protocol.Location, want ...protocol.Location) { + env := mark.run.env + got, err := env.Editor.Definitions(env.Ctx, loc) + if err != nil { + mark.errorf("definition request failed: %v", err) + return + } + + if err := compareLocations(mark, got, want); err != nil { + mark.errorf("def failed: %v", err) } } diff --git a/gopls/internal/test/marker/testdata/definition/branch_issue73797.txt b/gopls/internal/test/marker/testdata/definition/branch_issue73797.txt index 5b100971903..1e8eb20c0e7 100644 --- a/gopls/internal/test/marker/testdata/definition/branch_issue73797.txt +++ b/gopls/internal/test/marker/testdata/definition/branch_issue73797.txt @@ -4,7 +4,8 @@ so this test case should only be run for go1.24 or earlier. See the related change in go/types: https://go-review.git.corp.google.com/c/go/+/638257 -- flags -- --max_go_command=go1.24 +-min_go=go1.25 +-ignore_extra_diags -- go.mod -- module mod.com @@ -18,7 +19,7 @@ func InvalidLabel() { label: for i := 0; i < 10; i++ { } - for i := 0; i < 10; i++ { - continue label //@def("continue") + for i := 0; i < 10; i++ { //@loc(for, "for") + continue label //@def("continue", for) } } diff --git a/gopls/internal/test/marker/testdata/definition/branch_issue73797_go124.txt b/gopls/internal/test/marker/testdata/definition/branch_issue73797_go124.txt new file mode 100644 index 00000000000..5ab83771496 --- /dev/null +++ b/gopls/internal/test/marker/testdata/definition/branch_issue73797_go124.txt @@ -0,0 +1,25 @@ +This test checks the case of a definition operation on a "continue" with an invalid label. +In gotip, the typechecker no longer associates the continue statement with its invalid label, +so this test case should only be run for go1.24 or earlier. +See the related change in go/types: https://go-review.git.corp.google.com/c/go/+/638257 + +-- flags -- +-max_go=go1.24 +-ignore_extra_diags + +-- go.mod -- +module mod.com + +go 1.18 + +-- a/a.go -- +package a + +func InvalidLabel() { + label: + for i := 0; i < 10; i++ { + } + for i := 0; i < 10; i++ { + continue label //@def("continue") + } +} From 59198a184fb2c80c9930906bed0e897859abd98a Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 29 May 2025 07:57:10 -0400 Subject: [PATCH 152/196] internal/mcp: add resource conformance test This uncovered a problem in how we represent resource results: contents is a slice, not a single value. I tried to use a file resource, but the server->client ListRoots call didn't work because synthetic responses aren't supported. Left a TODO. I enhanced the hello example so I that could use the MCP Inspector to work with resources. The falseSchema change fixes a potential bug where schemas are re-used, which is disallowed. Change-Id: I1f0198043cb6221c2a03f4f83aa0bafbc71167f4 Reviewed-on: https://go-review.googlesource.com/c/tools/+/677136 Reviewed-by: Robert Findley Reviewed-by: Sam Thanawalla LUCI-TryBot-Result: Go LUCI Auto-Submit: Jonathan Amsterdam --- internal/mcp/conformance_test.go | 7 +- internal/mcp/examples/hello/main.go | 31 +++++++ internal/mcp/generate.go | 4 +- internal/mcp/mcp_test.go | 42 ++++++++-- internal/mcp/protocol.go | 4 +- internal/mcp/server.go | 14 ++-- .../conformance/server/resources.txtar | 81 +++++++++++++++++++ 7 files changed, 164 insertions(+), 19 deletions(-) create mode 100644 internal/mcp/testdata/conformance/server/resources.txtar diff --git a/internal/mcp/conformance_test.go b/internal/mcp/conformance_test.go index d7841d16456..dafba8b50ed 100644 --- a/internal/mcp/conformance_test.go +++ b/internal/mcp/conformance_test.go @@ -24,7 +24,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" - "golang.org/x/tools/internal/testenv" "golang.org/x/tools/txtar" ) @@ -56,10 +55,10 @@ type conformanceTest struct { server []jsonrpc2.Message // server messages } +// TODO(jba): support synthetic responses. // TODO(rfindley): add client conformance tests. func TestServerConformance(t *testing.T) { - testenv.NeedsGoExperiment(t, "synctest") var tests []*conformanceTest dir := filepath.Join("testdata", "conformance", "server") if err := filepath.WalkDir(dir, func(path string, _ fs.DirEntry, err error) error { @@ -245,7 +244,7 @@ func loadConformanceTest(dir, path string) (*conformanceTest, error) { // loadFeatures loads lists of named features from the archive file. loadFeatures := func(data []byte) []string { var feats []string - for line := range strings.SplitSeq(string(data), "\n") { + for line := range strings.Lines(string(data)) { if f := strings.TrimSpace(line); f != "" { feats = append(feats, f) } @@ -264,7 +263,7 @@ func loadConformanceTest(dir, path string) (*conformanceTest, error) { test.tools = loadFeatures(f.Data) case "prompts": test.prompts = loadFeatures(f.Data) - case "resource": + case "resources": test.resources = loadFeatures(f.Data) case "client": test.client, err = decodeMessages(f.Data) diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index c672c5f393a..1798c1810a7 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -9,6 +9,7 @@ import ( "flag" "fmt" "net/http" + "net/url" "os" "golang.org/x/tools/internal/mcp" @@ -43,6 +44,14 @@ func main() { mcp.Property("name", mcp.Description("the name to say hi to")), ))) server.AddPrompts(mcp.NewPrompt("greet", "", PromptHi)) + server.AddResources(&mcp.ServerResource{ + Resource: &mcp.Resource{ + Name: "info", + MIMEType: "text/plain", + URI: "embedded:info", + }, + Handler: handleEmbeddedResource, + }) if *httpAddr != "" { handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { @@ -56,3 +65,25 @@ func main() { } } } + +var embeddedResources = map[string]string{ + "info": "This is the hello example server.", +} + +func handleEmbeddedResource(_ context.Context, _ *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) { + u, err := url.Parse(params.URI) + if err != nil { + return nil, err + } + if u.Scheme != "embedded" { + return nil, fmt.Errorf("wrong scheme: %q", u.Scheme) + } + key := u.Opaque + text, ok := embeddedResources[key] + if !ok { + return nil, fmt.Errorf("no embedded resource named %q", key) + } + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{mcp.NewTextResourceContents(params.URI, "text/plain", text)}, + }, nil +} diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index 2d62d4350ca..670a79b5969 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -143,7 +143,7 @@ var declarations = config{ Fields: config{"Params": {Name: "ReadResourceParams"}}, }, "ReadResourceResult": { - Fields: config{"Contents": {Substitute: "*ResourceContents"}}, + Fields: config{"Contents": {Substitute: "[]*ResourceContents"}}, }, "Resource": {}, "ResourceListChangedNotification": { @@ -195,7 +195,7 @@ func main() { log.Fatal(err) } // Resolve the schema so we have the referents of all the Refs. - if _, err := schema.Resolve("", nil); err != nil { + if _, err := schema.Resolve(nil); err != nil { log.Fatal(err) } diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index ca62608d1ac..e908811941a 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log/slog" + "net/url" "path/filepath" "runtime" "slices" @@ -164,7 +165,7 @@ func TestEndToEnd(t *testing.T) { Description: "just fail", InputSchema: &jsonschema.Schema{ Type: "object", - AdditionalProperties: falseSchema, + AdditionalProperties: falseSchema(), }, }, { @@ -176,7 +177,7 @@ func TestEndToEnd(t *testing.T) { Properties: map[string]*jsonschema.Schema{ "Name": {Type: "string"}, }, - AdditionalProperties: falseSchema, + AdditionalProperties: falseSchema(), }, }, } @@ -245,11 +246,14 @@ func TestEndToEnd(t *testing.T) { } else { t.Errorf("reading %s: %v", tt.uri, err) } + } else if g, w := len(rres.Contents), 1; g != w { + t.Errorf("got %d contents, wanted %d", g, w) } else { - if got := rres.Contents.URI; got != tt.uri { + c := rres.Contents[0] + if got := c.URI; got != tt.uri { t.Errorf("got uri %q, want %q", got, tt.uri) } - if got := rres.Contents.MIMEType; got != tt.mimeType { + if got := c.MIMEType; got != tt.mimeType { t.Errorf("%s: got MIME type %q, want %q", tt.uri, got, tt.mimeType) } } @@ -413,13 +417,41 @@ var ( MIMEType: "text/plain", URI: "file:///fail.txt", } + resource3 = &Resource{ + Name: "info", + MIMEType: "text/plain", + URI: "embedded:info", + } readHandler = fileResourceHandler("testdata/files") resources = map[string]*ServerResource{ "info.txt": {resource1, readHandler}, "fail.txt": {resource2, readHandler}, + "info": {resource3, handleEmbeddedResource}, } ) +var embeddedResources = map[string]string{ + "info": "This is the MCP test server.", +} + +func handleEmbeddedResource(_ context.Context, _ *ServerSession, params *ReadResourceParams) (*ReadResourceResult, error) { + u, err := url.Parse(params.URI) + if err != nil { + return nil, err + } + if u.Scheme != "embedded" { + return nil, fmt.Errorf("wrong scheme: %q", u.Scheme) + } + key := u.Opaque + text, ok := embeddedResources[key] + if !ok { + return nil, fmt.Errorf("no embedded resource named %q", key) + } + return &ReadResourceResult{ + Contents: []*ResourceContents{NewTextResourceContents(params.URI, "text/plain", text)}, + }, nil +} + // Add calls the given function to add the named features. func add[T any](m map[string]T, add func(...T), names ...string) { for _, name := range names { @@ -643,4 +675,4 @@ func traceCalls[S ClientSession | ServerSession](w io.Writer, prefix string) Mid } } -var falseSchema = &jsonschema.Schema{Not: &jsonschema.Schema{}} +func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonschema.Schema{}} } diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index 15babb6fd5f..4e212c6afea 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -430,8 +430,8 @@ func (x *ReadResourceParams) GetMeta() *Meta { return &x.Meta } type ReadResourceResult struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. - Meta Meta `json:"_meta,omitempty"` - Contents *ResourceContents `json:"contents"` + Meta Meta `json:"_meta,omitempty"` + Contents []*ResourceContents `json:"contents"` } func (x *ReadResourceResult) GetMeta() *Meta { return &x.Meta } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 76bebb25a0a..56604d4d071 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -277,11 +277,13 @@ func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *Re return nil, fmt.Errorf("reading resource %s: read handler returned nil information", uri) } // As a convenience, populate some fields. - if res.Contents.URI == "" { - res.Contents.URI = uri - } - if res.Contents.MIMEType == "" { - res.Contents.MIMEType = resource.Resource.MIMEType + for _, c := range res.Contents { + if c.URI == "" { + c.URI = uri + } + if c.MIMEType == "" { + c.MIMEType = resource.Resource.MIMEType + } } return res, nil } @@ -329,7 +331,7 @@ func fileResourceHandler(dir string) ResourceHandler { return nil, err } // TODO(jba): figure out mime type. - return &ReadResourceResult{Contents: NewBlobResourceContents(params.URI, "text/plain", data)}, nil + return &ReadResourceResult{Contents: []*ResourceContents{NewBlobResourceContents(params.URI, "text/plain", data)}}, nil } } diff --git a/internal/mcp/testdata/conformance/server/resources.txtar b/internal/mcp/testdata/conformance/server/resources.txtar new file mode 100644 index 00000000000..6f3679d9396 --- /dev/null +++ b/internal/mcp/testdata/conformance/server/resources.txtar @@ -0,0 +1,81 @@ +Check behavior of a server with just resources. + +Fixed bugs: +- A resource result holds a slice of contents, not just one. + +-- resources -- +info + +-- client -- +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } + } +} +{ "jsonrpc": "2.0", "id": 2, "method": "resources/list" } +{ + "jsonrpc": "2.0", "id": 3, + "method": "resources/read", + "params": { + "uri": "embedded:info" + } +} +-- server -- +{ + "jsonrpc": "2.0", + "id": 1, + "result": { + "_meta": {}, + "capabilities": { + "completions": {}, + "logging": {}, + "prompts": { + "listChanged": true + }, + "resources": { + "listChanged": true + }, + "tools": { + "listChanged": true + } + }, + "protocolVersion": "2024-11-05", + "serverInfo": { + "name": "testServer", + "version": "v1.0.0" + } + } +} +{ + "jsonrpc": "2.0", + "id": 2, + "result": { + "_meta": {}, + "resources": [ + { + "mimeType": "text/plain", + "name": "info", + "uri": "embedded:info" + } + ] + } +} +{ + "jsonrpc": "2.0", + "id": 3, + "result": { + "_meta": {}, + "contents": [ + { + "uri": "embedded:info", + "mimeType": "text/plain", + "text": "This is the MCP test server." + } + ] + } +} From 147cb9c3c6288958e1f06bc732509a3a7413b268 Mon Sep 17 00:00:00 2001 From: xieyuschen Date: Mon, 28 Apr 2025 20:57:16 -0600 Subject: [PATCH 153/196] gopls/internal/analysis/modernize: minmax: put comments at proper positions This CL improves the approach to preserve comments in modernize minmax. It preserves the comments for each if/else block and put them above its corresponding variables after minmax modernizing. Updates golang/go#73473 Change-Id: Ia91dec8d03a0009aa688b5eb1370c9fd43a402d3 Reviewed-on: https://go-review.googlesource.com/c/tools/+/668495 Reviewed-by: Alan Donovan Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan --- gopls/internal/analysis/modernize/minmax.go | 25 ++++-- .../testdata/src/minmax/minmax.go.golden | 76 +++++++++++-------- 2 files changed, 61 insertions(+), 40 deletions(-) diff --git a/gopls/internal/analysis/modernize/minmax.go b/gopls/internal/analysis/modernize/minmax.go index 641ab38e889..4ab869ae903 100644 --- a/gopls/internal/analysis/modernize/minmax.go +++ b/gopls/internal/analysis/modernize/minmax.go @@ -48,6 +48,15 @@ func minmax(pass *analysis.Pass) { rhs = tassign.Rhs[0] scope = pass.TypesInfo.Scopes[ifStmt.Body] sign = isInequality(compare.Op) + + // callArg formats a call argument, preserving comments from [start-end). + callArg = func(arg ast.Expr, start, end token.Pos) string { + comments := allComments(file, start, end) + return cond(arg == b, ", ", "") + // second argument needs a comma + cond(comments != "", "\n", "") + // comments need their own line + comments + + analysisinternal.Format(pass.Fset, arg) + } ) if fblock, ok := ifStmt.Else.(*ast.BlockStmt); ok && isAssignBlock(fblock) { @@ -91,12 +100,12 @@ func minmax(pass *analysis.Pass) { // Replace IfStmt with lhs = min(a, b). Pos: ifStmt.Pos(), End: ifStmt.End(), - NewText: fmt.Appendf(nil, "%s%s = %s(%s, %s)", - allComments(file, ifStmt.Pos(), ifStmt.End()), + NewText: fmt.Appendf(nil, "%s = %s(%s%s)", analysisinternal.Format(pass.Fset, lhs), sym, - analysisinternal.Format(pass.Fset, a), - analysisinternal.Format(pass.Fset, b)), + callArg(a, ifStmt.Pos(), ifStmt.Else.Pos()), + callArg(b, ifStmt.Else.Pos(), ifStmt.End()), + ), }}, }}, }) @@ -154,13 +163,13 @@ func minmax(pass *analysis.Pass) { Pos: fassign.Pos(), End: ifStmt.End(), // Replace "x := a; if ... {}" with "x = min(...)", preserving comments. - NewText: fmt.Appendf(nil, "%s %s %s %s(%s, %s)", - allComments(file, fassign.Pos(), ifStmt.End()), + NewText: fmt.Appendf(nil, "%s %s %s(%s%s)", analysisinternal.Format(pass.Fset, lhs), fassign.Tok.String(), sym, - analysisinternal.Format(pass.Fset, a), - analysisinternal.Format(pass.Fset, b)), + callArg(a, fassign.Pos(), ifStmt.Pos()), + callArg(b, ifStmt.Pos(), ifStmt.End()), + ), }}, }}, }) diff --git a/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden b/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden index 6ae75ed5846..f8dc94b3702 100644 --- a/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden +++ b/gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden @@ -1,45 +1,52 @@ package minmax func ifmin(a, b int) { - // A - // B - // want "if statement can be modernized using max" - // C - // D - // E - x := max(a, b) + x := max( + // A + // B + a, + // want "if statement can be modernized using max" + // C + // D + // E + b) print(x) } func ifmax(a, b int) { - // want "if statement can be modernized using min" - x := min(a, b) + x := min(a, + // want "if statement can be modernized using min" + b) print(x) } func ifminvariant(a, b int) { - // want "if statement can be modernized using min" - x := min(a, b) + x := min(a, + // want "if statement can be modernized using min" + b) print(x) } func ifmaxvariant(a, b int) { - // want "if statement can be modernized using min" - x := min(a, b) + x := min(a, + // want "if statement can be modernized using min" + b) print(x) } func ifelsemin(a, b int) { var x int // A // B - // want "if/else statement can be modernized using min" - // C - // D - // E - // F - // G - // H - x = min(a, b) + x = min( + // want "if/else statement can be modernized using min" + // C + // D + // E + a, + // F + // G + // H + b) print(x) } @@ -47,12 +54,14 @@ func ifelsemax(a, b int) { // A var x int // B // C - // want "if/else statement can be modernized using max" - // D - // E - // F - // G - x = max(a, b) + x = max( + // want "if/else statement can be modernized using max" + // D + // E + // F + a, + // G + b) print(x) } @@ -79,8 +88,9 @@ func nopeIfStmtHasInitStmt() { // Regression test for a bug: fix was "y := max(x, y)". func oops() { x := 1 - // want "if statement can be modernized using max" - y := max(x, 2) + y := max(x, + // want "if statement can be modernized using max" + 2) print(y) } @@ -119,9 +129,11 @@ func nopeHasElseBlock(x int) int { } func fix72727(a, b int) { - // some important comment. DO NOT REMOVE. - // want "if statement can be modernized using max" - o := max(a-42, b) + o := max( + // some important comment. DO NOT REMOVE. + a-42, + // want "if statement can be modernized using max" + b) } type myfloat float64 From d3809eaa245b7132b99349fea2564977d8ea16b5 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 29 May 2025 13:48:50 -0400 Subject: [PATCH 154/196] internal/mcp: handle synthetic responses in conformance test Modify the conformance test driver to support outgoing responses. They are returned for incoming requests in order. Add a test that returns a file resource, which involves the server sending a roots/list request to the client. Change-Id: Ie7eb69a8c99ccbf11721518cfffb21244c03e532 Reviewed-on: https://go-review.googlesource.com/c/tools/+/677475 Commit-Queue: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/jsonrpc2_v2/messages.go | 2 +- internal/mcp/conformance_test.go | 90 ++++++++++++++----- .../conformance/server/resources.txtar | 39 ++++++++ 3 files changed, 108 insertions(+), 23 deletions(-) diff --git a/internal/jsonrpc2_v2/messages.go b/internal/jsonrpc2_v2/messages.go index 0aa321d92b6..639c735ed0a 100644 --- a/internal/jsonrpc2_v2/messages.go +++ b/internal/jsonrpc2_v2/messages.go @@ -165,7 +165,7 @@ func DecodeMessage(data []byte) (Message, error) { return nil, fmt.Errorf("unmarshaling jsonrpc message: %w", err) } if msg.VersionTag != wireVersion { - return nil, fmt.Errorf("invalid message version tag %s expected %s", msg.VersionTag, wireVersion) + return nil, fmt.Errorf("invalid message version tag %q; expected %q", msg.VersionTag, wireVersion) } id, err := MakeID(msg.ID) if err != nil { diff --git a/internal/mcp/conformance_test.go b/internal/mcp/conformance_test.go index dafba8b50ed..0510b0db072 100644 --- a/internal/mcp/conformance_test.go +++ b/internal/mcp/conformance_test.go @@ -17,7 +17,6 @@ import ( "os" "path/filepath" "strings" - "sync" "testing" "testing/synctest" @@ -55,7 +54,6 @@ type conformanceTest struct { server []jsonrpc2.Message // server messages } -// TODO(jba): support synthetic responses. // TODO(rfindley): add client conformance tests. func TestServerConformance(t *testing.T) { @@ -119,48 +117,96 @@ func runServerTest(t *testing.T, test *conformanceTest) { t.Fatal(err) } - // Collect server messages asynchronously. - var wg sync.WaitGroup + writeMsg := func(msg jsonrpc2.Message) { + if _, err := cStream.Write(ctx, msg); err != nil { + t.Fatalf("Write failed: %v", err) + } + } + var ( serverMessages []jsonrpc2.Message - serverErr error // abnormal failure of the server stream + outRequests []*jsonrpc2.Request + outResponses []*jsonrpc2.Response ) - wg.Add(1) - go func() { - defer wg.Done() + + // Separate client requests and responses; we use them differently. + for _, msg := range test.client { + switch msg := msg.(type) { + case *jsonrpc2.Request: + outRequests = append(outRequests, msg) + case *jsonrpc2.Response: + outResponses = append(outResponses, msg) + default: + t.Fatalf("bad message type %T", msg) + } + } + + // nextResponse handles incoming requests and notifications, and returns the + // next incoming response. + nextResponse := func() (*jsonrpc2.Response, error, bool) { for { msg, _, err := cStream.Read(ctx) if err != nil { // TODO(rfindley): we don't document (or want to document) that the in // memory transports use a net.Pipe. How can users detect this failure? // Should we promote it to EOF? - if !errors.Is(err, io.ErrClosedPipe) { - serverErr = err + if errors.Is(err, io.ErrClosedPipe) { + err = nil } - break + return nil, err, false } serverMessages = append(serverMessages, msg) + if req, ok := msg.(*jsonrpc2.Request); ok && req.ID.IsValid() { + // Pair up the next outgoing response with this request. + // We assume requests arrive in the same order every time. + if len(outResponses) == 0 { + t.Fatalf("no outgoing response for request %v", req) + } + outResponses[0].ID = req.ID + writeMsg(outResponses[0]) + outResponses = outResponses[1:] + continue + } + return msg.(*jsonrpc2.Response), nil, true } - }() + } - // Write client messages to the stream. - for _, msg := range test.client { - if _, err := cStream.Write(ctx, msg); err != nil { - t.Fatalf("Write failed: %v", err) + // Synthetic peer interacts with real peer. + for _, req := range outRequests { + writeMsg(req) + if req.ID.IsValid() { + // A request (as opposed to a notification). Wait for the response. + res, err, ok := nextResponse() + if err != nil { + t.Fatalf("reading server messages failed: %v", err) + } + if !ok { + t.Fatalf("missing response for request %v", req) + } + if res.ID != req.ID { + t.Fatalf("out-of-order response %v to request %v", req, res) + } } } - + // There might be more notifications or requests, but there shouldn't be more + // responses. + // Run this in a goroutine so the current thread can wait for it. + var extra *jsonrpc2.Response + go func() { + extra, err, _ = nextResponse() + }() // Before closing the stream, wait for all messages to be processed. synctest.Wait() + if err != nil { + t.Fatalf("reading server messages failedd: %v", err) + } + if extra != nil { + t.Fatalf("got extra response: %v", extra) + } if err := cStream.Close(); err != nil { t.Fatalf("Stream.Close failed: %v", err) } - ss.Wait() - wg.Wait() - if serverErr != nil { - t.Fatalf("reading server messages failed: %v", serverErr) - } // Handle server output. If -update is set, write the 'server' file. // Otherwise, compare with expected. diff --git a/internal/mcp/testdata/conformance/server/resources.txtar b/internal/mcp/testdata/conformance/server/resources.txtar index 6f3679d9396..6f1f3533ce6 100644 --- a/internal/mcp/testdata/conformance/server/resources.txtar +++ b/internal/mcp/testdata/conformance/server/resources.txtar @@ -5,6 +5,7 @@ Fixed bugs: -- resources -- info +info.txt -- client -- { @@ -25,6 +26,19 @@ info "uri": "embedded:info" } } +{ + "jsonrpc": "2.0", "id": 3, + "method": "resources/read", + "params": { + "uri": "file:///info.txt" + } +} +{ + "jsonrpc": "2.0", "id": 0, + "result": { + "roots": [] + } +} -- server -- { "jsonrpc": "2.0", @@ -61,6 +75,11 @@ info "mimeType": "text/plain", "name": "info", "uri": "embedded:info" + }, + { + "mimeType": "text/plain", + "name": "public", + "uri": "file:///info.txt" } ] } @@ -79,3 +98,23 @@ info ] } } +{ + "jsonrpc": "2.0", + "id": 1, + "method": "roots/list", + "params": null +} +{ + "jsonrpc": "2.0", + "id": 3, + "result": { + "_meta": {}, + "contents": [ + { + "uri": "file:///info.txt", + "mimeType": "text/plain", + "blob": "Q29udGVudHMK" + } + ] + } +} From c8e47eb2c1e78b0b03c3d106761301c1b0c73c82 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 30 May 2025 11:17:38 -0400 Subject: [PATCH 155/196] internal/gofix: document batch fix commands Updates golang/go#32816 Change-Id: Ifd482757d2ca57d4556e27b732cf2fba59d1c729 Reviewed-on: https://go-review.googlesource.com/c/tools/+/677515 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- gopls/doc/analyzers.md | 2 +- gopls/internal/doc/api.json | 4 ++-- internal/gofix/cmd/gofix/main.go | 18 +++++++++++++++--- internal/gofix/doc.go | 2 +- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md index 999e34d9a98..06892852319 100644 --- a/gopls/doc/analyzers.md +++ b/gopls/doc/analyzers.md @@ -3573,7 +3573,7 @@ The proposal https://go.dev/issue/32816 introduces the "//go:fix" directives. You can use this (officially unsupported) command to apply gofix fixes en masse: - $ go run golang.org/x/tools/gopls/internal/analysis/gofix/cmd/gofix@latest -test ./... + $ go run golang.org/x/tools/internal/gofix/cmd/gofix@latest -test ./... (Do not use "go get -tool" to add gopls as a dependency of your module; gopls commands must be built from their release branch.) diff --git a/gopls/internal/doc/api.json b/gopls/internal/doc/api.json index 1bab1dc08f2..cd325f364a3 100644 --- a/gopls/internal/doc/api.json +++ b/gopls/internal/doc/api.json @@ -1450,7 +1450,7 @@ }, { "Name": "\"gofix\"", - "Doc": "apply fixes based on go:fix comment directives\n\nThe gofix analyzer inlines functions and constants that are marked for inlining.\n\n## Functions\n\nGiven a function that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nthis analyzer will recommend that calls to the function elsewhere, in the same\nor other packages, should be inlined.\n\nInlining can be used to move off of a deprecated function:\n\n\t// Deprecated: prefer Pow(x, 2).\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nIt can also be used to move off of an obsolete package,\nas when the import path has changed or a higher major version is available:\n\n\tpackage pkg\n\n\timport pkg2 \"pkg/v2\"\n\n\t//go:fix inline\n\tfunc F() { pkg2.F(nil) }\n\nReplacing a call pkg.F() by pkg2.F(nil) can have no effect on the program,\nso this mechanism provides a low-risk way to update large numbers of calls.\nWe recommend, where possible, expressing the old API in terms of the new one\nto enable automatic migration.\n\nThe inliner takes care to avoid behavior changes, even subtle ones,\nsuch as changes to the order in which argument expressions are\nevaluated. When it cannot safely eliminate all parameter variables,\nit may introduce a \"binding declaration\" of the form\n\n\tvar params = args\n\nto evaluate argument expressions in the correct order and bind them to\nparameter variables. Since the resulting code transformation may be\nstylistically suboptimal, such inlinings may be disabled by specifying\nthe -gofix.allow_binding_decl=false flag to the analyzer driver.\n\n(In cases where it is not safe to \"reduce\" a call—that is, to replace\na call f(x) by the body of function f, suitably substituted—the\ninliner machinery is capable of replacing f by a function literal,\nfunc(){...}(). However, the gofix analyzer discards all such\n\"literalizations\" unconditionally, again on grounds of style.)\n\n## Constants\n\nGiven a constant that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tconst Ptr = Pointer\n\nthis analyzer will recommend that uses of Ptr should be replaced with Pointer.\n\nAs with functions, inlining can be used to replace deprecated constants and\nconstants in obsolete packages.\n\nA constant definition can be marked for inlining only if it refers to another\nnamed constant.\n\nThe \"//go:fix inline\" comment must appear before a single const declaration on its own,\nas above; before a const declaration that is part of a group, as in this case:\n\n\tconst (\n\t C = 1\n\t //go:fix inline\n\t Ptr = Pointer\n\t)\n\nor before a group, applying to every constant in the group:\n\n\t//go:fix inline\n\tconst (\n\t\tPtr = Pointer\n\t Val = Value\n\t)\n\nThe proposal https://go.dev/issue/32816 introduces the \"//go:fix\" directives.\n\nYou can use this (officially unsupported) command to apply gofix fixes en masse:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/gofix/cmd/gofix@latest -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)", + "Doc": "apply fixes based on go:fix comment directives\n\nThe gofix analyzer inlines functions and constants that are marked for inlining.\n\n## Functions\n\nGiven a function that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nthis analyzer will recommend that calls to the function elsewhere, in the same\nor other packages, should be inlined.\n\nInlining can be used to move off of a deprecated function:\n\n\t// Deprecated: prefer Pow(x, 2).\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nIt can also be used to move off of an obsolete package,\nas when the import path has changed or a higher major version is available:\n\n\tpackage pkg\n\n\timport pkg2 \"pkg/v2\"\n\n\t//go:fix inline\n\tfunc F() { pkg2.F(nil) }\n\nReplacing a call pkg.F() by pkg2.F(nil) can have no effect on the program,\nso this mechanism provides a low-risk way to update large numbers of calls.\nWe recommend, where possible, expressing the old API in terms of the new one\nto enable automatic migration.\n\nThe inliner takes care to avoid behavior changes, even subtle ones,\nsuch as changes to the order in which argument expressions are\nevaluated. When it cannot safely eliminate all parameter variables,\nit may introduce a \"binding declaration\" of the form\n\n\tvar params = args\n\nto evaluate argument expressions in the correct order and bind them to\nparameter variables. Since the resulting code transformation may be\nstylistically suboptimal, such inlinings may be disabled by specifying\nthe -gofix.allow_binding_decl=false flag to the analyzer driver.\n\n(In cases where it is not safe to \"reduce\" a call—that is, to replace\na call f(x) by the body of function f, suitably substituted—the\ninliner machinery is capable of replacing f by a function literal,\nfunc(){...}(). However, the gofix analyzer discards all such\n\"literalizations\" unconditionally, again on grounds of style.)\n\n## Constants\n\nGiven a constant that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tconst Ptr = Pointer\n\nthis analyzer will recommend that uses of Ptr should be replaced with Pointer.\n\nAs with functions, inlining can be used to replace deprecated constants and\nconstants in obsolete packages.\n\nA constant definition can be marked for inlining only if it refers to another\nnamed constant.\n\nThe \"//go:fix inline\" comment must appear before a single const declaration on its own,\nas above; before a const declaration that is part of a group, as in this case:\n\n\tconst (\n\t C = 1\n\t //go:fix inline\n\t Ptr = Pointer\n\t)\n\nor before a group, applying to every constant in the group:\n\n\t//go:fix inline\n\tconst (\n\t\tPtr = Pointer\n\t Val = Value\n\t)\n\nThe proposal https://go.dev/issue/32816 introduces the \"//go:fix\" directives.\n\nYou can use this (officially unsupported) command to apply gofix fixes en masse:\n\n\t$ go run golang.org/x/tools/internal/gofix/cmd/gofix@latest -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)", "Default": "true", "Status": "" }, @@ -3176,7 +3176,7 @@ }, { "Name": "gofix", - "Doc": "apply fixes based on go:fix comment directives\n\nThe gofix analyzer inlines functions and constants that are marked for inlining.\n\n## Functions\n\nGiven a function that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nthis analyzer will recommend that calls to the function elsewhere, in the same\nor other packages, should be inlined.\n\nInlining can be used to move off of a deprecated function:\n\n\t// Deprecated: prefer Pow(x, 2).\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nIt can also be used to move off of an obsolete package,\nas when the import path has changed or a higher major version is available:\n\n\tpackage pkg\n\n\timport pkg2 \"pkg/v2\"\n\n\t//go:fix inline\n\tfunc F() { pkg2.F(nil) }\n\nReplacing a call pkg.F() by pkg2.F(nil) can have no effect on the program,\nso this mechanism provides a low-risk way to update large numbers of calls.\nWe recommend, where possible, expressing the old API in terms of the new one\nto enable automatic migration.\n\nThe inliner takes care to avoid behavior changes, even subtle ones,\nsuch as changes to the order in which argument expressions are\nevaluated. When it cannot safely eliminate all parameter variables,\nit may introduce a \"binding declaration\" of the form\n\n\tvar params = args\n\nto evaluate argument expressions in the correct order and bind them to\nparameter variables. Since the resulting code transformation may be\nstylistically suboptimal, such inlinings may be disabled by specifying\nthe -gofix.allow_binding_decl=false flag to the analyzer driver.\n\n(In cases where it is not safe to \"reduce\" a call—that is, to replace\na call f(x) by the body of function f, suitably substituted—the\ninliner machinery is capable of replacing f by a function literal,\nfunc(){...}(). However, the gofix analyzer discards all such\n\"literalizations\" unconditionally, again on grounds of style.)\n\n## Constants\n\nGiven a constant that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tconst Ptr = Pointer\n\nthis analyzer will recommend that uses of Ptr should be replaced with Pointer.\n\nAs with functions, inlining can be used to replace deprecated constants and\nconstants in obsolete packages.\n\nA constant definition can be marked for inlining only if it refers to another\nnamed constant.\n\nThe \"//go:fix inline\" comment must appear before a single const declaration on its own,\nas above; before a const declaration that is part of a group, as in this case:\n\n\tconst (\n\t C = 1\n\t //go:fix inline\n\t Ptr = Pointer\n\t)\n\nor before a group, applying to every constant in the group:\n\n\t//go:fix inline\n\tconst (\n\t\tPtr = Pointer\n\t Val = Value\n\t)\n\nThe proposal https://go.dev/issue/32816 introduces the \"//go:fix\" directives.\n\nYou can use this (officially unsupported) command to apply gofix fixes en masse:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/gofix/cmd/gofix@latest -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)", + "Doc": "apply fixes based on go:fix comment directives\n\nThe gofix analyzer inlines functions and constants that are marked for inlining.\n\n## Functions\n\nGiven a function that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nthis analyzer will recommend that calls to the function elsewhere, in the same\nor other packages, should be inlined.\n\nInlining can be used to move off of a deprecated function:\n\n\t// Deprecated: prefer Pow(x, 2).\n\t//go:fix inline\n\tfunc Square(x int) int { return Pow(x, 2) }\n\nIt can also be used to move off of an obsolete package,\nas when the import path has changed or a higher major version is available:\n\n\tpackage pkg\n\n\timport pkg2 \"pkg/v2\"\n\n\t//go:fix inline\n\tfunc F() { pkg2.F(nil) }\n\nReplacing a call pkg.F() by pkg2.F(nil) can have no effect on the program,\nso this mechanism provides a low-risk way to update large numbers of calls.\nWe recommend, where possible, expressing the old API in terms of the new one\nto enable automatic migration.\n\nThe inliner takes care to avoid behavior changes, even subtle ones,\nsuch as changes to the order in which argument expressions are\nevaluated. When it cannot safely eliminate all parameter variables,\nit may introduce a \"binding declaration\" of the form\n\n\tvar params = args\n\nto evaluate argument expressions in the correct order and bind them to\nparameter variables. Since the resulting code transformation may be\nstylistically suboptimal, such inlinings may be disabled by specifying\nthe -gofix.allow_binding_decl=false flag to the analyzer driver.\n\n(In cases where it is not safe to \"reduce\" a call—that is, to replace\na call f(x) by the body of function f, suitably substituted—the\ninliner machinery is capable of replacing f by a function literal,\nfunc(){...}(). However, the gofix analyzer discards all such\n\"literalizations\" unconditionally, again on grounds of style.)\n\n## Constants\n\nGiven a constant that is marked for inlining, like this one:\n\n\t//go:fix inline\n\tconst Ptr = Pointer\n\nthis analyzer will recommend that uses of Ptr should be replaced with Pointer.\n\nAs with functions, inlining can be used to replace deprecated constants and\nconstants in obsolete packages.\n\nA constant definition can be marked for inlining only if it refers to another\nnamed constant.\n\nThe \"//go:fix inline\" comment must appear before a single const declaration on its own,\nas above; before a const declaration that is part of a group, as in this case:\n\n\tconst (\n\t C = 1\n\t //go:fix inline\n\t Ptr = Pointer\n\t)\n\nor before a group, applying to every constant in the group:\n\n\t//go:fix inline\n\tconst (\n\t\tPtr = Pointer\n\t Val = Value\n\t)\n\nThe proposal https://go.dev/issue/32816 introduces the \"//go:fix\" directives.\n\nYou can use this (officially unsupported) command to apply gofix fixes en masse:\n\n\t$ go run golang.org/x/tools/internal/gofix/cmd/gofix@latest -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)", "URL": "https://pkg.go.dev/golang.org/x/tools/internal/gofix", "Default": true }, diff --git a/internal/gofix/cmd/gofix/main.go b/internal/gofix/cmd/gofix/main.go index 9ec77943774..14d120838f3 100644 --- a/internal/gofix/cmd/gofix/main.go +++ b/internal/gofix/cmd/gofix/main.go @@ -2,10 +2,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The inline command applies the inliner to the specified packages of -// Go source code. Run with: +// The gofix command applies the inliner to the specified packages of +// Go source code. Run this command to report all fixes: // -// $ go run ./internal/analysis/gofix/main.go -fix packages... +// $ go run ./internal/gofix/cmd/gofix packages... +// +// Run this command to preview the changes: +// +// $ go run ./internal/gofix/cmd/gofix -fix -diff packages... +// +// And run this command to apply them, including ones in test files: +// +// $ go run ./internal/gofix/cmd/gofix -fix -test packages... +// +// This internal command is not officially supported. In the long +// term, we plan to migrate this functionality into "go fix"; see Go +// issues https//go.dev/issue/32816, 71859, 73605. package main import ( diff --git a/internal/gofix/doc.go b/internal/gofix/doc.go index b3e693d7dce..0c883370705 100644 --- a/internal/gofix/doc.go +++ b/internal/gofix/doc.go @@ -98,7 +98,7 @@ The proposal https://go.dev/issue/32816 introduces the "//go:fix" directives. You can use this (officially unsupported) command to apply gofix fixes en masse: - $ go run golang.org/x/tools/gopls/internal/analysis/gofix/cmd/gofix@latest -test ./... + $ go run golang.org/x/tools/internal/gofix/cmd/gofix@latest -test ./... (Do not use "go get -tool" to add gopls as a dependency of your module; gopls commands must be built from their release branch.) From cb264bf6180511862f7c4dedc7f51284c1e1af88 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 30 May 2025 17:42:51 +0000 Subject: [PATCH 156/196] internal/mcp: clean up diagnostics from Google import Change-Id: Icf388c48761f0ed504f45cd07e6d1662693b1e2e Reviewed-on: https://go-review.googlesource.com/c/tools/+/677517 Reviewed-by: Alan Donovan Auto-Submit: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/client.go | 46 ++++++++++++++--------------- internal/mcp/jsonschema/validate.go | 2 +- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 1d8252c0e2f..3b577a19896 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -129,14 +129,14 @@ type ClientSession struct { // Close performs a graceful close of the connection, preventing new requests // from being handled, and waiting for ongoing requests to return. Close then // terminates the connection. -func (c *ClientSession) Close() error { - return c.conn.Close() +func (cs *ClientSession) Close() error { + return cs.conn.Close() } // Wait waits for the connection to be closed by the server. // Generally, clients should be responsible for closing the connection. -func (c *ClientSession) Wait() error { - return c.conn.Wait() +func (cs *ClientSession) Wait() error { + return cs.conn.Wait() } // AddRoots adds the given roots to the client, @@ -232,33 +232,33 @@ func (cs *ClientSession) methodHandler() MethodHandler[ClientSession] { // getConn implements [session.getConn]. func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn } -func (c *ClientSession) ping(ct context.Context, params *PingParams) (Result, error) { +func (cs *ClientSession) ping(ct context.Context, params *PingParams) (Result, error) { return emptyResult{}, nil } // Ping makes an MCP "ping" request to the server. -func (c *ClientSession) Ping(ctx context.Context, params *PingParams) error { - return call(ctx, c.conn, methodPing, params, nil) +func (cs *ClientSession) Ping(ctx context.Context, params *PingParams) error { + return call(ctx, cs.conn, methodPing, params, nil) } // ListPrompts lists prompts that are currently available on the server. -func (c *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { - return standardCall[ListPromptsResult](ctx, c.conn, methodListPrompts, params) +func (cs *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { + return standardCall[ListPromptsResult](ctx, cs.conn, methodListPrompts, params) } // GetPrompt gets a prompt from the server. -func (c *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { - return standardCall[GetPromptResult](ctx, c.conn, methodGetPrompt, params) +func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { + return standardCall[GetPromptResult](ctx, cs.conn, methodGetPrompt, params) } // ListTools lists tools that are currently available on the server. -func (c *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { - return standardCall[ListToolsResult](ctx, c.conn, methodListTools, params) +func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { + return standardCall[ListToolsResult](ctx, cs.conn, methodListTools, params) } // CallTool calls the tool with the given name and arguments. // Pass a [CallToolOptions] to provide additional request fields. -func (c *ClientSession) CallTool(ctx context.Context, name string, args map[string]any, opts *CallToolOptions) (_ *CallToolResult, err error) { +func (cs *ClientSession) CallTool(ctx context.Context, name string, args map[string]any, opts *CallToolOptions) (_ *CallToolResult, err error) { defer func() { if err != nil { err = fmt.Errorf("calling tool %q: %w", name, err) @@ -273,11 +273,11 @@ func (c *ClientSession) CallTool(ctx context.Context, name string, args map[stri Name: name, Arguments: json.RawMessage(data), } - return standardCall[CallToolResult](ctx, c.conn, methodCallTool, params) + return standardCall[CallToolResult](ctx, cs.conn, methodCallTool, params) } -func (c *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error { - return call(ctx, c.conn, methodSetLevel, params, nil) +func (cs *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error { + return call(ctx, cs.conn, methodSetLevel, params, nil) } // NOTE: the following struct should consist of all fields of callToolParams except name and arguments. @@ -288,13 +288,13 @@ type CallToolOptions struct { } // ListResources lists the resources that are currently available on the server. -func (c *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { - return standardCall[ListResourcesResult](ctx, c.conn, methodListResources, params) +func (cs *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { + return standardCall[ListResourcesResult](ctx, cs.conn, methodListResources, params) } // ReadResource ask the server to read a resource and return its contents. -func (c *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { - return standardCall[ReadResourceResult](ctx, c.conn, methodReadResource, params) +func (cs *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { + return standardCall[ReadResourceResult](ctx, cs.conn, methodReadResource, params) } func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (Result, error) { @@ -319,14 +319,14 @@ func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, para // Tools provides an iterator for all tools available on the server, // automatically fetching pages and managing cursors. // The `params` argument can set the initial cursor. -func (c *ClientSession) Tools(ctx context.Context, params *ListToolsParams) iter.Seq2[Tool, error] { +func (cs *ClientSession) Tools(ctx context.Context, params *ListToolsParams) iter.Seq2[Tool, error] { currentParams := &ListToolsParams{} if params != nil { *currentParams = *params } return func(yield func(Tool, error) bool) { for { - res, err := c.ListTools(ctx, currentParams) + res, err := cs.ListTools(ctx, currentParams) if err != nil { yield(Tool{}, err) return diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index f0ddcf6ce58..466e34506aa 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -50,7 +50,7 @@ func (rs *Resolved) validateDefaults() error { if s.Default != nil { var d any if err := json.Unmarshal(s.Default, &d); err != nil { - fmt.Errorf("unmarshaling default value of schema %s: %w", s, err) + return fmt.Errorf("unmarshaling default value of schema %s: %w", s, err) } if err := st.validate(reflect.ValueOf(d), s, nil); err != nil { return err From 218e5f2e9d0c005d461c8075349966ad0b3c7b57 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 30 May 2025 13:31:25 -0400 Subject: [PATCH 157/196] gopls/internal: handle errors, or document that we ignore them This CL documents every case in which a call to a function discards its error result, except for infallible or benign cases like bytes.Buffer.WriteByte or fmt.Println. In some cases, the function were reworked to avoid returning errors. Updates golang/go#73930 Change-Id: If6871ec5dfd6560f98707149ba9e29cd1efbe796 Reviewed-on: https://go-review.googlesource.com/c/tools/+/677516 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Commit-Queue: Alan Donovan Auto-Submit: Alan Donovan --- .../simplifycompositelit.go | 4 +- .../analysis/simplifyslice/simplifyslice.go | 2 +- gopls/internal/cache/future_test.go | 2 +- gopls/internal/cache/port_test.go | 2 +- gopls/internal/cache/typerefs/pkgrefs_test.go | 4 +- gopls/internal/cmd/cmd.go | 2 +- gopls/internal/cmd/help_test.go | 4 +- gopls/internal/cmd/info.go | 9 ++-- gopls/internal/cmd/stats.go | 4 +- gopls/internal/debug/info.go | 53 +++++++++---------- gopls/internal/debug/info_test.go | 9 +--- gopls/internal/debug/serve.go | 2 +- gopls/internal/doc/generate/generate.go | 8 +-- .../internal/golang/completion/completion.go | 9 ++-- .../internal/golang/completion/unimported.go | 17 +++--- gopls/internal/golang/hover.go | 4 +- gopls/internal/licenses/licenses_test.go | 2 +- gopls/internal/lsprpc/export_test.go | 2 +- gopls/internal/lsprpc/lsprpc.go | 4 +- gopls/internal/lsprpc/middleware_test.go | 4 +- gopls/internal/mcp/mcp.go | 2 +- gopls/internal/protocol/generate/main.go | 4 +- gopls/internal/server/general.go | 6 +-- gopls/internal/server/workspace.go | 2 +- gopls/internal/telemetry/telemetry_test.go | 2 +- .../test/integration/bench/bench_test.go | 4 +- .../diagnostics/diagnostics_test.go | 4 +- .../internal/test/integration/fake/editor.go | 3 +- .../internal/test/integration/fake/sandbox.go | 2 +- .../internal/test/integration/fake/workdir.go | 2 +- .../test/integration/misc/imports_test.go | 8 +-- .../test/integration/misc/misc_test.go | 2 +- .../test/integration/misc/test_test.go | 2 +- .../test/integration/modfile/modfile_test.go | 4 +- gopls/internal/test/integration/regtest.go | 6 +-- gopls/internal/test/integration/runner.go | 4 +- gopls/internal/test/marker/marker_test.go | 2 +- gopls/internal/vulncheck/vulntest/db.go | 2 +- 38 files changed, 103 insertions(+), 105 deletions(-) diff --git a/gopls/internal/analysis/simplifycompositelit/simplifycompositelit.go b/gopls/internal/analysis/simplifycompositelit/simplifycompositelit.go index b38ccf4d5ed..3e54dc27b0d 100644 --- a/gopls/internal/analysis/simplifycompositelit/simplifycompositelit.go +++ b/gopls/internal/analysis/simplifycompositelit/simplifycompositelit.go @@ -89,7 +89,7 @@ func simplifyLiteral(pass *analysis.Pass, typ reflect.Value, astType, x ast.Expr // literal type may be omitted if inner, ok := x.(*ast.CompositeLit); ok && match(typ, reflect.ValueOf(inner.Type)) { var b bytes.Buffer - printer.Fprint(&b, pass.Fset, inner.Type) + printer.Fprint(&b, pass.Fset, inner.Type) // ignore error createDiagnostic(pass, inner.Type.Pos(), inner.Type.End(), b.String()) } // if the outer literal's element type is a pointer type *T @@ -100,7 +100,7 @@ func simplifyLiteral(pass *analysis.Pass, typ reflect.Value, astType, x ast.Expr if inner, ok := addr.X.(*ast.CompositeLit); ok { if match(reflect.ValueOf(ptr.X), reflect.ValueOf(inner.Type)) { var b bytes.Buffer - printer.Fprint(&b, pass.Fset, inner.Type) + printer.Fprint(&b, pass.Fset, inner.Type) // ignore error // Account for the & by subtracting 1 from typ.Pos(). createDiagnostic(pass, inner.Type.Pos()-1, inner.Type.End(), "&"+b.String()) } diff --git a/gopls/internal/analysis/simplifyslice/simplifyslice.go b/gopls/internal/analysis/simplifyslice/simplifyslice.go index 28cc266d713..8aae3c67029 100644 --- a/gopls/internal/analysis/simplifyslice/simplifyslice.go +++ b/gopls/internal/analysis/simplifyslice/simplifyslice.go @@ -82,7 +82,7 @@ func run(pass *analysis.Pass) (any, error) { return } var b bytes.Buffer - printer.Fprint(&b, pass.Fset, expr.High) + printer.Fprint(&b, pass.Fset, expr.High) // ignore error pass.Report(analysis.Diagnostic{ Pos: expr.High.Pos(), End: expr.High.End(), diff --git a/gopls/internal/cache/future_test.go b/gopls/internal/cache/future_test.go index d96dc0f5317..033ecea5259 100644 --- a/gopls/internal/cache/future_test.go +++ b/gopls/internal/cache/future_test.go @@ -145,7 +145,7 @@ func TestFutureCache_Retrying(t *testing.T) { defer cancels[9]() dones[9] <- struct{}{} - g.Wait() + _ = g.Wait() // can't fail t.Logf("started %d computations", started.Load()) if got := lastValue.Load(); got != 9 { diff --git a/gopls/internal/cache/port_test.go b/gopls/internal/cache/port_test.go index 5d0c5d4a50f..e1789b89c37 100644 --- a/gopls/internal/cache/port_test.go +++ b/gopls/internal/cache/port_test.go @@ -65,7 +65,7 @@ func TestMatchingPortsStdlib(t *testing.T) { }) } }) - g.Wait() + _ = g.Wait() // can't fail } func matchingPreferredPorts(tb testing.TB, fh file.Handle, trimContent bool) map[port]unit { diff --git a/gopls/internal/cache/typerefs/pkgrefs_test.go b/gopls/internal/cache/typerefs/pkgrefs_test.go index ce297e4380b..0500120c977 100644 --- a/gopls/internal/cache/typerefs/pkgrefs_test.go +++ b/gopls/internal/cache/typerefs/pkgrefs_test.go @@ -223,12 +223,12 @@ func importFromExportData(pkgPath, exportFile string) (*types.Package, error) { } r, err := gcexportdata.NewReader(file) if err != nil { - file.Close() + file.Close() // ignore error return nil, err } fset := token.NewFileSet() tpkg, err := gcexportdata.Read(r, fset, make(map[string]*types.Package), pkgPath) - file.Close() + file.Close() // ignore error if err != nil { return nil, err } diff --git a/gopls/internal/cmd/cmd.go b/gopls/internal/cmd/cmd.go index 02c5103de37..ed05235a5dc 100644 --- a/gopls/internal/cmd/cmd.go +++ b/gopls/internal/cmd/cmd.go @@ -822,7 +822,7 @@ func (c *connection) diagnoseFiles(ctx context.Context, files []protocol.Documen func (c *connection) terminate(ctx context.Context) { // TODO: do we need to handle errors on these calls? - c.Shutdown(ctx) + c.Shutdown(ctx) // ignore error // TODO: right now calling exit terminates the process, we should rethink that // server.Exit(ctx) } diff --git a/gopls/internal/cmd/help_test.go b/gopls/internal/cmd/help_test.go index 74fb07fbe75..7b90b3e8133 100644 --- a/gopls/internal/cmd/help_test.go +++ b/gopls/internal/cmd/help_test.go @@ -39,7 +39,7 @@ func TestHelpFiles(t *testing.T) { var buf bytes.Buffer s := flag.NewFlagSet(page.Name(), flag.ContinueOnError) s.SetOutput(&buf) - tool.Run(ctx, s, page, []string{"-h"}) + tool.Run(ctx, s, page, []string{"-h"}) // ignore error name := page.Name() if name == appName { name = "usage" @@ -70,7 +70,7 @@ func TestVerboseHelp(t *testing.T) { var buf bytes.Buffer s := flag.NewFlagSet(appName, flag.ContinueOnError) s.SetOutput(&buf) - tool.Run(ctx, s, app, []string{"-v", "-h"}) + tool.Run(ctx, s, app, []string{"-v", "-h"}) // ignore error got := buf.Bytes() helpFile := filepath.Join("usage", "usage-v.hlp") diff --git a/gopls/internal/cmd/info.go b/gopls/internal/cmd/info.go index 93a66880234..90baf11004f 100644 --- a/gopls/internal/cmd/info.go +++ b/gopls/internal/cmd/info.go @@ -11,6 +11,7 @@ import ( "context" "flag" "fmt" + "io" "net/url" "os" "sort" @@ -95,8 +96,10 @@ func (v *version) Run(ctx context.Context, args ...string) error { if v.JSON { mode = debug.JSON } - - return debug.PrintVersionInfo(ctx, os.Stdout, v.app.verbose(), mode) + var buf bytes.Buffer + debug.WriteVersionInfo(&buf, v.app.verbose(), mode) + _, err := io.Copy(os.Stdout, &buf) + return err } // bug implements the bug command. @@ -175,7 +178,7 @@ func (b *bug) Run(ctx context.Context, args ...string) error { } fmt.Fprintf(public, "\nPlease copy the full information printed by `gopls bug` here, if you are comfortable sharing it.\n\n") } - debug.PrintVersionInfo(ctx, public, true, debug.Markdown) + debug.WriteVersionInfo(public, true, debug.Markdown) body := public.String() title := strings.Join(args, " ") if !strings.HasPrefix(title, goplsBugPrefix) { diff --git a/gopls/internal/cmd/stats.go b/gopls/internal/cmd/stats.go index 1ba43ccee83..51658ab0ed2 100644 --- a/gopls/internal/cmd/stats.go +++ b/gopls/internal/cmd/stats.go @@ -224,7 +224,7 @@ type dirStats struct { // subdirectories. func findDirStats() (dirStats, error) { var ds dirStats - filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { + err := filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { if err != nil { return err } @@ -244,5 +244,5 @@ func findDirStats() (dirStats, error) { } return nil }) - return ds, nil + return ds, err } diff --git a/gopls/internal/debug/info.go b/gopls/internal/debug/info.go index b2824d86f38..fbc7d166d35 100644 --- a/gopls/internal/debug/info.go +++ b/gopls/internal/debug/info.go @@ -6,7 +6,7 @@ package debug import ( - "context" + "bytes" "encoding/json" "fmt" "io" @@ -53,49 +53,48 @@ func VersionInfo() *ServerVersion { } } -// PrintServerInfo writes HTML debug info to w for the Instance. -func (i *Instance) PrintServerInfo(ctx context.Context, w io.Writer) { +// writeServerInfo writes HTML debug info to w for the instance. +func (i *Instance) writeServerInfo(out *bytes.Buffer) { workDir, _ := os.Getwd() - section(w, HTML, "Server Instance", func() { - fmt.Fprintf(w, "Start time: %v\n", i.StartTime) - fmt.Fprintf(w, "LogFile: %s\n", i.Logfile) - fmt.Fprintf(w, "pid: %d\n", os.Getpid()) - fmt.Fprintf(w, "Working directory: %s\n", workDir) - fmt.Fprintf(w, "Address: %s\n", i.ServerAddress) - fmt.Fprintf(w, "Debug address: %s\n", i.DebugAddress()) + section(out, HTML, "server instance", func() { + fmt.Fprintf(out, "Start time: %v\n", i.StartTime) + fmt.Fprintf(out, "LogFile: %s\n", i.Logfile) + fmt.Fprintf(out, "pid: %d\n", os.Getpid()) + fmt.Fprintf(out, "Working directory: %s\n", workDir) + fmt.Fprintf(out, "Address: %s\n", i.ServerAddress) + fmt.Fprintf(out, "Debug address: %s\n", i.DebugAddress()) }) - PrintVersionInfo(ctx, w, true, HTML) - section(w, HTML, "Command Line", func() { - fmt.Fprintf(w, "cmdline") + WriteVersionInfo(out, true, HTML) + section(out, HTML, "Command Line", func() { + fmt.Fprintf(out, "cmdline") }) } -// PrintVersionInfo writes version information to w, using the output format +// WriteVersionInfo writes version information to w, using the output format // specified by mode. verbose controls whether additional information is // written, including section headers. -func PrintVersionInfo(_ context.Context, w io.Writer, verbose bool, mode PrintMode) error { +func WriteVersionInfo(out *bytes.Buffer, verbose bool, mode PrintMode) { info := VersionInfo() if mode == JSON { - return printVersionInfoJSON(w, info) + writeVersionInfoJSON(out, info) + return } if !verbose { - printBuildInfo(w, info, false, mode) - return nil + writeBuildInfo(out, info, false, mode) + return } - section(w, mode, "Build info", func() { - printBuildInfo(w, info, true, mode) + section(out, mode, "Build info", func() { + writeBuildInfo(out, info, true, mode) }) - return nil } -func printVersionInfoJSON(w io.Writer, info *ServerVersion) error { - js, err := json.MarshalIndent(info, "", "\t") +func writeVersionInfoJSON(out *bytes.Buffer, info *ServerVersion) { + data, err := json.MarshalIndent(info, "", "\t") if err != nil { - return err + panic(err) // can't happen } - _, err = fmt.Fprint(w, string(js)) - return err + out.Write(data) } func section(w io.Writer, mode PrintMode, title string, body func()) { @@ -115,7 +114,7 @@ func section(w io.Writer, mode PrintMode, title string, body func()) { } } -func printBuildInfo(w io.Writer, info *ServerVersion, verbose bool, mode PrintMode) { +func writeBuildInfo(w io.Writer, info *ServerVersion, verbose bool, mode PrintMode) { fmt.Fprintf(w, "%v %v\n", info.Path, version.Version()) if !verbose { return diff --git a/gopls/internal/debug/info_test.go b/gopls/internal/debug/info_test.go index 7f24b696682..6028c187543 100644 --- a/gopls/internal/debug/info_test.go +++ b/gopls/internal/debug/info_test.go @@ -7,7 +7,6 @@ package debug import ( "bytes" - "context" "encoding/json" "runtime" "testing" @@ -17,9 +16,7 @@ import ( func TestPrintVersionInfoJSON(t *testing.T) { buf := new(bytes.Buffer) - if err := PrintVersionInfo(context.Background(), buf, true, JSON); err != nil { - t.Fatalf("PrintVersionInfo failed: %v", err) - } + WriteVersionInfo(buf, true, JSON) res := buf.Bytes() var got ServerVersion @@ -37,9 +34,7 @@ func TestPrintVersionInfoJSON(t *testing.T) { func TestPrintVersionInfoPlainText(t *testing.T) { buf := new(bytes.Buffer) - if err := PrintVersionInfo(context.Background(), buf, true, PlainText); err != nil { - t.Fatalf("PrintVersionInfo failed: %v", err) - } + WriteVersionInfo(buf, true, PlainText) res := buf.Bytes() // Other fields of BuildInfo may not be available during test. diff --git a/gopls/internal/debug/serve.go b/gopls/internal/debug/serve.go index b8fdfe0791f..77a86d8c8da 100644 --- a/gopls/internal/debug/serve.go +++ b/gopls/internal/debug/serve.go @@ -323,7 +323,7 @@ func (i *Instance) getFile(r *http.Request) any { func (i *Instance) getInfo(r *http.Request) any { buf := &bytes.Buffer{} - i.PrintServerInfo(r.Context(), buf) + i.writeServerInfo(buf) return template.HTML(buf.String()) } diff --git a/gopls/internal/doc/generate/generate.go b/gopls/internal/doc/generate/generate.go index d470fb71333..5b1acc9f005 100644 --- a/gopls/internal/doc/generate/generate.go +++ b/gopls/internal/doc/generate/generate.go @@ -17,6 +17,7 @@ package main import ( "bytes" "encoding/json" + "errors" "fmt" "go/ast" "go/token" @@ -500,9 +501,10 @@ func loadLenses(settingsPkg *packages.Package, defaults map[settings.CodeLensSou } return nil } - addAll(golang.CodeLensSources(), "Go") - addAll(mod.CodeLensSources(), "go.mod") - return lenses, nil + err := errors.Join( + addAll(golang.CodeLensSources(), "Go"), + addAll(mod.CodeLensSources(), "go.mod")) + return lenses, err } func loadAnalyzers(analyzers []*settings.Analyzer, defaults *settings.Options) []*doc.Analyzer { diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index 13793995561..b48841500bd 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -137,13 +137,14 @@ func (i *CompletionItem) Snippet() string { // addConversion wraps the existing completionItem in a conversion expression. // Only affects the receiver's InsertText and snippet fields, not the Label. // An empty conv argument has no effect. -func (i *CompletionItem) addConversion(c *completer, conv conversionEdits) error { +func (i *CompletionItem) addConversion(c *completer, conv conversionEdits) { if conv.prefix != "" { // If we are in a selector, add an edit to place prefix before selector. if sel := enclosingSelector(c.path, c.pos); sel != nil { edits, err := c.editText(sel.Pos(), sel.Pos(), conv.prefix) if err != nil { - return err + // safetoken failed: invalid token.Pos information in AST. + return } i.AdditionalTextEdits = append(i.AdditionalTextEdits, edits...) } else { @@ -157,8 +158,6 @@ func (i *CompletionItem) addConversion(c *completer, conv conversionEdits) error i.InsertText += conv.suffix i.snippet.WriteText(conv.suffix) } - - return nil } // Scoring constants are used for weighting the relevance of different candidates. @@ -1457,7 +1456,7 @@ func (c *completer) selector(ctx context.Context, sel *ast.SelectorExpr) error { var buf strings.Builder buf.WriteString(name) buf.WriteByte(' ') - cfg.Fprint(&buf, token.NewFileSet(), typ) + cfg.Fprint(&buf, token.NewFileSet(), typ) // ignore error params = append(params, buf.String()) } diff --git a/gopls/internal/golang/completion/unimported.go b/gopls/internal/golang/completion/unimported.go index 9b0d20b08b0..e562f5cd83c 100644 --- a/gopls/internal/golang/completion/unimported.go +++ b/gopls/internal/golang/completion/unimported.go @@ -41,14 +41,14 @@ import ( "golang.org/x/tools/internal/versions" ) -func (c *completer) unimported(ctx context.Context, pkgname metadata.PackageName, prefix string) error { +func (c *completer) unimported(ctx context.Context, pkgname metadata.PackageName, prefix string) { wsIDs, ourIDs := c.findPackageIDs(pkgname) stdpkgs := c.stdlibPkgs(pkgname) if len(ourIDs) > 0 { // use the one in the current package, if possible items := c.pkgIDmatches(ctx, ourIDs, pkgname, prefix) if c.scoreList(items) { - return nil + return } } // do the stdlib next. @@ -63,19 +63,19 @@ func (c *completer) unimported(ctx context.Context, pkgname metadata.PackageName if len(x) > 0 { items := c.pkgIDmatches(ctx, x, pkgname, prefix) if c.scoreList(items) { - return nil + return } } // just use the stdlib items := c.stdlibMatches(stdpkgs, pkgname, prefix) if c.scoreList(items) { - return nil + return } // look in the rest of the workspace items = c.pkgIDmatches(ctx, wsIDs, pkgname, prefix) if c.scoreList(items) { - return nil + return } // look in the module cache, for the last chance @@ -83,7 +83,6 @@ func (c *completer) unimported(ctx context.Context, pkgname metadata.PackageName if err == nil { c.scoreList(items) } - return nil } // find all the packageIDs for packages in the workspace that have the desired name @@ -126,14 +125,14 @@ func (c *completer) pkgIDmatches(ctx context.Context, ids []metadata.PackageID, return nil // would if be worth retrying the ids one by one? } if len(allpkgsyms) != len(ids) { - bug.Errorf("Symbols returned %d values for %d pkgIDs", len(allpkgsyms), len(ids)) + bug.Reportf("Symbols returned %d values for %d pkgIDs", len(allpkgsyms), len(ids)) return nil } var got []CompletionItem for i, pkgID := range ids { pkg := c.snapshot.MetadataGraph().Packages[pkgID] if pkg == nil { - bug.Errorf("no metadata for %s", pkgID) + bug.Reportf("no metadata for %s", pkgID) continue // something changed underfoot, otherwise can't happen } pkgsyms := allpkgsyms[i] @@ -345,7 +344,7 @@ func funcParams(f *ast.File, fname string) []string { var buf strings.Builder buf.WriteString(name) buf.WriteByte(' ') - cfg.Fprint(&buf, token.NewFileSet(), typ) + cfg.Fprint(&buf, token.NewFileSet(), typ) // ignore error params = append(params, buf.String()) } diff --git a/gopls/internal/golang/hover.go b/gopls/internal/golang/hover.go index dd04f8908c7..369003822d7 100644 --- a/gopls/internal/golang/hover.go +++ b/gopls/internal/golang/hover.go @@ -501,7 +501,7 @@ func hover(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, pp pro types.TypeString(f.field.Type(), qual), f.path) } - w.Flush() + w.Flush() // ignore error b.WriteByte('\n') fields = b.String() } @@ -1046,7 +1046,7 @@ func hoverReturnStatement(pgf *parsego.File, path []ast.Node, ret *ast.ReturnStm if i > 0 { buf.WriteString(", ") } - cfg.Fprint(&buf, fset, field.Type) + cfg.Fprint(&buf, fset, field.Type) // ignore error } buf.WriteByte(')') return rng, &hoverResult{ diff --git a/gopls/internal/licenses/licenses_test.go b/gopls/internal/licenses/licenses_test.go index c31b4e9e659..9892664626c 100644 --- a/gopls/internal/licenses/licenses_test.go +++ b/gopls/internal/licenses/licenses_test.go @@ -20,7 +20,7 @@ func TestLicenses(t *testing.T) { if err != nil { t.Fatal(err) } - tmp.Close() + tmp.Close() // ignore error if out, err := exec.Command("./gen-licenses.sh", tmp.Name()).CombinedOutput(); err != nil { t.Fatalf("generating licenses failed: %q, %v", out, err) diff --git a/gopls/internal/lsprpc/export_test.go b/gopls/internal/lsprpc/export_test.go index 1caf22415cb..999718705e6 100644 --- a/gopls/internal/lsprpc/export_test.go +++ b/gopls/internal/lsprpc/export_test.go @@ -75,7 +75,7 @@ func (b *ForwardBinder) Bind(ctx context.Context, conn *jsonrpc2_v2.Connection) } detached := xcontext.Detach(ctx) go func() { - conn.Wait() + conn.Wait() // ignore error if err := serverConn.Close(); err != nil { event.Log(detached, fmt.Sprintf("closing remote connection: %v", err)) } diff --git a/gopls/internal/lsprpc/lsprpc.go b/gopls/internal/lsprpc/lsprpc.go index f432d64aa76..39ec9bb0ac8 100644 --- a/gopls/internal/lsprpc/lsprpc.go +++ b/gopls/internal/lsprpc/lsprpc.go @@ -248,9 +248,9 @@ func (f *forwarder) ServeStream(ctx context.Context, clientConn jsonrpc2.Conn) e select { case <-serverConn.Done(): - clientConn.Close() + clientConn.Close() // ignore error case <-clientConn.Done(): - serverConn.Close() + serverConn.Close() // ignore error } err = nil diff --git a/gopls/internal/lsprpc/middleware_test.go b/gopls/internal/lsprpc/middleware_test.go index afa6ae78d2f..41f70fe6dec 100644 --- a/gopls/internal/lsprpc/middleware_test.go +++ b/gopls/internal/lsprpc/middleware_test.go @@ -79,7 +79,7 @@ func TestHandshakeMiddleware(t *testing.T) { if err := check(true); err != nil { t.Fatalf("after handshake: %v", err) } - conn.Close() + conn.Close() // ignore error // Wait for up to ~2s for connections to get cleaned up. delay := 25 * time.Millisecond for retries := 3; retries >= 0; retries-- { @@ -206,7 +206,7 @@ func (h *Handshaker) nextID() int64 { } func (h *Handshaker) cleanupAtDisconnect(conn *jsonrpc2_v2.Connection, peerID int64) { - conn.Wait() + conn.Wait() // ignore error h.mu.Lock() defer h.mu.Unlock() diff --git a/gopls/internal/mcp/mcp.go b/gopls/internal/mcp/mcp.go index e8744f76a56..b12e463ec37 100644 --- a/gopls/internal/mcp/mcp.go +++ b/gopls/internal/mcp/mcp.go @@ -46,7 +46,7 @@ func Serve(ctx context.Context, address string, eventChan <-chan lsprpc.SessionE // Run the server until cancellation. go func() { <-ctx.Done() - svr.Close() + svr.Close() // ignore error }() return svr.Serve(listener) } diff --git a/gopls/internal/protocol/generate/main.go b/gopls/internal/protocol/generate/main.go index ef9bf943606..aa1c47516c2 100644 --- a/gopls/internal/protocol/generate/main.go +++ b/gopls/internal/protocol/generate/main.go @@ -213,7 +213,9 @@ func formatTo(basename string, src []byte) { formatted, err := format.Source(src) if err != nil { failed := filepath.Join("/tmp", basename+".fail") - os.WriteFile(failed, src, 0644) + if err := os.WriteFile(failed, src, 0644); err != nil { + log.Fatal(err) + } log.Fatalf("formatting %s: %v (see %s)", basename, err, failed) } if err := os.WriteFile(filepath.Join(*outputdir, basename), formatted, 0644); err != nil { diff --git a/gopls/internal/server/general.go b/gopls/internal/server/general.go index 6ce1f788dba..0eea8f641f7 100644 --- a/gopls/internal/server/general.go +++ b/gopls/internal/server/general.go @@ -221,7 +221,7 @@ func (s *server) Initialized(ctx context.Context, params *protocol.InitializedPa s.stateMu.Unlock() for _, not := range s.notifications { - s.client.ShowMessage(ctx, not) + s.client.ShowMessage(ctx, not) // ignore error } s.notifications = nil @@ -652,7 +652,7 @@ func (s *server) Shutdown(ctx context.Context) error { if s.state != serverShutDown { // Wait for the webserver (if any) to finish. if s.web != nil { - s.web.server.Shutdown(ctx) + s.web.server.Shutdown(ctx) // ignore error } // drop all the active views @@ -674,7 +674,7 @@ func (s *server) Exit(ctx context.Context) error { s.stateMu.Lock() defer s.stateMu.Unlock() - s.client.Close() + s.client.Close() // ignore error if s.state != serverShutDown { // TODO: We should be able to do better than this. diff --git a/gopls/internal/server/workspace.go b/gopls/internal/server/workspace.go index ced5656c6ac..01e2c69d8ee 100644 --- a/gopls/internal/server/workspace.go +++ b/gopls/internal/server/workspace.go @@ -121,7 +121,7 @@ func (s *server) DidChangeConfiguration(ctx context.Context, _ *protocol.DidChan } newFolders = append(newFolders, newFolder) } - s.session.UpdateFolders(ctx, newFolders) + s.session.UpdateFolders(ctx, newFolders) // ignore error // The view set may have been updated above. viewsToDiagnose := make(map[*cache.View][]protocol.DocumentURI) diff --git a/gopls/internal/telemetry/telemetry_test.go b/gopls/internal/telemetry/telemetry_test.go index 1e56012182f..f41769f3ddf 100644 --- a/gopls/internal/telemetry/telemetry_test.go +++ b/gopls/internal/telemetry/telemetry_test.go @@ -32,7 +32,7 @@ func TestMain(m *testing.M) { } countertest.Open(tmp) code := Main(m) - os.RemoveAll(tmp) // golang/go#68243: ignore error; cleanup fails on Windows + os.RemoveAll(tmp) // ignore error (cleanup fails on Windows; golang/go#68243) os.Exit(code) } diff --git a/gopls/internal/test/integration/bench/bench_test.go b/gopls/internal/test/integration/bench/bench_test.go index 9858177b7e0..ef8769d9d5f 100644 --- a/gopls/internal/test/integration/bench/bench_test.go +++ b/gopls/internal/test/integration/bench/bench_test.go @@ -270,8 +270,8 @@ func (s *SidecarServer) Connect(ctx context.Context) jsonrpc2.Conn { go func() { select { case <-ctx.Done(): - clientConn.Close() - clientStream.Close() + clientConn.Close() // ignore error + clientStream.Close() // ignore error case <-clientConn.Done(): } }() diff --git a/gopls/internal/test/integration/diagnostics/diagnostics_test.go b/gopls/internal/test/integration/diagnostics/diagnostics_test.go index 222077d2e55..978d6b22dec 100644 --- a/gopls/internal/test/integration/diagnostics/diagnostics_test.go +++ b/gopls/internal/test/integration/diagnostics/diagnostics_test.go @@ -266,7 +266,7 @@ func TestDeleteTestVariant_DiskOnly(t *testing.T) { Run(t, test38878, func(t *testing.T, env *Env) { env.OpenFile("a_test.go") env.AfterChange(Diagnostics(AtPosition("a_test.go", 5, 3))) - env.Sandbox.Workdir.RemoveFile(context.Background(), "a_test.go") + env.Sandbox.Workdir.RemoveFile(context.Background(), "a_test.go") // ignore error env.AfterChange(Diagnostics(AtPosition("a_test.go", 5, 3))) }) } @@ -1136,7 +1136,7 @@ package main func main() {} ` Run(t, basic, func(t *testing.T, env *Env) { - env.Editor.CreateBuffer(env.Ctx, "foo.go", `package main`) + env.CreateBuffer("foo.go", `package main`) env.AfterChange() env.CloseBuffer("foo.go") env.AfterChange(NoLogMatching(protocol.Info, "packages=0")) diff --git a/gopls/internal/test/integration/fake/editor.go b/gopls/internal/test/integration/fake/editor.go index b5d8e4ccda0..4acfb06393d 100644 --- a/gopls/internal/test/integration/fake/editor.go +++ b/gopls/internal/test/integration/fake/editor.go @@ -1579,8 +1579,7 @@ func (e *Editor) applyTextDocumentEdit(ctx context.Context, change protocol.Text // TODO: it's unclear if this is correct. Here we create the buffer (with // version 1), then apply edits. Perhaps we should apply the edits before // sending the didOpen notification. - e.CreateBuffer(ctx, path, "") - err = nil + err = e.CreateBuffer(ctx, path, "") } if err != nil { return err diff --git a/gopls/internal/test/integration/fake/sandbox.go b/gopls/internal/test/integration/fake/sandbox.go index 1d8918babd4..22352dbda9c 100644 --- a/gopls/internal/test/integration/fake/sandbox.go +++ b/gopls/internal/test/integration/fake/sandbox.go @@ -85,7 +85,7 @@ func NewSandbox(config *SandboxConfig) (_ *Sandbox, err error) { defer func() { // Clean up if we fail at any point in this constructor. if err != nil { - sb.Close() + sb.Close() // ignore error } }() diff --git a/gopls/internal/test/integration/fake/workdir.go b/gopls/internal/test/integration/fake/workdir.go index 54fabb358c3..b430f31b544 100644 --- a/gopls/internal/test/integration/fake/workdir.go +++ b/gopls/internal/test/integration/fake/workdir.go @@ -295,7 +295,7 @@ func (w *Workdir) RenameFile(ctx context.Context, oldPath, newPath string) error if err := robustio.RemoveAll(oldAbs); err != nil { // If we failed to remove the old file, that may explain the Rename error too. // Make a best effort to back out the write to the new path. - robustio.RemoveAll(newAbs) + robustio.RemoveAll(newAbs) // ignore error return renameErr } } diff --git a/gopls/internal/test/integration/misc/imports_test.go b/gopls/internal/test/integration/misc/imports_test.go index bdb5ea25318..fc946d7e809 100644 --- a/gopls/internal/test/integration/misc/imports_test.go +++ b/gopls/internal/test/integration/misc/imports_test.go @@ -283,7 +283,7 @@ return nil for k, v := range mx { fname := filepath.Join(modcache, k) dir := filepath.Dir(fname) - os.MkdirAll(dir, 0777) + os.MkdirAll(dir, 0777) // ignore error if err := os.WriteFile(fname, v, 0644); err != nil { t.Fatal(err) } @@ -333,7 +333,7 @@ return nil for k, v := range mx { fname := filepath.Join(modcache, k) dir := filepath.Dir(fname) - os.MkdirAll(dir, 0777) + os.MkdirAll(dir, 0777) // ignore error if err := os.WriteFile(fname, v, 0644); err != nil { t.Fatal(err) } @@ -384,7 +384,7 @@ return nil for k, v := range mx { fname := filepath.Join(modcache, k) dir := filepath.Dir(fname) - os.MkdirAll(dir, 0777) + os.MkdirAll(dir, 0777) // ignore error if err := os.WriteFile(fname, v, 0644); err != nil { t.Fatal(err) } @@ -636,7 +636,7 @@ var A int for k, v := range mx { fname := filepath.Join(modcache, k) dir := filepath.Dir(fname) - os.MkdirAll(dir, 0777) + os.MkdirAll(dir, 0777) // ignore error if err := os.WriteFile(fname, v, 0644); err != nil { t.Fatal(err) } diff --git a/gopls/internal/test/integration/misc/misc_test.go b/gopls/internal/test/integration/misc/misc_test.go index ca0125894c8..2b9ad7fe6ba 100644 --- a/gopls/internal/test/integration/misc/misc_test.go +++ b/gopls/internal/test/integration/misc/misc_test.go @@ -23,7 +23,7 @@ func TestMain(m *testing.M) { } countertest.Open(tmp) code := Main(m) - os.RemoveAll(tmp) // golang/go#68243: ignore error; cleanup fails on Windows + os.RemoveAll(tmp) // ignore error (cleanup fails on Windows; golang/go#68243) os.Exit(code) } diff --git a/gopls/internal/test/integration/misc/test_test.go b/gopls/internal/test/integration/misc/test_test.go index b282bf57a95..3dfc70d5407 100644 --- a/gopls/internal/test/integration/misc/test_test.go +++ b/gopls/internal/test/integration/misc/test_test.go @@ -19,7 +19,7 @@ import ( func TestRunTestsAndBenchmarks(t *testing.T) { file := filepath.Join(t.TempDir(), "out") - os.Setenv("TESTFILE", file) + os.Setenv("TESTFILE", file) // ignore error const src = ` -- go.mod -- diff --git a/gopls/internal/test/integration/modfile/modfile_test.go b/gopls/internal/test/integration/modfile/modfile_test.go index dfd50c3effb..36ed9cf4138 100644 --- a/gopls/internal/test/integration/modfile/modfile_test.go +++ b/gopls/internal/test/integration/modfile/modfile_test.go @@ -598,12 +598,12 @@ func main() { Diagnostics(env.AtRegexp("a/main.go", "x = ")), ) env.RegexpReplace("a/go.mod", "v1.2.3", "v1.2.2") - env.Editor.SaveBuffer(env.Ctx, "a/go.mod") // go.mod changes must be on disk + env.SaveBuffer("a/go.mod") // go.mod changes must be on disk env.AfterChange( Diagnostics(env.AtRegexp("a/go.mod", "example.com v1.2.2")), ) env.RegexpReplace("a/go.mod", "v1.2.2", "v1.2.3") - env.Editor.SaveBuffer(env.Ctx, "a/go.mod") // go.mod changes must be on disk + env.SaveBuffer("a/go.mod") // go.mod changes must be on disk env.AfterChange( Diagnostics(env.AtRegexp("a/main.go", "x = ")), ) diff --git a/gopls/internal/test/integration/regtest.go b/gopls/internal/test/integration/regtest.go index dc9600af7df..1ca077a8f57 100644 --- a/gopls/internal/test/integration/regtest.go +++ b/gopls/internal/test/integration/regtest.go @@ -157,7 +157,7 @@ func Main(m *testing.M) (code int) { flag.Parse() // Disable GOPACKAGESDRIVER, as it can cause spurious test failures. - os.Setenv("GOPACKAGESDRIVER", "off") + os.Setenv("GOPACKAGESDRIVER", "off") // ignore error if skipReason := checkBuilder(); skipReason != "" { fmt.Printf("Skipping all tests: %s\n", skipReason) @@ -213,8 +213,8 @@ func FilterToolchainPathAndGOROOT() { if localGo, first := findLocalGo(); localGo != "" && !first { dir := filepath.Dir(localGo) path := os.Getenv("PATH") - os.Setenv("PATH", dir+string(os.PathListSeparator)+path) - os.Unsetenv("GOROOT") // Remove the GOROOT value that was added by toolchain switch. + os.Setenv("PATH", dir+string(os.PathListSeparator)+path) // ignore error + os.Unsetenv("GOROOT") // Remove the GOROOT value that was added by toolchain switch. } } diff --git a/gopls/internal/test/integration/runner.go b/gopls/internal/test/integration/runner.go index 96427461580..9b0a7c27024 100644 --- a/gopls/internal/test/integration/runner.go +++ b/gopls/internal/test/integration/runner.go @@ -194,7 +194,7 @@ func (r *Runner) Run(t *testing.T, files string, test TestFunc, opts ...RunOptio defer func() { if !r.SkipCleanup { if err := sandbox.Close(); err != nil { - pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) + pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) // ignore error t.Errorf("closing the sandbox: %v", err) } } @@ -217,7 +217,7 @@ func (r *Runner) Run(t *testing.T, files string, test TestFunc, opts ...RunOptio env := ConnectGoplsEnv(t, ctx, sandbox, config.editor, ts) defer func() { if t.Failed() && r.PrintGoroutinesOnFailure { - pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) + pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) // ignore error } if (t.Failed() && !config.noLogsOnError) || *printLogs { ls.printBuffers(t.Name(), os.Stderr) diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index a41c2f670cf..83df6781662 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -58,7 +58,7 @@ func TestMain(m *testing.M) { bug.PanicOnBugs = true testenv.ExitIfSmallMachine() // Disable GOPACKAGESDRIVER, as it can cause spurious test failures. - os.Setenv("GOPACKAGESDRIVER", "off") + os.Setenv("GOPACKAGESDRIVER", "off") // ignore error integration.FilterToolchainPathAndGOROOT() os.Exit(m.Run()) } diff --git a/gopls/internal/vulncheck/vulntest/db.go b/gopls/internal/vulncheck/vulntest/db.go index 9a5c054520d..5f03927d14f 100644 --- a/gopls/internal/vulncheck/vulntest/db.go +++ b/gopls/internal/vulncheck/vulntest/db.go @@ -43,7 +43,7 @@ func NewDatabase(ctx context.Context, txtarReports []byte) (*DB, error) { return nil, err } if err := generateDB(ctx, txtarReports, disk, false); err != nil { - os.RemoveAll(disk) + os.RemoveAll(disk) // ignore error return nil, err } From dd6ec0425d023308b501cea8b8c4c7b5f4a11cd2 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 30 May 2025 14:36:52 -0400 Subject: [PATCH 158/196] gopls/internal/settings: add maprange analyzer + relnote Updates golang/go#72908 Change-Id: Ib3f0210249996ea48bedb38095064d1a8934a406 Reviewed-on: https://go-review.googlesource.com/c/tools/+/677536 LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Robert Findley --- gopls/doc/analyzers.md | 33 +++++++++++++++++++ gopls/doc/release/v0.19.0.md | 6 ++++ gopls/internal/analysis/maprange/main.go | 15 +++++++++ gopls/internal/doc/api.json | 12 +++++++ gopls/internal/settings/analysis.go | 2 ++ .../marker/testdata/diagnostics/analyzers.txt | 9 +++++ 6 files changed, 77 insertions(+) create mode 100644 gopls/internal/analysis/maprange/main.go diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md index 06892852319..c2bb5a6ad4f 100644 --- a/gopls/doc/analyzers.md +++ b/gopls/doc/analyzers.md @@ -3753,6 +3753,39 @@ Default: on. Package documentation: [lostcancel](https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/lostcancel) + +## `maprange`: checks for unnecessary calls to maps.Keys and maps.Values in range statements + + +Consider a loop written like this: + + for val := range maps.Values(m) { + fmt.Println(val) + } + +This should instead be written without the call to maps.Values: + + for _, val := range m { + fmt.Println(val) + } + +golang.org/x/exp/maps returns slices for Keys/Values instead of iterators, +but unnecessary calls should similarly be removed: + + for _, key := range maps.Keys(m) { + fmt.Println(key) + } + +should be rewritten as: + + for key := range m { + fmt.Println(key) + } + +Default: on. + +Package documentation: [maprange](https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/maprange) + ## `modernize`: simplify code by using modern constructs diff --git a/gopls/doc/release/v0.19.0.md b/gopls/doc/release/v0.19.0.md index 94f225a800f..05aeb2ec738 100644 --- a/gopls/doc/release/v0.19.0.md +++ b/gopls/doc/release/v0.19.0.md @@ -59,6 +59,12 @@ The new `recursiveiter` analyzer detects such mistakes; see documentation) for details, including tips on how to define simple and efficient recursive iterators. +## "Inefficient range over maps.Keys/Values" analyzer + +This analyzer detects redundant calls to `maps.Keys` or `maps.Values` +as the operand of a range loop; maps can of course be ranged over +directly. + ## "Implementations" supports signature types The Implementations query reports the correspondence between abstract diff --git a/gopls/internal/analysis/maprange/main.go b/gopls/internal/analysis/maprange/main.go new file mode 100644 index 00000000000..2ed5b36df08 --- /dev/null +++ b/gopls/internal/analysis/maprange/main.go @@ -0,0 +1,15 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build ignore + +// The unusedfunc command runs the maprange analyzer. +package main + +import ( + "golang.org/x/tools/go/analysis/singlechecker" + "golang.org/x/tools/gopls/internal/analysis/maprange" +) + +func main() { singlechecker.Main(maprange.Analyzer) } diff --git a/gopls/internal/doc/api.json b/gopls/internal/doc/api.json index cd325f364a3..a77b64c473a 100644 --- a/gopls/internal/doc/api.json +++ b/gopls/internal/doc/api.json @@ -1490,6 +1490,12 @@ "Default": "true", "Status": "" }, + { + "Name": "\"maprange\"", + "Doc": "checks for unnecessary calls to maps.Keys and maps.Values in range statements\n\nConsider a loop written like this:\n\n\tfor val := range maps.Values(m) {\n\t\tfmt.Println(val)\n\t}\n\nThis should instead be written without the call to maps.Values:\n\n\tfor _, val := range m {\n\t\tfmt.Println(val)\n\t}\n\ngolang.org/x/exp/maps returns slices for Keys/Values instead of iterators,\nbut unnecessary calls should similarly be removed:\n\n\tfor _, key := range maps.Keys(m) {\n\t\tfmt.Println(key)\n\t}\n\nshould be rewritten as:\n\n\tfor key := range m {\n\t\tfmt.Println(key)\n\t}", + "Default": "true", + "Status": "" + }, { "Name": "\"modernize\"", "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n B.Loop intentionally defeats compiler optimizations such as\n inlining so that the benchmark is not entirely optimized away.\n Currently, however, it may cause benchmarks to become slower\n in some cases due to increased allocation; see\n https://go.dev/issue/73137.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", @@ -3216,6 +3222,12 @@ "URL": "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/lostcancel", "Default": true }, + { + "Name": "maprange", + "Doc": "checks for unnecessary calls to maps.Keys and maps.Values in range statements\n\nConsider a loop written like this:\n\n\tfor val := range maps.Values(m) {\n\t\tfmt.Println(val)\n\t}\n\nThis should instead be written without the call to maps.Values:\n\n\tfor _, val := range m {\n\t\tfmt.Println(val)\n\t}\n\ngolang.org/x/exp/maps returns slices for Keys/Values instead of iterators,\nbut unnecessary calls should similarly be removed:\n\n\tfor _, key := range maps.Keys(m) {\n\t\tfmt.Println(key)\n\t}\n\nshould be rewritten as:\n\n\tfor key := range m {\n\t\tfmt.Println(key)\n\t}", + "URL": "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/maprange", + "Default": true + }, { "Name": "modernize", "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n B.Loop intentionally defeats compiler optimizations such as\n inlining so that the benchmark is not entirely optimized away.\n Currently, however, it may cause benchmarks to become slower\n in some cases due to increased allocation; see\n https://go.dev/issue/73137.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", diff --git a/gopls/internal/settings/analysis.go b/gopls/internal/settings/analysis.go index 59b88ba840f..ccc06b9ffea 100644 --- a/gopls/internal/settings/analysis.go +++ b/gopls/internal/settings/analysis.go @@ -53,6 +53,7 @@ import ( "golang.org/x/tools/gopls/internal/analysis/embeddirective" "golang.org/x/tools/gopls/internal/analysis/fillreturns" "golang.org/x/tools/gopls/internal/analysis/infertypeargs" + "golang.org/x/tools/gopls/internal/analysis/maprange" "golang.org/x/tools/gopls/internal/analysis/modernize" "golang.org/x/tools/gopls/internal/analysis/nonewvars" "golang.org/x/tools/gopls/internal/analysis/noresultvalues" @@ -245,6 +246,7 @@ var DefaultAnalyzers = []*Analyzer{ {analyzer: unusedfunc.Analyzer, severity: protocol.SeverityInformation}, {analyzer: unusedwrite.Analyzer, severity: protocol.SeverityInformation}, // uses go/ssa {analyzer: modernize.Analyzer, severity: protocol.SeverityHint}, + {analyzer: maprange.Analyzer, severity: protocol.SeverityHint}, // type-error analyzers // These analyzers enrich go/types errors with suggested fixes. diff --git a/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt b/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt index 252b4b4180a..c129a3d3b81 100644 --- a/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt +++ b/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt @@ -106,6 +106,15 @@ func _(c chan bool) { C.f(unsafe.Pointer(&c)) //@ diag("unsafe", re"passing Go type with embedded pointer to C") } +-- maprange/maprange.go -- +package maprange + +import "maps" + +func _(m map[int]int) { + for range maps.Keys(m) {} //@ diag("maps.Keys", re"unnecessary and inefficient call of maps.Keys") +} + -- staticcheck/staticcheck.go -- package staticcheck From 661b815f07f317c37d51d33298e8baafb480e6c0 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 30 May 2025 14:45:58 -0400 Subject: [PATCH 159/196] go/analysis/passes/unusedresult: add slices, maps functions Also, report the diagnostic about the entire f part of f(x). Change-Id: Iab84accf798f60a282c34278b23cf174b60840fd Reviewed-on: https://go-review.googlesource.com/c/tools/+/677538 LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Robert Findley Commit-Queue: Alan Donovan --- .../passes/unusedresult/unusedresult.go | 81 ++++++++++++++----- 1 file changed, 62 insertions(+), 19 deletions(-) diff --git a/go/analysis/passes/unusedresult/unusedresult.go b/go/analysis/passes/unusedresult/unusedresult.go index 932f1347e56..193cabc2022 100644 --- a/go/analysis/passes/unusedresult/unusedresult.go +++ b/go/analysis/passes/unusedresult/unusedresult.go @@ -26,6 +26,7 @@ import ( "golang.org/x/tools/go/analysis/passes/internal/analysisutil" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/types/typeutil" + "golang.org/x/tools/internal/analysisinternal" ) //go:embed doc.go @@ -59,23 +60,63 @@ func init() { // The context.With{Cancel,Deadline,Timeout} entries are // effectively redundant wrt the lostcancel analyzer. funcs = stringSetFlag{ - "context.WithCancel": true, - "context.WithDeadline": true, - "context.WithTimeout": true, - "context.WithValue": true, - "errors.New": true, - "fmt.Errorf": true, - "fmt.Sprint": true, - "fmt.Sprintf": true, - "slices.Clip": true, - "slices.Compact": true, - "slices.CompactFunc": true, - "slices.Delete": true, - "slices.DeleteFunc": true, - "slices.Grow": true, - "slices.Insert": true, - "slices.Replace": true, - "sort.Reverse": true, + "context.WithCancel": true, + "context.WithDeadline": true, + "context.WithTimeout": true, + "context.WithValue": true, + "errors.New": true, + "fmt.Append": true, + "fmt.Appendf": true, + "fmt.Appendln": true, + "fmt.Errorf": true, + "fmt.Sprint": true, + "fmt.Sprintf": true, + "fmt.Sprintln": true, + "maps.All": true, + "maps.Clone": true, + "maps.Collect": true, + "maps.Copy": true, + "maps.Equal": true, + "maps.EqualFunc": true, + "maps.Keys": true, + "maps.Values": true, + "slices.All": true, + "slices.AppendSeq": true, + "slices.Backward": true, + "slices.BinarySearch": true, + "slices.BinarySearchFunc": true, + "slices.Chunk": true, + "slices.Clip": true, + "slices.Clone": true, + "slices.Collect": true, + "slices.Compact": true, + "slices.CompactFunc": true, + "slices.Compare": true, + "slices.CompareFunc": true, + "slices.Concat": true, + "slices.Contains": true, + "slices.ContainsFunc": true, + "slices.Delete": true, + "slices.DeleteFunc": true, + "slices.Equal": true, + "slices.EqualFunc": true, + "slices.Grow": true, + "slices.Index": true, + "slices.IndexFunc": true, + "slices.Insert": true, + "slices.IsSorted": true, + "slices.IsSortedFunc": true, + "slices.Max": true, + "slices.MaxFunc": true, + "slices.Min": true, + "slices.MinFunc": true, + "slices.Repeat": true, + "slices.Replace": true, + "slices.Sorted": true, + "slices.SortedFunc": true, + "slices.SortedStableFunc": true, + "slices.Values": true, + "sort.Reverse": true, } Analyzer.Flags.Var(&funcs, "funcs", "comma-separated list of functions whose results must be used") @@ -114,14 +155,16 @@ func run(pass *analysis.Pass) (any, error) { // method (e.g. foo.String()) if types.Identical(sig, sigNoArgsStringResult) { if stringMethods[fn.Name()] { - pass.Reportf(call.Lparen, "result of (%s).%s call not used", + pass.ReportRangef(analysisinternal.Range(call.Pos(), call.Lparen), + "result of (%s).%s call not used", sig.Recv().Type(), fn.Name()) } } } else { // package-level function (e.g. fmt.Errorf) if pkgFuncs[[2]string{fn.Pkg().Path(), fn.Name()}] { - pass.Reportf(call.Lparen, "result of %s.%s call not used", + pass.ReportRangef(analysisinternal.Range(call.Pos(), call.Lparen), + "result of %s.%s call not used", fn.Pkg().Path(), fn.Name()) } } From ef3a8dc581e411d79ecfa0c79597387fe930feaf Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 30 May 2025 14:45:58 -0400 Subject: [PATCH 160/196] go/analysis/passes/unusedresult: add test The previous CL 677538 cannot be tested by the analysistest framework, so I added coverage in the form of a gopls marker test. Change-Id: Ie779ab458b64583949045bbdd21b9c469588c41b Reviewed-on: https://go-review.googlesource.com/c/tools/+/677539 Reviewed-by: Robert Findley Auto-Submit: Alan Donovan LUCI-TryBot-Result: Go LUCI --- .../test/marker/testdata/diagnostics/analyzers.txt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt b/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt index c129a3d3b81..1535e229f65 100644 --- a/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt +++ b/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt @@ -115,6 +115,15 @@ func _(m map[int]int) { for range maps.Keys(m) {} //@ diag("maps.Keys", re"unnecessary and inefficient call of maps.Keys") } +-- unusedresult/unusedresult.go -- +package unusedresult + +import "fmt" + +func _() { + fmt.Appendf(nil, "%d", 1) //@ diag("fmt.Appendf", re"result.*not used") +} + -- staticcheck/staticcheck.go -- package staticcheck From f80f3ff1d33581961cfa82f97175a12aa85b196a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 30 May 2025 17:49:33 -0400 Subject: [PATCH 161/196] gopls/internal/protocol: add Mapper.{Pos,Node]Text helpers These methods return the text of a token.Pos range or Node. Also, move Mapper.Location to DocumentURI, and use it in more places. Change-Id: I62f21631feb92d2ea330ff04ddde77e80170d18b Reviewed-on: https://go-review.googlesource.com/c/tools/+/677519 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Auto-Submit: Alan Donovan --- gopls/internal/cache/check.go | 4 ++-- gopls/internal/cache/errors.go | 12 +++------- gopls/internal/cache/parsego/file.go | 10 +++++++++ gopls/internal/cmd/cmd.go | 2 +- gopls/internal/golang/addtest.go | 8 +++---- gopls/internal/golang/call_hierarchy.go | 5 +---- gopls/internal/golang/codeaction.go | 7 +++--- gopls/internal/golang/comment.go | 4 ++-- gopls/internal/golang/compileropt.go | 7 ++---- gopls/internal/golang/completion/newfile.go | 4 ++-- gopls/internal/golang/extract.go | 4 ++-- gopls/internal/golang/extracttofile.go | 8 +++---- gopls/internal/golang/folding_range.go | 4 ++-- gopls/internal/golang/pkgdoc.go | 12 ++++------ gopls/internal/golang/workspace_symbol.go | 9 +++----- gopls/internal/protocol/mapper.go | 19 +++++++++++----- gopls/internal/protocol/uri.go | 5 +++++ gopls/internal/server/symbols.go | 5 +---- gopls/internal/template/implementations.go | 4 ++-- .../internal/test/integration/fake/editor.go | 2 +- gopls/internal/test/marker/marker_test.go | 22 +++++++------------ 21 files changed, 76 insertions(+), 81 deletions(-) diff --git a/gopls/internal/cache/check.go b/gopls/internal/cache/check.go index a23ec4c2937..ecb4da2a48b 100644 --- a/gopls/internal/cache/check.go +++ b/gopls/internal/cache/check.go @@ -2129,11 +2129,11 @@ func typeErrorsToDiagnostics(pkg *syntaxPackage, inputs *typeCheckInputs, errs [ if i > 0 && len(diags) > 0 { primary := diags[0] primary.Related = append(primary.Related, protocol.DiagnosticRelatedInformation{ - Location: protocol.Location{URI: diag.URI, Range: diag.Range}, + Location: diag.URI.Location(diag.Range), Message: related[i].Msg, // use the unmodified secondary error for related errors. }) diag.Related = []protocol.DiagnosticRelatedInformation{{ - Location: protocol.Location{URI: primary.URI, Range: primary.Range}, + Location: primary.URI.Location(primary.Range), }} } diags = append(diags, diag) diff --git a/gopls/internal/cache/errors.go b/gopls/internal/cache/errors.go index 39eb8387702..26f30c8c4dc 100644 --- a/gopls/internal/cache/errors.go +++ b/gopls/internal/cache/errors.go @@ -168,11 +168,8 @@ func encodeDiagnostics(srcDiags []*Diagnostic) []byte { for uri, srcEdits := range srcFix.Edits { for _, srcEdit := range srcEdits { gobFix.TextEdits = append(gobFix.TextEdits, gobTextEdit{ - Location: protocol.Location{ - URI: uri, - Range: srcEdit.Range, - }, - NewText: []byte(srcEdit.NewText), + Location: uri.Location(srcEdit.Range), + NewText: []byte(srcEdit.NewText), }) } } @@ -191,10 +188,7 @@ func encodeDiagnostics(srcDiags []*Diagnostic) []byte { gobRelated = append(gobRelated, gobRel) } gobDiag := gobDiagnostic{ - Location: protocol.Location{ - URI: srcDiag.URI, - Range: srcDiag.Range, - }, + Location: srcDiag.URI.Location(srcDiag.Range), Severity: srcDiag.Severity, Code: srcDiag.Code, CodeHref: srcDiag.CodeHref, diff --git a/gopls/internal/cache/parsego/file.go b/gopls/internal/cache/parsego/file.go index ef8a3379b03..7254e1f4621 100644 --- a/gopls/internal/cache/parsego/file.go +++ b/gopls/internal/cache/parsego/file.go @@ -89,6 +89,11 @@ func (pgf *File) PosLocation(start, end token.Pos) (protocol.Location, error) { return pgf.Mapper.PosLocation(pgf.Tok, start, end) } +// PosText returns the source text for the token.Pos interval in this file. +func (pgf *File) PosText(start, end token.Pos) ([]byte, error) { + return pgf.Mapper.PosText(pgf.Tok, start, end) +} + // NodeRange returns a protocol Range for the ast.Node interval in this file. func (pgf *File) NodeRange(node ast.Node) (protocol.Range, error) { return pgf.Mapper.NodeRange(pgf.Tok, node) @@ -104,6 +109,11 @@ func (pgf *File) NodeLocation(node ast.Node) (protocol.Location, error) { return pgf.Mapper.PosLocation(pgf.Tok, node.Pos(), node.End()) } +// NodeText returns the source text for the ast.Node interval in this file. +func (pgf *File) NodeText(node ast.Node) ([]byte, error) { + return pgf.Mapper.NodeText(pgf.Tok, node) +} + // RangePos parses a protocol Range back into the go/token domain. func (pgf *File) RangePos(r protocol.Range) (token.Pos, token.Pos, error) { start, end, err := pgf.Mapper.RangeOffsets(r) diff --git a/gopls/internal/cmd/cmd.go b/gopls/internal/cmd/cmd.go index ed05235a5dc..a572622e682 100644 --- a/gopls/internal/cmd/cmd.go +++ b/gopls/internal/cmd/cmd.go @@ -887,7 +887,7 @@ func (f *cmdFile) spanLocation(s span) (protocol.Location, error) { if err != nil { return protocol.Location{}, err } - return f.mapper.RangeLocation(rng), nil + return f.mapper.URI.Location(rng), nil } // spanRange converts a (UTF-8) span to a protocol (UTF-16) range. diff --git a/gopls/internal/golang/addtest.go b/gopls/internal/golang/addtest.go index da9a8ecc88c..73665ce9755 100644 --- a/gopls/internal/golang/addtest.go +++ b/gopls/internal/golang/addtest.go @@ -320,11 +320,11 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. // Search for something that looks like a copyright header, to replicate // in the new file. if c := CopyrightComment(pgf.File); c != nil { - start, end, err := pgf.NodeOffsets(c) + text, err := pgf.NodeText(c) if err != nil { return nil, err } - header.Write(pgf.Src[start:end]) + header.Write(text) // One empty line between copyright header and following. header.WriteString("\n\n") } @@ -332,11 +332,11 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. // If this test file was created by gopls, add build constraints // matching the non-test file. if c := buildConstraintComment(pgf.File); c != nil { - start, end, err := pgf.NodeOffsets(c) + text, err := pgf.NodeText(c) if err != nil { return nil, err } - header.Write(pgf.Src[start:end]) + header.Write(text) // One empty line between build constraint and following. header.WriteString("\n\n") } diff --git a/gopls/internal/golang/call_hierarchy.go b/gopls/internal/golang/call_hierarchy.go index b9f21cd18d7..1193d7e8de8 100644 --- a/gopls/internal/golang/call_hierarchy.go +++ b/gopls/internal/golang/call_hierarchy.go @@ -88,10 +88,7 @@ func IncomingCalls(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle event.Error(ctx, fmt.Sprintf("error getting enclosing node for %q", ref.pkgPath), err) continue } - loc := protocol.Location{ - URI: callItem.URI, - Range: callItem.Range, - } + loc := callItem.URI.Location(callItem.Range) call, ok := incomingCalls[loc] if !ok { call = &protocol.CallHierarchyIncomingCall{From: callItem} diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index 3d43d5694dc..703b06bc6a2 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -41,8 +41,7 @@ import ( // // See ../protocol/codeactionkind.go for some code action theory. func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, rng protocol.Range, diagnostics []protocol.Diagnostic, enabled func(protocol.CodeActionKind) bool, trigger protocol.CodeActionTriggerKind) (actions []protocol.CodeAction, _ error) { - - loc := protocol.Location{URI: fh.URI(), Range: rng} + loc := fh.URI().Location(rng) pgf, err := snapshot.ParseGo(ctx, fh, parsego.Full) if err != nil { @@ -513,11 +512,11 @@ func refactorExtractVariableAll(ctx context.Context, req *codeActionsRequest) er // Don't suggest if only one expr is found, // otherwise it will duplicate with [refactorExtractVariable] if exprs, err := canExtractVariable(info, req.pgf.Cursor, req.start, req.end, true); err == nil && len(exprs) > 1 { - start, end, err := req.pgf.NodeOffsets(exprs[0]) + text, err := req.pgf.NodeText(exprs[0]) if err != nil { return err } - desc := string(req.pgf.Src[start:end]) + desc := string(text) if len(desc) >= 40 || strings.Contains(desc, "\n") { desc = astutil.NodeDescription(exprs[0]) } diff --git a/gopls/internal/golang/comment.go b/gopls/internal/golang/comment.go index a58045b1819..64636573b8b 100644 --- a/gopls/internal/golang/comment.go +++ b/gopls/internal/golang/comment.go @@ -103,12 +103,12 @@ func parseDocLink(pkg *cache.Package, pgf *parsego.File, pos token.Pos) (types.O end = comment.End() } - offsetStart, offsetEnd, err := safetoken.Offsets(pgf.Tok, start, end) + textBytes, err := pgf.PosText(start, end) if err != nil { return nil, protocol.Range{}, err } - text := string(pgf.Src[offsetStart:offsetEnd]) + text := string(textBytes) lineOffset := int(pos - start) for _, idx := range docLinkRegex.FindAllStringSubmatchIndex(text, -1) { diff --git a/gopls/internal/golang/compileropt.go b/gopls/internal/golang/compileropt.go index df6c58145bf..ab219167ebf 100644 --- a/gopls/internal/golang/compileropt.go +++ b/gopls/internal/golang/compileropt.go @@ -149,11 +149,8 @@ func parseDetailsFile(filename string, options *settings.Options) (protocol.Docu var related []protocol.DiagnosticRelatedInformation for _, ri := range d.RelatedInformation { related = append(related, protocol.DiagnosticRelatedInformation{ - Location: protocol.Location{ - URI: ri.Location.URI, - Range: zeroIndexedRange(ri.Location.Range), - }, - Message: ri.Message, + Location: ri.Location.URI.Location(zeroIndexedRange(ri.Location.Range)), + Message: ri.Message, }) } diagnostic := &cache.Diagnostic{ diff --git a/gopls/internal/golang/completion/newfile.go b/gopls/internal/golang/completion/newfile.go index 38dcadc238f..3208317145e 100644 --- a/gopls/internal/golang/completion/newfile.go +++ b/gopls/internal/golang/completion/newfile.go @@ -40,11 +40,11 @@ func NewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) (*pr continue } if group := golang.CopyrightComment(pgf.File); group != nil { - start, end, err := pgf.NodeOffsets(group) + text, err := pgf.NodeText(group) if err != nil { continue } - buf.Write(pgf.Src[start:end]) + buf.Write(text) buf.WriteString("\n\n") break } diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index cc0ee536b1c..37c3352a68f 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -1223,11 +1223,11 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to for i, br := range freeBranches { // Preserve spacing at the beginning of the line containing the branch statement. startPos := tok.LineStart(safetoken.Line(tok, br.Pos())) - start, end, err := safetoken.Offsets(tok, startPos, br.End()) + text, err := pgf.PosText(startPos, br.End()) if err != nil { return nil, nil, err } - fmt.Fprintf(&fullReplacement, "case %d:\n%s%s", i+1, pgf.Src[start:end], newLineIndent) + fmt.Fprintf(&fullReplacement, "case %d:\n%s%s", i+1, text, newLineIndent) } fullReplacement.WriteString("}") } diff --git a/gopls/internal/golang/extracttofile.go b/gopls/internal/golang/extracttofile.go index cc833f12c42..6aef6636e48 100644 --- a/gopls/internal/golang/extracttofile.go +++ b/gopls/internal/golang/extracttofile.go @@ -153,21 +153,21 @@ func ExtractToNewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Han var buf bytes.Buffer if c := CopyrightComment(pgf.File); c != nil { - start, end, err := pgf.NodeOffsets(c) + text, err := pgf.NodeText(c) if err != nil { return nil, err } - buf.Write(pgf.Src[start:end]) + buf.Write(text) // One empty line between copyright header and following. buf.WriteString("\n\n") } if c := buildConstraintComment(pgf.File); c != nil { - start, end, err := pgf.NodeOffsets(c) + text, err := pgf.NodeText(c) if err != nil { return nil, err } - buf.Write(pgf.Src[start:end]) + buf.Write(text) // One empty line between build constraint and following. buf.WriteString("\n\n") } diff --git a/gopls/internal/golang/folding_range.go b/gopls/internal/golang/folding_range.go index 2cf9f9a6b94..d88def1b739 100644 --- a/gopls/internal/golang/folding_range.go +++ b/gopls/internal/golang/folding_range.go @@ -174,12 +174,12 @@ func getLineFoldingRange(pgf *parsego.File, open, close token.Pos, lineFoldingOn // isOnlySpaceBetween returns true if there are only space characters between "from" and "to". isOnlySpaceBetween := func(from token.Pos, to token.Pos) bool { - start, end, err := safetoken.Offsets(pgf.Tok, from, to) + text, err := pgf.PosText(from, to) if err != nil { bug.Reportf("failed to get offsets: %s", err) // can't happen return false } - return len(bytes.TrimSpace(pgf.Src[start:end])) == 0 + return len(bytes.TrimSpace(text)) == 0 } nextLine := safetoken.Line(pgf.Tok, open) + 1 diff --git a/gopls/internal/golang/pkgdoc.go b/gopls/internal/golang/pkgdoc.go index 9f2b2bf51a4..bd4c2cecfef 100644 --- a/gopls/internal/golang/pkgdoc.go +++ b/gopls/internal/golang/pkgdoc.go @@ -49,7 +49,6 @@ import ( "golang.org/x/tools/gopls/internal/protocol" goplsastutil "golang.org/x/tools/gopls/internal/util/astutil" "golang.org/x/tools/gopls/internal/util/bug" - "golang.org/x/tools/gopls/internal/util/safetoken" "golang.org/x/tools/internal/stdlib" "golang.org/x/tools/internal/typesinternal" ) @@ -577,15 +576,12 @@ window.addEventListener('load', function() { if !to.IsValid() { bug.Reportf("invalid Pos") } - start, err := safetoken.Offset(file.Tok, pos) + text, err := file.PosText(pos, to) if err != nil { - bug.Reportf("invalid start Pos: %v", err) + bug.Reportf("invalid pos range: %v", err) + return } - end, err := safetoken.Offset(file.Tok, to) - if err != nil { - bug.Reportf("invalid end Pos: %v", err) - } - buf.WriteString(escape(string(file.Src[start:end]))) + buf.WriteString(escape(string(text))) pos = to } ast.Inspect(n, func(n ast.Node) bool { diff --git a/gopls/internal/golang/workspace_symbol.go b/gopls/internal/golang/workspace_symbol.go index 1a0819b4d52..27f96b5fffb 100644 --- a/gopls/internal/golang/workspace_symbol.go +++ b/gopls/internal/golang/workspace_symbol.go @@ -499,12 +499,9 @@ func matchFile(store *symbolStore, symbolizer symbolizer, matcher matcherFunc, f si := &scoredSymbol{ score: score, info: protocol.SymbolInformation{ - Name: strings.Join(symbolParts, ""), - Kind: sym.Kind, - Location: protocol.Location{ - URI: f.uri, - Range: sym.Range, - }, + Name: strings.Join(symbolParts, ""), + Kind: sym.Kind, + Location: f.uri.Location(sym.Range), ContainerName: string(f.mp.PkgPath), }, } diff --git a/gopls/internal/protocol/mapper.go b/gopls/internal/protocol/mapper.go index a4aa2e2efe8..17935b4c4a0 100644 --- a/gopls/internal/protocol/mapper.go +++ b/gopls/internal/protocol/mapper.go @@ -163,7 +163,7 @@ func (m *Mapper) OffsetLocation(start, end int) (Location, error) { if err != nil { return Location{}, err } - return m.RangeLocation(rng), nil + return m.URI.Location(rng), nil } // OffsetRange converts a byte-offset interval to a protocol (UTF-16) range. @@ -324,7 +324,7 @@ func (m *Mapper) PosLocation(tf *token.File, start, end token.Pos) (Location, er if err != nil { return Location{}, err } - return m.RangeLocation(rng), nil + return m.URI.Location(rng), nil } // PosRange converts a token range to a protocol (UTF-16) range. @@ -336,14 +336,23 @@ func (m *Mapper) PosRange(tf *token.File, start, end token.Pos) (Range, error) { return m.OffsetRange(startOffset, endOffset) } +// PosText returns the source text for the token range. +func (m *Mapper) PosText(tf *token.File, start, end token.Pos) ([]byte, error) { + startOffset, endOffset, err := safetoken.Offsets(tf, start, end) + if err != nil { + return nil, err + } + return m.Content[startOffset:endOffset], nil +} + // NodeRange converts a syntax node range to a protocol (UTF-16) range. func (m *Mapper) NodeRange(tf *token.File, node ast.Node) (Range, error) { return m.PosRange(tf, node.Pos(), node.End()) } -// RangeLocation pairs a protocol Range with its URI, in a Location. -func (m *Mapper) RangeLocation(rng Range) Location { - return Location{URI: m.URI, Range: rng} +// NodeText returns the source text for syntax node range. +func (m *Mapper) NodeText(tf *token.File, node ast.Node) ([]byte, error) { + return m.PosText(tf, node.Pos(), node.End()) } // LocationTextDocumentPositionParams converts its argument to its result. diff --git a/gopls/internal/protocol/uri.go b/gopls/internal/protocol/uri.go index 491d767805f..5d00009b30d 100644 --- a/gopls/internal/protocol/uri.go +++ b/gopls/internal/protocol/uri.go @@ -110,6 +110,11 @@ func (uri DocumentURI) Encloses(file DocumentURI) bool { return pathutil.InDir(uri.Path(), file.Path()) } +// Locationr returns the Location for the specified range of this URI's file. +func (uri DocumentURI) Location(rng Range) Location { + return Location{URI: uri, Range: rng} +} + func filename(uri DocumentURI) (string, error) { if uri == "" { return "", nil diff --git a/gopls/internal/server/symbols.go b/gopls/internal/server/symbols.go index 40df7369f51..334154add5b 100644 --- a/gopls/internal/server/symbols.go +++ b/gopls/internal/server/symbols.go @@ -52,10 +52,7 @@ func (s *server) DocumentSymbol(ctx context.Context, params *protocol.DocumentSy Name: s.Name, Kind: s.Kind, Deprecated: s.Deprecated, - Location: protocol.Location{ - URI: params.TextDocument.URI, - Range: s.Range, - }, + Location: params.TextDocument.URI.Location(s.Range), } } return symbols, nil diff --git a/gopls/internal/template/implementations.go b/gopls/internal/template/implementations.go index 4ed485cfee2..5ae4bf2a182 100644 --- a/gopls/internal/template/implementations.go +++ b/gopls/internal/template/implementations.go @@ -96,7 +96,7 @@ func Definition(snapshot *cache.Snapshot, fh file.Handle, loc protocol.Position) if !s.vardef || s.name != sym { continue } - ans = append(ans, protocol.Location{URI: k, Range: p.Range(s.start, s.length)}) + ans = append(ans, k.Location(p.Range(s.start, s.length))) } } return ans, nil @@ -149,7 +149,7 @@ func References(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p if s.vardef && !params.Context.IncludeDeclaration { continue } - ans = append(ans, protocol.Location{URI: k, Range: p.Range(s.start, s.length)}) + ans = append(ans, k.Location(p.Range(s.start, s.length))) } } // do these need to be sorted? (a.files is a map) diff --git a/gopls/internal/test/integration/fake/editor.go b/gopls/internal/test/integration/fake/editor.go index 4acfb06393d..6ac9dc17e04 100644 --- a/gopls/internal/test/integration/fake/editor.go +++ b/gopls/internal/test/integration/fake/editor.go @@ -1724,7 +1724,7 @@ func (e *Editor) Hover(ctx context.Context, loc protocol.Location) (*protocol.Ma if resp == nil { return nil, protocol.Location{}, nil // e.g. no selected symbol } - return &resp.Contents, protocol.Location{URI: loc.URI, Range: resp.Range}, nil + return &resp.Contents, loc.URI.Location(resp.Range), nil } func (e *Editor) DocumentLink(ctx context.Context, path string) ([]protocol.DocumentLink, error) { diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index 83df6781662..8bb2fc35490 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -1817,7 +1817,7 @@ func sortDocumentHighlights(s []protocol.DocumentHighlight) { func highlightAllMarker(mark marker, all ...protocol.DocumentHighlight) { sortDocumentHighlights(all) for _, src := range all { - loc := protocol.Location{URI: mark.uri(), Range: src.Range} + loc := mark.uri().Location(src.Range) got := mark.run.env.DocumentHighlight(loc) sortDocumentHighlights(got) @@ -1828,7 +1828,7 @@ func highlightAllMarker(mark marker, all ...protocol.DocumentHighlight) { } func highlightMarker(mark marker, src protocol.DocumentHighlight, dsts ...protocol.DocumentHighlight) { - loc := protocol.Location{URI: mark.uri(), Range: src.Range} + loc := mark.uri().Location(src.Range) got := mark.run.env.DocumentHighlight(loc) sortDocumentHighlights(got) @@ -2198,7 +2198,7 @@ func documentLinkMarker(mark marker, g *Golden) { mark.errorf("%s: nil link target", l.Range) continue } - loc := protocol.Location{URI: mark.uri(), Range: l.Range} + loc := mark.uri().Location(l.Range) fmt.Fprintln(&b, mark.run.fmtLocForGolden(loc), *l.Target) } @@ -2438,13 +2438,6 @@ func implementationMarker(mark marker, src protocol.Location, want ...protocol.L } } -func itemLocation(item protocol.CallHierarchyItem) protocol.Location { - return protocol.Location{ - URI: item.URI, - Range: item.Range, - } -} - func mcpToolMarker(mark marker, tool string, args string, loc protocol.Location) { var toolArgs map[string]any if err := json.Unmarshal([]byte(args), &toolArgs); err != nil { @@ -2494,7 +2487,7 @@ func incomingCallsMarker(mark marker, src protocol.Location, want ...protocol.Lo } var locs []protocol.Location for _, call := range calls { - locs = append(locs, itemLocation(call.From)) + locs = append(locs, call.From.URI.Location(call.From.Range)) } return locs, nil } @@ -2509,7 +2502,7 @@ func outgoingCallsMarker(mark marker, src protocol.Location, want ...protocol.Lo } var locs []protocol.Location for _, call := range calls { - locs = append(locs, itemLocation(call.To)) + locs = append(locs, call.To.URI.Location(call.To.Range)) } return locs, nil } @@ -2530,7 +2523,8 @@ func callHierarchy(mark marker, src protocol.Location, getCalls callHierarchyFun mark.errorf("PrepareCallHierarchy returned %d items, want exactly 1", nitems) return } - if loc := itemLocation(items[0]); loc != src { + item := items[0] + if loc := item.URI.Location(item.Range); loc != src { mark.errorf("PrepareCallHierarchy found call %v, want %v", loc, src) return } @@ -2641,7 +2635,7 @@ func typeHierarchy(mark marker, src protocol.Location, want []protocol.Location, } got := []protocol.Location{} // non-nil; cmp.Diff cares for _, item := range items { - got = append(got, protocol.Location{URI: item.URI, Range: item.Range}) + got = append(got, item.URI.Location(item.Range)) } if d := cmp.Diff(want, got); d != "" { mark.errorf("type hierarchy: unexpected results (-want +got):\n%s", d) From 39596567a4b57eedb9f51504e4c1c55b7a5e04a8 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Wed, 28 May 2025 17:48:35 +0000 Subject: [PATCH 162/196] internal/mcp: add iterator methods resources and prompts This CL streamlines the client side pagination logic by adding a generic helper that handles pagination. This also simplifies the existing tools iterator method. Change-Id: If4304deda1ccd22d0c4580d76deaefc64e5dfeb9 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676975 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Auto-Submit: Sam Thanawalla Reviewed-by: Jonathan Amsterdam --- internal/mcp/client.go | 65 +++++++-- internal/mcp/client_test.go | 201 ++++++++++++++++++++++++++++ internal/mcp/generate.go | 7 + internal/mcp/protocol.go | 18 ++- internal/mcp/server_example_test.go | 116 +++++++++------- internal/mcp/server_test.go | 4 +- 6 files changed, 347 insertions(+), 64 deletions(-) create mode 100644 internal/mcp/client_test.go diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 3b577a19896..e64bcb238d2 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -319,27 +319,72 @@ func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, para // Tools provides an iterator for all tools available on the server, // automatically fetching pages and managing cursors. // The `params` argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. func (cs *ClientSession) Tools(ctx context.Context, params *ListToolsParams) iter.Seq2[Tool, error] { - currentParams := &ListToolsParams{} - if params != nil { - *currentParams = *params + if params == nil { + params = &ListToolsParams{} } - return func(yield func(Tool, error) bool) { + return paginate(ctx, params, cs.ListTools, func(res *ListToolsResult) []*Tool { + return res.Tools + }) +} + +// Resources provides an iterator for all resources available on the server, +// automatically fetching pages and managing cursors. +// The `params` argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Resources(ctx context.Context, params *ListResourcesParams) iter.Seq2[Resource, error] { + if params == nil { + params = &ListResourcesParams{} + } + return paginate(ctx, params, cs.ListResources, func(res *ListResourcesResult) []*Resource { + return res.Resources + }) +} + +// Prompts provides an iterator for all prompts available on the server, +// automatically fetching pages and managing cursors. +// The `params` argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Prompts(ctx context.Context, params *ListPromptsParams) iter.Seq2[Prompt, error] { + if params == nil { + params = &ListPromptsParams{} + } + return paginate(ctx, params, cs.ListPrompts, func(res *ListPromptsResult) []*Prompt { + return res.Prompts + }) +} + +type ListParams interface { + // Returns a pointer to the param's Cursor field. + cursorPtr() *string +} + +type ListResult[T any] interface { + // Returns a pointer to the param's NextCursor field. + nextCursorPtr() *string +} + +// paginate is a generic helper function to provide a paginated iterator. +func paginate[P ListParams, R ListResult[E], E any](ctx context.Context, params P, listFunc func(context.Context, P) (R, error), items func(R) []*E) iter.Seq2[E, error] { + return func(yield func(E, error) bool) { for { - res, err := cs.ListTools(ctx, currentParams) + res, err := listFunc(ctx, params) if err != nil { - yield(Tool{}, err) + var zero E + yield(zero, err) return } - for _, t := range res.Tools { - if !yield(*t, nil) { + for _, r := range items(res) { + if !yield(*r, nil) { return } } - if res.NextCursor == "" { + nextCursorVal := res.nextCursorPtr() + if nextCursorVal == nil || *nextCursorVal == "" { return } - currentParams.Cursor = res.NextCursor + *params.cursorPtr() = *nextCursorVal } } } diff --git a/internal/mcp/client_test.go b/internal/mcp/client_test.go new file mode 100644 index 00000000000..3d8b8b23dc6 --- /dev/null +++ b/internal/mcp/client_test.go @@ -0,0 +1,201 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "golang.org/x/tools/internal/mcp/jsonschema" +) + +type Item struct { + Name string + Value string +} + +type ListTestParams struct { + Cursor string +} + +func (p *ListTestParams) cursorPtr() *string { + return &p.Cursor +} + +type ListTestResult struct { + Items []*Item + NextCursor string +} + +func (r *ListTestResult) nextCursorPtr() *string { + return &r.NextCursor +} + +var allItems = []*Item{ + {"alpha", "val-A"}, + {"bravo", "val-B"}, + {"charlie", "val-C"}, + {"delta", "val-D"}, + {"echo", "val-E"}, + {"foxtrot", "val-F"}, + {"golf", "val-G"}, + {"hotel", "val-H"}, + {"india", "val-I"}, + {"juliet", "val-J"}, + {"kilo", "val-K"}, +} + +func toItemValueSlice(ptrSlice []*Item) []Item { + var valSlice []Item + for _, ptr := range ptrSlice { + valSlice = append(valSlice, *ptr) + } + return valSlice +} + +// generatePaginatedResults is a helper to create a sequence of mock responses for pagination. +// It simulates a server returning items in pages based on a given page size. +func generatePaginatedResults(all []*Item, pageSize int) []*ListTestResult { + if len(all) == 0 { + return []*ListTestResult{{Items: []*Item{}, NextCursor: ""}} + } + if pageSize <= 0 { + panic("pageSize must be greater than 0") + } + numPages := (len(all) + pageSize - 1) / pageSize // Ceiling division + var results []*ListTestResult + for i := range numPages { + startIndex := i * pageSize + endIndex := min(startIndex+pageSize, len(all)) // Use min to prevent out of bounds + nextCursor := "" + if endIndex < len(all) { // If there are more items after this page + nextCursor = fmt.Sprintf("cursor_%d", endIndex) + } + results = append(results, &ListTestResult{Items: all[startIndex:endIndex], NextCursor: nextCursor}) + } + return results +} + +func TestClientPaginateBasic(t *testing.T) { + ctx := context.Background() + testCases := []struct { + name string + results []*ListTestResult + mockError error + initialParams *ListTestParams + expected []Item + expectError bool + }{ + { + name: "SinglePageAllItems", + results: generatePaginatedResults(allItems, len(allItems)), + expected: toItemValueSlice(allItems), + }, + { + name: "MultiplePages", + results: generatePaginatedResults(allItems, 3), + expected: toItemValueSlice(allItems), + }, + { + name: "EmptyResults", + results: generatePaginatedResults([]*Item{}, 10), + expected: nil, + }, + { + name: "ListFuncReturnsErrorImmediately", + results: []*ListTestResult{{}}, + mockError: fmt.Errorf("API error on first call"), + expected: nil, + expectError: true, + }, + { + name: "InitialCursorProvided", + initialParams: &ListTestParams{Cursor: "cursor_2"}, + results: generatePaginatedResults(allItems[2:], 3), + expected: toItemValueSlice(allItems[2:]), + }, + { + name: "CursorBeyondAllItems", + initialParams: &ListTestParams{Cursor: "cursor_999"}, + results: []*ListTestResult{{Items: []*Item{}, NextCursor: ""}}, + expected: nil, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + listFunc := func(ctx context.Context, params *ListTestParams) (*ListTestResult, error) { + if len(tc.results) == 0 { + t.Fatalf("listFunc called but no more results defined for test case %q", tc.name) + } + res := tc.results[0] + tc.results = tc.results[1:] + var err error + if tc.mockError != nil { + err = tc.mockError + } + return res, err + } + + params := tc.initialParams + if tc.initialParams == nil { + params = &ListTestParams{} + } + + var gotItems []Item + var iterationErr error + seq := paginate(ctx, params, listFunc, func(r *ListTestResult) []*Item { return r.Items }) + for item, err := range seq { + if err != nil { + iterationErr = err + break + } + gotItems = append(gotItems, item) + } + if tc.expectError { + if iterationErr == nil { + t.Errorf("paginate() expected an error during iteration, but got none") + } + } else { + if iterationErr != nil { + t.Errorf("paginate() got: %v, want: nil", iterationErr) + } + } + if diff := cmp.Diff(tc.expected, gotItems, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("paginate() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestClientPaginateVariousPageSizes(t *testing.T) { + ctx := context.Background() + for i := 1; i < len(allItems)+1; i++ { + testname := fmt.Sprintf("PageSize=%d", i) + t.Run(testname, func(t *testing.T) { + results := generatePaginatedResults(allItems, i) + listFunc := func(ctx context.Context, params *ListTestParams) (*ListTestResult, error) { + res := results[0] + results = results[1:] + return res, nil + } + var gotItems []Item + seq := paginate(ctx, &ListTestParams{}, listFunc, func(r *ListTestResult) []*Item { return r.Items }) + for item, err := range seq { + if err != nil { + t.Fatalf("paginate() unexpected error during iteration: %v", err) + } + gotItems = append(gotItems, item) + } + if diff := cmp.Diff(toItemValueSlice(allItems), gotItems, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("paginate() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index 670a79b5969..80ce74dddd6 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -332,6 +332,13 @@ func writeDecl(configName string, config typeConfig, def *jsonschema.Schema, nam fmt.Fprintf(w, "\nfunc (x *%s) GetMeta() *Meta { return &x.Meta }", typeName) } + if _, ok := def.Properties["cursor"]; ok { + fmt.Fprintf(w, "\nfunc (x *%s) cursorPtr() *string { return &x.Cursor }", typeName) + } + if _, ok := def.Properties["nextCursor"]; ok { + fmt.Fprintf(w, "\nfunc (x *%s) nextCursorPtr() *string { return &x.NextCursor }", typeName) + } + return nil } diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index 4e212c6afea..fd4b9af2ae2 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -214,7 +214,8 @@ type ListPromptsParams struct { Cursor string `json:"cursor,omitempty"` } -func (x *ListPromptsParams) GetMeta() *Meta { return &x.Meta } +func (x *ListPromptsParams) GetMeta() *Meta { return &x.Meta } +func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } // The server's response to a prompts/list request from the client. type ListPromptsResult struct { @@ -227,7 +228,8 @@ type ListPromptsResult struct { Prompts []*Prompt `json:"prompts"` } -func (x *ListPromptsResult) GetMeta() *Meta { return &x.Meta } +func (x *ListPromptsResult) GetMeta() *Meta { return &x.Meta } +func (x *ListPromptsResult) nextCursorPtr() *string { return &x.NextCursor } type ListResourcesParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -238,7 +240,8 @@ type ListResourcesParams struct { Cursor string `json:"cursor,omitempty"` } -func (x *ListResourcesParams) GetMeta() *Meta { return &x.Meta } +func (x *ListResourcesParams) GetMeta() *Meta { return &x.Meta } +func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } // The server's response to a resources/list request from the client. type ListResourcesResult struct { @@ -251,7 +254,8 @@ type ListResourcesResult struct { Resources []*Resource `json:"resources"` } -func (x *ListResourcesResult) GetMeta() *Meta { return &x.Meta } +func (x *ListResourcesResult) GetMeta() *Meta { return &x.Meta } +func (x *ListResourcesResult) nextCursorPtr() *string { return &x.NextCursor } type ListRootsParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -282,7 +286,8 @@ type ListToolsParams struct { Cursor string `json:"cursor,omitempty"` } -func (x *ListToolsParams) GetMeta() *Meta { return &x.Meta } +func (x *ListToolsParams) GetMeta() *Meta { return &x.Meta } +func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } // The server's response to a tools/list request from the client. type ListToolsResult struct { @@ -295,7 +300,8 @@ type ListToolsResult struct { Tools []*Tool `json:"tools"` } -func (x *ListToolsResult) GetMeta() *Meta { return &x.Meta } +func (x *ListToolsResult) GetMeta() *Meta { return &x.Meta } +func (x *ListToolsResult) nextCursorPtr() *string { return &x.NextCursor } // The severity of a log message. // diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index 231a329d004..800fd5d906b 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -57,8 +57,8 @@ func ExampleServer() { } // createSessions creates and connects an in-memory client and server session for testing purposes. -func createSessions(ctx context.Context, opts *mcp.ServerOptions) (*mcp.ClientSession, *mcp.ServerSession, *mcp.Server) { - server := mcp.NewServer("server", "v0.0.1", opts) +func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession, *mcp.Server) { + server := mcp.NewServer("server", "v0.0.1", nil) client := mcp.NewClient("client", "v0.0.1", nil) serverTransport, clientTransport := mcp.NewInMemoryTransports() serverSession, err := server.Connect(ctx, serverTransport) @@ -77,42 +77,33 @@ func TestListTools(t *testing.T) { toolB := mcp.NewTool("banana", "banana tool", SayHi) toolC := mcp.NewTool("cherry", "cherry tool", SayHi) tools := []*mcp.ServerTool{toolA, toolB, toolC} - wantListTools := []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool} - wantIteratorTools := []mcp.Tool{*toolA.Tool, *toolB.Tool, *toolC.Tool} ctx := context.Background() + clientSession, serverSession, server := createSessions(ctx) + defer clientSession.Close() + defer serverSession.Close() + server.AddTools(tools...) t.Run("ListTools", func(t *testing.T) { - clientSession, serverSession, server := createSessions(ctx, nil) - defer clientSession.Close() - defer serverSession.Close() - server.AddTools(tools...) + wantTools := []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool} res, err := clientSession.ListTools(ctx, nil) if err != nil { t.Fatal("ListTools() failed:", err) } - if diff := cmp.Diff(wantListTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { t.Fatalf("ListTools() mismatch (-want +got):\n%s", diff) } }) t.Run("ToolsIterator", func(t *testing.T) { - for pageSize := range len(tools) + 1 { - testName := fmt.Sprintf("PageSize=%v", pageSize) - t.Run(testName, func(t *testing.T) { - clientSession, serverSession, server := createSessions(ctx, &mcp.ServerOptions{PageSize: pageSize}) - defer clientSession.Close() - defer serverSession.Close() - server.AddTools(tools...) - var gotTools []mcp.Tool - seq := clientSession.Tools(ctx, nil) - for tool, err := range seq { - if err != nil { - t.Fatalf("Tools(%s) failed: %v", testName, err) - } - gotTools = append(gotTools, tool) - } - if diff := cmp.Diff(wantIteratorTools, gotTools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("Tools(%s) mismatch (-want +got):\n%s", testName, diff) - } - }) + wantTools := []mcp.Tool{*toolA.Tool, *toolB.Tool, *toolC.Tool} + var gotTools []mcp.Tool + seq := clientSession.Tools(ctx, nil) + for tool, err := range seq { + if err != nil { + t.Fatalf("Tools() failed: %v", err) + } + gotTools = append(gotTools, tool) + } + if diff := cmp.Diff(wantTools, gotTools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("Tools() mismatch (-want +got):\n%s", diff) } }) } @@ -122,19 +113,36 @@ func TestListResources(t *testing.T) { resourceB := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://banana"}} resourceC := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://cherry"}} resources := []*mcp.ServerResource{resourceA, resourceB, resourceC} - wantResource := []*mcp.Resource{resourceA.Resource, resourceB.Resource, resourceC.Resource} ctx := context.Background() - clientSession, serverSession, server := createSessions(ctx, nil) + clientSession, serverSession, server := createSessions(ctx) defer clientSession.Close() defer serverSession.Close() server.AddResources(resources...) - res, err := clientSession.ListResources(ctx, nil) - if err != nil { - t.Fatal("ListResources() failed:", err) - } - if diff := cmp.Diff(wantResource, res.Resources, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("ListResources() mismatch (-want +got):\n%s", diff) - } + t.Run("ListResources", func(t *testing.T) { + wantResources := []*mcp.Resource{resourceA.Resource, resourceB.Resource, resourceC.Resource} + res, err := clientSession.ListResources(ctx, nil) + if err != nil { + t.Fatal("ListResources() failed:", err) + } + if diff := cmp.Diff(wantResources, res.Resources, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("ListResources() mismatch (-want +got):\n%s", diff) + } + }) + t.Run("ResourcesIterator", func(t *testing.T) { + wantResources := []mcp.Resource{*resourceA.Resource, *resourceB.Resource, *resourceC.Resource} + var gotResources []mcp.Resource + seq := clientSession.Resources(ctx, nil) + for resource, err := range seq { + if err != nil { + t.Fatalf("Resources() failed: %v", err) + } + gotResources = append(gotResources, resource) + } + if diff := cmp.Diff(wantResources, gotResources, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("Resources() mismatch (-want +got):\n%s", diff) + } + }) + } func TestListPrompts(t *testing.T) { @@ -142,17 +150,33 @@ func TestListPrompts(t *testing.T) { promptB := mcp.NewPrompt("banana", "banana prompt", testPromptHandler[struct{}]) promptC := mcp.NewPrompt("cherry", "cherry prompt", testPromptHandler[struct{}]) prompts := []*mcp.ServerPrompt{promptA, promptB, promptC} - wantPrompts := []*mcp.Prompt{promptA.Prompt, promptB.Prompt, promptC.Prompt} ctx := context.Background() - clientSession, serverSession, server := createSessions(ctx, nil) + clientSession, serverSession, server := createSessions(ctx) defer clientSession.Close() defer serverSession.Close() server.AddPrompts(prompts...) - res, err := clientSession.ListPrompts(ctx, nil) - if err != nil { - t.Fatal("ListPrompts() failed:", err) - } - if diff := cmp.Diff(wantPrompts, res.Prompts, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("ListPrompts() mismatch (-want +got):\n%s", diff) - } + t.Run("ListPrompts", func(t *testing.T) { + wantPrompts := []*mcp.Prompt{promptA.Prompt, promptB.Prompt, promptC.Prompt} + res, err := clientSession.ListPrompts(ctx, nil) + if err != nil { + t.Fatal("ListPrompts() failed:", err) + } + if diff := cmp.Diff(wantPrompts, res.Prompts, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("ListPrompts() mismatch (-want +got):\n%s", diff) + } + }) + t.Run("PromptsIterator", func(t *testing.T) { + wantPrompts := []mcp.Prompt{*promptA.Prompt, *promptB.Prompt, *promptC.Prompt} + var gotPrompts []mcp.Prompt + seq := clientSession.Prompts(ctx, nil) + for prompt, err := range seq { + if err != nil { + t.Fatalf("Prompts() failed: %v", err) + } + gotPrompts = append(gotPrompts, prompt) + } + if diff := cmp.Diff(wantPrompts, gotPrompts, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("Prompts() mismatch (-want +got):\n%s", diff) + } + }) } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 67a13f9b9ba..16bf8a5317e 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -41,7 +41,7 @@ func getCursor(input string) string { return cursor } -func TestPaginateBasic(t *testing.T) { +func TestServerPaginateBasic(t *testing.T) { testCases := []struct { name string initialItems []*TestItem @@ -170,7 +170,7 @@ func TestPaginateBasic(t *testing.T) { } } -func TestPaginateVariousPageSizes(t *testing.T) { +func TestServerPaginateVariousPageSizes(t *testing.T) { fs := newFeatureSet(func(t *TestItem) string { return t.Name }) fs.add(allTestItems...) // Try all possible page sizes, ensuring we get the correct list of items. From d794c0d920769cd0f3eab4edd8afc075da2527cf Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 27 May 2025 14:15:09 +0000 Subject: [PATCH 163/196] internal/mcp: generic handling of CallToolParams Make Tool APIs future-proof by pushing generics into CallToolParams, and using CallToolParams throughout. This has several advantages: - We no longer need the awkward CallToolOptions type, which is almost but not quite like CallToolParams. - NewTool no longer needs to define its own handler type: it accepts a ToolHandler[TArgs]. This also means its signature is future-proof. Additionally, provide a `CallTool` helper (the analog of NewTool) to serve as a more strongly typed variant of `ClientSession.CallTool`. + design, protocol generation Change-Id: I80172a25bc50734dae84a541aa4776ed9161694e Reviewed-on: https://go-review.googlesource.com/c/tools/+/676495 Reviewed-by: Jonathan Amsterdam Auto-Submit: Robert Findley LUCI-TryBot-Result: Go LUCI --- gopls/internal/mcp/mcp.go | 12 ++--- gopls/internal/test/marker/marker_test.go | 13 +++--- internal/mcp/README.md | 12 +++-- internal/mcp/client.go | 33 ++++++-------- internal/mcp/cmd_test.go | 5 ++- internal/mcp/design/design.md | 33 ++++++++------ internal/mcp/examples/hello/main.go | 10 +++-- internal/mcp/examples/sse/main.go | 8 ++-- internal/mcp/features_test.go | 8 ++-- internal/mcp/generate.go | 40 +++++++++++++---- internal/mcp/internal/readme/client/client.go | 6 ++- internal/mcp/internal/readme/server/server.go | 6 +-- internal/mcp/mcp_test.go | 44 ++++++++++++------- internal/mcp/protocol.go | 12 +++-- internal/mcp/server.go | 3 +- internal/mcp/server_example_test.go | 15 ++++--- internal/mcp/sse_example_test.go | 11 +++-- internal/mcp/sse_test.go | 5 ++- internal/mcp/tool.go | 21 +++++---- internal/mcp/tool_test.go | 2 +- internal/mcp/transport.go | 4 +- 21 files changed, 185 insertions(+), 118 deletions(-) diff --git a/gopls/internal/mcp/mcp.go b/gopls/internal/mcp/mcp.go index b12e463ec37..b3a897d359f 100644 --- a/gopls/internal/mcp/mcp.go +++ b/gopls/internal/mcp/mcp.go @@ -119,8 +119,8 @@ func newServer(_ *cache.Cache, session *cache.Session) *mcp.Server { mcp.NewTool( "hello_world", "Say hello to someone", - func(ctx context.Context, _ *mcp.ServerSession, request HelloParams) ([]*mcp.Content, error) { - return helloHandler(ctx, session, request) + func(ctx context.Context, _ *mcp.ServerSession, params *mcp.CallToolParams[HelloParams]) (*mcp.CallToolResult, error) { + return helloHandler(ctx, session, params) }, ), ) @@ -132,9 +132,11 @@ type HelloParams struct { Location Location `json:"loc" mcp:"location inside of a text file"` } -func helloHandler(_ context.Context, _ *cache.Session, request HelloParams) ([]*mcp.Content, error) { - return []*mcp.Content{ - mcp.NewTextContent(fmt.Sprintf("Hi %s, current file %s.", request.Name, path.Base(request.Location.URI))), +func helloHandler(_ context.Context, _ *cache.Session, params *mcp.CallToolParams[HelloParams]) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []*mcp.Content{ + mcp.NewTextContent(fmt.Sprintf("Hi %s, current file %s.", params.Arguments.Name, path.Base(params.Arguments.Location.URI))), + }, }, nil } diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index 8bb2fc35490..3914fb76e4b 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -2438,9 +2438,9 @@ func implementationMarker(mark marker, src protocol.Location, want ...protocol.L } } -func mcpToolMarker(mark marker, tool string, args string, loc protocol.Location) { - var toolArgs map[string]any - if err := json.Unmarshal([]byte(args), &toolArgs); err != nil { +func mcpToolMarker(mark marker, tool string, rawArgs string, loc protocol.Location) { + args := make(map[string]any) + if err := json.Unmarshal([]byte(rawArgs), &args); err != nil { mark.errorf("fail to unmarshal arguments to map[string]any: %v", err) return } @@ -2448,9 +2448,12 @@ func mcpToolMarker(mark marker, tool string, args string, loc protocol.Location) // Inserts the location value into the MCP tool arguments map under the // "loc" key. // TODO(hxjiang): Make the "loc" key configurable. - toolArgs["loc"] = loc + args["loc"] = loc - res, err := mark.run.env.MCPSession.CallTool(mark.ctx(), tool, toolArgs, nil) + res, err := mcp.CallTool(mark.ctx(), mark.run.env.MCPSession, &mcp.CallToolParams[map[string]any]{ + Name: tool, + Arguments: args, + }) if err != nil { mark.errorf("failed to call mcp tool: %v", err) return diff --git a/internal/mcp/README.md b/internal/mcp/README.md index a4fc3dee443..7fa3265cee0 100644 --- a/internal/mcp/README.md +++ b/internal/mcp/README.md @@ -50,7 +50,11 @@ func main() { } defer session.Close() // Call a tool on the server. - if res, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}, nil); err != nil { + params := &mcp.CallToolParams[map[string]any]{ + Name: "greet", + Arguments: map[string]any{"name": "you"}, + } + if res, err := mcp.CallTool(ctx, session, params); err != nil { log.Printf("CallTool failed: %v", err) } else { if res.IsError { @@ -78,9 +82,9 @@ type HiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *HiParams) ([]*mcp.Content, error) { - return []*mcp.Content{ - mcp.NewTextContent("Hi " + params.Name), +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParams[HiParams]) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []*mcp.Content{mcp.NewTextContent("Hi " + params.Name)}, }, nil } diff --git a/internal/mcp/client.go b/internal/mcp/client.go index e64bcb238d2..20d1c3bb42d 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -258,35 +258,30 @@ func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) // CallTool calls the tool with the given name and arguments. // Pass a [CallToolOptions] to provide additional request fields. -func (cs *ClientSession) CallTool(ctx context.Context, name string, args map[string]any, opts *CallToolOptions) (_ *CallToolResult, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("calling tool %q: %w", name, err) - } - }() +func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams[json.RawMessage]) (*CallToolResult, error) { + return standardCall[CallToolResult](ctx, cs.conn, methodCallTool, params) +} - data, err := json.Marshal(args) +// CallTool is a helper to call a tool with any argument type. It returns an +// error if params.Arguments fails to marshal to JSON. +func CallTool[TArgs any](ctx context.Context, cs *ClientSession, params *CallToolParams[TArgs]) (*CallToolResult, error) { + data, err := json.Marshal(params.Arguments) if err != nil { - return nil, fmt.Errorf("marshaling arguments: %w", err) + return nil, fmt.Errorf("failed to marshal arguments: %v", err) } - params := &CallToolParams{ - Name: name, - Arguments: json.RawMessage(data), + // TODO(rfindley): write a test that guarantees this copying is total. + wireParams := &CallToolParams[json.RawMessage]{ + Meta: params.Meta, + Name: params.Name, + Arguments: data, } - return standardCall[CallToolResult](ctx, cs.conn, methodCallTool, params) + return cs.CallTool(ctx, wireParams) } func (cs *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error { return call(ctx, cs.conn, methodSetLevel, params, nil) } -// NOTE: the following struct should consist of all fields of callToolParams except name and arguments. - -// CallToolOptions contains options to [ClientSession.CallTool]. -type CallToolOptions struct { - ProgressToken any // string or int -} - // ListResources lists the resources that are currently available on the server. func (cs *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { return standardCall[ListResourcesResult](ctx, cs.conn, methodListResources, params) diff --git a/internal/mcp/cmd_test.go b/internal/mcp/cmd_test.go index 202f8495136..a251b1e2ad8 100644 --- a/internal/mcp/cmd_test.go +++ b/internal/mcp/cmd_test.go @@ -53,7 +53,10 @@ func TestCmdTransport(t *testing.T) { if err != nil { log.Fatal(err) } - got, err := session.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) + got, err := mcp.CallTool(ctx, session, &mcp.CallToolParams[map[string]any]{ + Name: "greet", + Arguments: map[string]any{"name": "user"}, + }) if err != nil { log.Fatal(err) } diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 642fdda3add..95a6ed829e6 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -276,6 +276,7 @@ type Content struct { func NewTextContent(text string) *Content // etc. ``` + The `Meta` type includes a `map[string]any` for arbitrary data, and a `ProgressToken` field. **Differences from mcp-go**: these types are largely similar, but our type generator flattens types rather than using struct embedding. @@ -379,9 +380,9 @@ In our SDK, RPC methods that are defined in the specification take a context and func (*ClientSession) ListTools(context.Context, *ListToolsParams) (*ListToolsResult, error) ``` -Our SDK has a method for every RPC in the spec, and except for `CallTool`, their signatures all share this form. We do this, rather than providing more convenient shortcut signatures, to maintain backward compatibility if the spec makes backward-compatible changes such as adding a new property to the request parameters (as in [this commit](https://github.com/modelcontextprotocol/modelcontextprotocol/commit/2fce8a077688bf8011e80af06348b8fe1dae08ac), for example). To avoid boilerplate, we don't repeat this signature for RPCs defined in the spec; readers may assume it when we mention a "spec method." +Our SDK has a method for every RPC in the spec, their signatures all share this form. We do this, rather than providing more convenient shortcut signatures, to maintain backward compatibility if the spec makes backward-compatible changes such as adding a new property to the request parameters (as in [this commit](https://github.com/modelcontextprotocol/modelcontextprotocol/commit/2fce8a077688bf8011e80af06348b8fe1dae08ac), for example). To avoid boilerplate, we don't repeat this signature for RPCs defined in the spec; readers may assume it when we mention a "spec method." -`CallTool` is the only exception: for convenience, it takes the tool name and arguments, with an options struct for additional request fields. See the section on Tools below for details. +`CallTool` is the only exception: for convenience when binding to Go argument types, `*CallToolParams[TArgs]` is generic, with a type parameter providing the Go type of the tool arguments. The spec method accepts a `*CallToolParams[json.RawMessage]`, but we provide a generic helper function. See the section on Tools below for details. Why do we use params instead of the full JSON-RPC request? As much as possible, we endeavor to hide JSON-RPC details when they are not relevant to the business logic of your client or server. In this case, the additional information in the JSON-RPC request is just the request ID and method name; the request ID is irrelevant, and the method name is implied by the name of the Go method providing the API. @@ -495,7 +496,6 @@ type Meta struct { } ``` - Handlers can notify their peer about progress by calling the `NotifyProgress` method. The notification is only sent if the peer requested it by providing a progress token. ```go @@ -579,7 +579,15 @@ type ClientOptions struct { A `Tool` is a logical MCP tool, generated from the MCP spec, and a `ServerTool` is a tool bound to a tool handler. +A tool handler accepts `CallToolParams` and returns a `CallToolResult`. However, since we want to bind tools to Go input types, it is convenient in associated APIs to make `CallToolParams` generic, with a type parameter `TArgs` for the tool argument type. This allows tool APIs to manage the marshalling and unmarshalling of tool inputs for their caller. The bound `ServerTool` type expects a `json.RawMessage` for its tool arguments, but the `NewTool` constructor described below provides a mechanism to bind a typed handler. + ```go +type CallToolParams[TArgs any] struct { + Meta Meta `json:"_meta,omitempty"` + Arguments TArgs `json:"arguments,omitempty"` + Name string `json:"name"` +} + type Tool struct { Annotations *ToolAnnotations `json:"annotations,omitempty"` Description string `json:"description,omitempty"` @@ -587,11 +595,11 @@ type Tool struct { Name string `json:"name"` } -type ToolHandler func(context.Context, *ServerSession, *CallToolParams) (*CallToolResult, error) +type ToolHandler[TArgs] func(context.Context, *ServerSession, *CallToolParams[TArgs]) (*CallToolResult, error) type ServerTool struct { Tool Tool - Handler ToolHandler + Handler ToolHandler[json.RawMessage] } ``` @@ -620,12 +628,12 @@ We have found that a hybrid model works well, where the _initial_ schema is deri ```go // NewTool creates a Tool using reflection on the given handler. -func NewTool[TInput any](name, description string, handler func(context.Context, *ServerSession, TInput) ([]Content, error), opts …ToolOption) *ServerTool +func NewTool[TArgs any](name, description string, handler ToolHandler[TArgs], opts …ToolOption) *ServerTool type ToolOption interface { /* ... */ } ``` -`NewTool` determines the input schema for a Tool from the struct used in the handler. Each struct field that would be marshaled by `encoding/json.Marshal` becomes a property of the schema. The property is required unless the field's `json` tag specifies "omitempty" or "omitzero" (new in Go 1.24). For example, given this struct: +`NewTool` determines the input schema for a Tool from the `TArgs` type. Each struct field that would be marshaled by `encoding/json.Marshal` becomes a property of the schema. The property is required unless the field's `json` tag specifies "omitempty" or "omitzero" (new in Go 1.24). For example, given this struct: ```go struct { @@ -666,16 +674,12 @@ Schemas are validated on the server before the tool handler is called. Since all the fields of the Tool struct are exported, a Tool can also be created directly with assignment or a struct literal. -Client sessions can call the spec method `ListTools` or an iterator method `Tools` to list the available tools. - -As mentioned above, the client session method `CallTool` has a non-standard signature, so that `CallTool` can handle the marshalling of tool arguments: the type of `CallToolParams.Arguments` is `json.RawMessage`, to delegate unmarshalling to the tool handler. +Client sessions can call the spec method `ListTools` or an iterator method `Tools` to list the available tools, and use spec method `CallTool` to call tools. Similar to `ServerTool.Handler`, `CallTool` expects `*CallToolParams[json.RawMessage]`, but we provide a generic `CallTool` helper to operate on typed arguments. ```go -func (c *ClientSession) CallTool(ctx context.Context, name string, args map[string]any, opts *CallToolOptions) (_ *CallToolResult, err error) +func (cs *ClientSession) CallTool(context.Context, *CallToolParams[json.RawMessage]) (*CallToolResult, error) -type CallToolOptions struct { - ProgressToken any // string or int -} +func CallTool[TArgs any](context.Context, *ClientSession, *CallToolParams[TArgs]) (*CallToolResult, error) ``` **Differences from mcp-go**: using variadic options to configure tools was significantly inspired by mcp-go. However, the distinction between `ToolOption` and `SchemaOption` allows for recursive application of schema options. For example, that limitation is visible in [this code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), which must resort to untyped maps to express a nested schema. @@ -833,6 +837,7 @@ Clients call the spec method `Complete` to request completions. Servers automati MCP specifies a notification for servers to log to clients. Server sessions implement this with the `LoggingMessage` method. It honors the minimum log level established by the client session's `SetLevel` call. As a convenience, we also provide a `slog.Handler` that allows server authors to write logs with the `log/slog` package:: + ```go // A LoggingHandler is a [slog.Handler] for MCP. type LoggingHandler struct {...} diff --git a/internal/mcp/examples/hello/main.go b/internal/mcp/examples/hello/main.go index 1798c1810a7..84e409d98d8 100644 --- a/internal/mcp/examples/hello/main.go +++ b/internal/mcp/examples/hello/main.go @@ -21,13 +21,15 @@ type HiArgs struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *HiArgs) ([]*mcp.Content, error) { - return []*mcp.Content{ - mcp.NewTextContent("Hi " + params.Name), +func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParams[HiArgs]) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []*mcp.Content{ + mcp.NewTextContent("Hi " + params.Name), + }, }, nil } -func PromptHi(ctx context.Context, cc *mcp.ServerSession, args *HiArgs, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { +func PromptHi(ctx context.Context, ss *mcp.ServerSession, args *HiArgs, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Description: "Code review prompt", Messages: []*mcp.PromptMessage{ diff --git a/internal/mcp/examples/sse/main.go b/internal/mcp/examples/sse/main.go index 201f3a091ce..0447e7f336a 100644 --- a/internal/mcp/examples/sse/main.go +++ b/internal/mcp/examples/sse/main.go @@ -19,9 +19,11 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *SayHiParams) ([]*mcp.Content, error) { - return []*mcp.Content{ - mcp.NewTextContent("Hi " + params.Name), +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParams[SayHiParams]) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []*mcp.Content{ + mcp.NewTextContent("Hi " + params.Name), + }, }, nil } diff --git a/internal/mcp/features_test.go b/internal/mcp/features_test.go index 2bda8745932..327f817c2e7 100644 --- a/internal/mcp/features_test.go +++ b/internal/mcp/features_test.go @@ -18,9 +18,11 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *ServerSession, params *SayHiParams) ([]*Content, error) { - return []*Content{ - NewTextContent("Hi " + params.Name), +func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParams[SayHiParams]) (*CallToolResult, error) { + return &CallToolResult{ + Content: []*Content{ + NewTextContent("Hi " + params.Name), + }, }, nil } diff --git a/internal/mcp/generate.go b/internal/mcp/generate.go index 80ce74dddd6..d4b63c15eb7 100644 --- a/internal/mcp/generate.go +++ b/internal/mcp/generate.go @@ -38,9 +38,10 @@ var schemaFile = flag.String("schema_file", "", "if set, use this file as the pe // struct field. In others, we may want to extract the type definition to a // name. type typeConfig struct { - Name string // declaration name for the type - Substitute string // type definition to substitute - Fields config // individual field configuration, or nil + Name string // declaration name for the type + TypeParams [][2]string // formatted type parameter list ({name, constraint}), if any + Substitute string // type definition to substitute + Fields config // individual field configuration, or nil } type config map[string]*typeConfig @@ -57,9 +58,10 @@ var declarations = config{ Name: "-", Fields: config{ "Params": { - Name: "CallToolParams", + Name: "CallToolParams", + TypeParams: [][2]string{{"TArgs", "any"}}, Fields: config{ - "Arguments": {Substitute: "json.RawMessage"}, + "Arguments": {Substitute: "TArgs"}, }, }, }, @@ -224,8 +226,6 @@ func main() { package mcp import ( - "encoding/json" - "golang.org/x/tools/internal/mcp/jsonschema" ) `) @@ -320,7 +320,18 @@ func writeDecl(configName string, config typeConfig, def *jsonschema.Schema, nam if def.Description != "" { fmt.Fprintf(buf, "%s\n", toComment(def.Description)) } - fmt.Fprintf(buf, "type %s ", typeName) + typeParams := new(strings.Builder) + if len(config.TypeParams) > 0 { + typeParams.WriteByte('[') + for i, p := range config.TypeParams { + if i > 0 { + typeParams.WriteString(", ") + } + fmt.Fprintf(typeParams, "%s %s", p[0], p[1]) + } + typeParams.WriteByte(']') + } + fmt.Fprintf(buf, "type %s%s ", typeName, typeParams) } if err := writeType(w, &config, def, named); err != nil { return err // Better error here? @@ -329,7 +340,18 @@ func writeDecl(configName string, config typeConfig, def *jsonschema.Schema, nam // Any decl with a _meta field gets a GetMeta method. if _, ok := def.Properties["_meta"]; ok { - fmt.Fprintf(w, "\nfunc (x *%s) GetMeta() *Meta { return &x.Meta }", typeName) + targs := new(strings.Builder) + if len(config.TypeParams) > 0 { + targs.WriteByte('[') + for i, p := range config.TypeParams { + if i > 0 { + targs.WriteString(", ") + } + fmt.Fprintf(targs, "%s", p[0]) + } + targs.WriteByte(']') + } + fmt.Fprintf(w, "\nfunc (x *%s%s) GetMeta() *Meta { return &x.Meta }", typeName, targs) } if _, ok := def.Properties["cursor"]; ok { diff --git a/internal/mcp/internal/readme/client/client.go b/internal/mcp/internal/readme/client/client.go index 97600e7d2ab..53b36dd2547 100644 --- a/internal/mcp/internal/readme/client/client.go +++ b/internal/mcp/internal/readme/client/client.go @@ -25,7 +25,11 @@ func main() { } defer session.Close() // Call a tool on the server. - if res, err := session.CallTool(ctx, "greet", map[string]any{"name": "you"}, nil); err != nil { + params := &mcp.CallToolParams[map[string]any]{ + Name: "greet", + Arguments: map[string]any{"name": "you"}, + } + if res, err := mcp.CallTool(ctx, session, params); err != nil { log.Printf("CallTool failed: %v", err) } else { if res.IsError { diff --git a/internal/mcp/internal/readme/server/server.go b/internal/mcp/internal/readme/server/server.go index 867d4c1e08d..504c7456619 100644 --- a/internal/mcp/internal/readme/server/server.go +++ b/internal/mcp/internal/readme/server/server.go @@ -15,9 +15,9 @@ type HiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *HiParams) ([]*mcp.Content, error) { - return []*mcp.Content{ - mcp.NewTextContent("Hi " + params.Name), +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParams[HiParams]) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []*mcp.Content{mcp.NewTextContent("Hi " + params.Name)}, }, nil } diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index e908811941a..882a4e5ce51 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -30,16 +30,18 @@ type hiParams struct { Name string } -func sayHi(ctx context.Context, ss *ServerSession, v hiParams) ([]*Content, error) { +func sayHi(ctx context.Context, ss *ServerSession, params *CallToolParams[hiParams]) (*CallToolResult, error) { if err := ss.Ping(ctx, nil); err != nil { return nil, fmt.Errorf("ping failed: %v", err) } - return []*Content{NewTextContent("hi " + v.Name)}, nil + return &CallToolResult{Content: []*Content{NewTextContent("hi " + params.Arguments.Name)}}, nil } func TestEndToEnd(t *testing.T) { ctx := context.Background() - ct, st := NewInMemoryTransports() + var ct, st Transport = NewInMemoryTransports() + // ct = NewLoggingTransport(ct, os.Stderr) + // st = NewLoggingTransport(st, os.Stderr) // Channels to check if notification callbacks happened. notificationChans := map[string]chan int{} @@ -185,7 +187,10 @@ func TestEndToEnd(t *testing.T) { t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) } - gotHi, err := cs.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) + gotHi, err := CallTool(ctx, cs, &CallToolParams[map[string]any]{ + Name: "greet", + Arguments: map[string]any{"name": "user"}, + }) if err != nil { t.Fatal(err) } @@ -196,7 +201,10 @@ func TestEndToEnd(t *testing.T) { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } - gotFail, err := cs.CallTool(ctx, "fail", map[string]any{}, nil) + gotFail, err := CallTool(ctx, cs, &CallToolParams[map[string]any]{ + Name: "fail", + Arguments: map[string]any{}, + }) // Counter-intuitively, when a tool fails, we don't expect an RPC error for // call tool: instead, the failure is embedded in the result. if err != nil { @@ -387,7 +395,7 @@ var ( tools = map[string]*ServerTool{ "greet": NewTool("greet", "say hi", sayHi), - "fail": NewTool("fail", "just fail", func(context.Context, *ServerSession, struct{}) ([]*Content, error) { + "fail": NewTool("fail", "just fail", func(context.Context, *ServerSession, *CallToolParams[struct{}]) (*CallToolResult, error) { return nil, errTestFailure }), } @@ -506,24 +514,30 @@ func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *Clien } func TestServerClosing(t *testing.T) { - cc, c := basicConnection(t, NewTool("greet", "say hi", sayHi)) - defer c.Close() + cc, cs := basicConnection(t, NewTool("greet", "say hi", sayHi)) + defer cs.Close() ctx := context.Background() var wg sync.WaitGroup wg.Add(1) go func() { - if err := c.Wait(); err != nil { + if err := cs.Wait(); err != nil { t.Errorf("server connection failed: %v", err) } wg.Done() }() - if _, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil); err != nil { + if _, err := CallTool(ctx, cs, &CallToolParams[map[string]any]{ + Name: "greet", + Arguments: map[string]any{"name": "user"}, + }); err != nil { t.Fatalf("after connecting: %v", err) } cc.Close() wg.Wait() - if _, err := c.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil); !errors.Is(err, ErrConnectionClosed) { + if _, err := CallTool(ctx, cs, &CallToolParams[map[string]any]{ + Name: "greet", + Arguments: map[string]any{"name": "user"}, + }); !errors.Is(err, ErrConnectionClosed) { t.Errorf("after disconnection, got error %v, want EOF", err) } } @@ -572,7 +586,7 @@ func TestCancellation(t *testing.T) { cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, cc *ServerSession, v struct{}) ([]*Content, error) { + slowRequest := func(ctx context.Context, cc *ServerSession, params *CallToolParams[struct{}]) (*CallToolResult, error) { start <- struct{}{} select { case <-ctx.Done(): @@ -582,11 +596,11 @@ func TestCancellation(t *testing.T) { } return nil, nil } - _, sc := basicConnection(t, NewTool("slow", "a slow request", slowRequest)) - defer sc.Close() + _, cs := basicConnection(t, NewTool("slow", "a slow request", slowRequest)) + defer cs.Close() ctx, cancel := context.WithCancel(context.Background()) - go sc.CallTool(ctx, "slow", map[string]any{}, nil) + go CallTool(ctx, cs, &CallToolParams[struct{}]{Name: "slow"}) <-start cancel() select { diff --git a/internal/mcp/protocol.go b/internal/mcp/protocol.go index fd4b9af2ae2..07db7d32873 100644 --- a/internal/mcp/protocol.go +++ b/internal/mcp/protocol.go @@ -7,8 +7,6 @@ package mcp import ( - "encoding/json" - "golang.org/x/tools/internal/mcp/jsonschema" ) @@ -28,15 +26,15 @@ type Annotations struct { Priority float64 `json:"priority,omitempty"` } -type CallToolParams struct { +type CallToolParams[TArgs any] struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. - Meta Meta `json:"_meta,omitempty"` - Arguments json.RawMessage `json:"arguments,omitempty"` - Name string `json:"name"` + Meta Meta `json:"_meta,omitempty"` + Arguments TArgs `json:"arguments,omitempty"` + Name string `json:"name"` } -func (x *CallToolParams) GetMeta() *Meta { return &x.Meta } +func (x *CallToolParams[TArgs]) GetMeta() *Meta { return &x.Meta } // The server's response to a tool call. // diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 56604d4d071..79321b8046b 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -9,6 +9,7 @@ import ( "context" "encoding/base64" "encoding/gob" + "encoding/json" "fmt" "iter" "net/url" @@ -227,7 +228,7 @@ func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListTool return res, nil } -func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallToolParams) (*CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallToolParams[json.RawMessage]) (*CallToolResult, error) { s.mu.Lock() tool, ok := s.tools.get(params.Name) s.mu.Unlock() diff --git a/internal/mcp/server_example_test.go b/internal/mcp/server_example_test.go index 800fd5d906b..07d34b9ac9c 100644 --- a/internal/mcp/server_example_test.go +++ b/internal/mcp/server_example_test.go @@ -17,12 +17,14 @@ import ( ) type SayHiParams struct { - Name string `json:"name" mcp:"the name to say hi to"` + Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *SayHiParams) ([]*mcp.Content, error) { - return []*mcp.Content{ - mcp.NewTextContent("Hi " + params.Name), +func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParams[SayHiParams]) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []*mcp.Content{ + mcp.NewTextContent("Hi " + params.Arguments.Name), + }, }, nil } @@ -44,7 +46,10 @@ func ExampleServer() { log.Fatal(err) } - res, err := clientSession.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) + res, err := mcp.CallTool(ctx, clientSession, &mcp.CallToolParams[map[string]any]{ + Name: "greet", + Arguments: map[string]any{"name": "user"}, + }) if err != nil { log.Fatal(err) } diff --git a/internal/mcp/sse_example_test.go b/internal/mcp/sse_example_test.go index ef4269d46ff..6ad05b9bb68 100644 --- a/internal/mcp/sse_example_test.go +++ b/internal/mcp/sse_example_test.go @@ -18,9 +18,9 @@ type AddParams struct { X, Y int } -func Add(ctx context.Context, cc *mcp.ServerSession, params *AddParams) ([]*mcp.Content, error) { - return []*mcp.Content{ - mcp.NewTextContent(fmt.Sprintf("%d", params.X+params.Y)), +func Add(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParams[AddParams]) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []*mcp.Content{mcp.NewTextContent(fmt.Sprintf("%d", params.Arguments.X+params.Arguments.Y))}, }, nil } @@ -41,7 +41,10 @@ func ExampleSSEHandler() { } defer cs.Close() - res, err := cs.CallTool(ctx, "add", map[string]any{"x": 1, "y": 2}, nil) + res, err := mcp.CallTool(ctx, cs, &mcp.CallToolParams[map[string]any]{ + Name: "add", + Arguments: map[string]any{"x": 1, "y": 2}, + }) if err != nil { log.Fatal(err) } diff --git a/internal/mcp/sse_test.go b/internal/mcp/sse_test.go index cba0ada9235..cef1dbbed9d 100644 --- a/internal/mcp/sse_test.go +++ b/internal/mcp/sse_test.go @@ -44,7 +44,10 @@ func TestSSEServer(t *testing.T) { t.Fatal(err) } ss := <-conns - gotHi, err := cs.CallTool(ctx, "greet", map[string]any{"name": "user"}, nil) + gotHi, err := CallTool(ctx, cs, &CallToolParams[map[string]any]{ + Name: "greet", + Arguments: map[string]any{"name": "user"}, + }) if err != nil { t.Fatal(err) } diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index 43ebe1bfdb4..099321fde1e 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -13,12 +13,12 @@ import ( ) // A ToolHandler handles a call to tools/call. -type ToolHandler func(context.Context, *ServerSession, *CallToolParams) (*CallToolResult, error) +type ToolHandler[TArgs any] func(context.Context, *ServerSession, *CallToolParams[TArgs]) (*CallToolResult, error) // A Tool is a tool definition that is bound to a tool handler. type ServerTool struct { Tool *Tool - Handler ToolHandler + Handler ToolHandler[json.RawMessage] } // NewTool is a helper to make a tool using reflection on the given handler. @@ -34,17 +34,19 @@ type ServerTool struct { // // TODO: just have the handler return a CallToolResult: returning []Content is // going to be inconsistent with other server features. -func NewTool[TReq any](name, description string, handler func(context.Context, *ServerSession, TReq) ([]*Content, error), opts ...ToolOption) *ServerTool { +func NewTool[TReq any](name, description string, handler ToolHandler[TReq], opts ...ToolOption) *ServerTool { schema, err := jsonschema.For[TReq]() if err != nil { panic(err) } - wrapped := func(ctx context.Context, cc *ServerSession, params *CallToolParams) (*CallToolResult, error) { - var v TReq - if err := unmarshalSchema(params.Arguments, schema, &v); err != nil { - return nil, err + wrapped := func(ctx context.Context, cc *ServerSession, params *CallToolParams[json.RawMessage]) (*CallToolResult, error) { + var params2 CallToolParams[TReq] + if params.Arguments != nil { + if err := unmarshalSchema(params.Arguments, schema, ¶ms2.Arguments); err != nil { + return nil, err + } } - content, err := handler(ctx, cc, v) + res, err := handler(ctx, cc, ¶ms2) // TODO: investigate why server errors are embedded in this strange way, // rather than returned as jsonrpc2 server errors. if err != nil { @@ -53,9 +55,6 @@ func NewTool[TReq any](name, description string, handler func(context.Context, * IsError: true, }, nil } - res := &CallToolResult{ - Content: content, - } return res, nil } t := &ServerTool{ diff --git a/internal/mcp/tool_test.go b/internal/mcp/tool_test.go index ae4e5ee93e5..646e9b32992 100644 --- a/internal/mcp/tool_test.go +++ b/internal/mcp/tool_test.go @@ -15,7 +15,7 @@ import ( ) // testToolHandler is used for type inference in TestNewTool. -func testToolHandler[T any](context.Context, *mcp.ServerSession, T) ([]*mcp.Content, error) { +func testToolHandler[T any](context.Context, *mcp.ServerSession, *mcp.CallToolParams[T]) (*mcp.CallToolResult, error) { panic("not implemented") } diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index 0fbe7082a80..184206ec8d6 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -200,7 +200,7 @@ func (s *loggingStream) Read(ctx context.Context) (jsonrpc2.Message, int64, erro if err != nil { fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) } - fmt.Fprintf(s.w, "read: %s", string(data)) + fmt.Fprintf(s.w, "read: %s\n", string(data)) } return msg, n, err } @@ -215,7 +215,7 @@ func (s *loggingStream) Write(ctx context.Context, msg jsonrpc2.Message) (int64, if err != nil { fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) } - fmt.Fprintf(s.w, "write: %s", string(data)) + fmt.Fprintf(s.w, "write: %s\n", string(data)) } return n, err } From 73f12340029e3ffab6a6e9f53be549a537b61bdd Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 30 May 2025 22:39:39 +0000 Subject: [PATCH 164/196] internal/mcp: community design Add a section the the MCP SDK design discussing governance and community engagement. Change-Id: I554d8d5623201053a17a17c21dc596c4cdd0878c Reviewed-on: https://go-review.googlesource.com/c/tools/+/677540 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/mcp/design/design.md | 62 ++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 95a6ed829e6..b13f4164515 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -20,7 +20,7 @@ These may be obvious, but it's worthwhile to define goals for an official MCP SD - **future-proof**: the SDK should allow for future evolution of the MCP spec, in such a way that we can (as much as possible) avoid incompatible changes to the SDK API. - **extensible**: to best serve the previous four concerns, the SDK should be minimal. However, it should admit extensibility using (for example) simple interfaces, middleware, or hooks. -# Design considerations +# Design In the sections below, we visit each aspect of the MCP spec, in approximately the order they are presented by the [official spec](https://modelcontextprotocol.io/specification/2025-03-26) For each, we discuss considerations for the Go implementation, and propose a Go API. @@ -904,3 +904,63 @@ Client requests for List methods include an optional Cursor field for pagination In addition to the `List` methods, the SDK provides an iterator method for each list operation. This simplifies pagination for clients by automatically handling the underlying pagination logic. See [Iterator Methods](#iterator-methods) above. **Differences with mcp-go**: the PageSize configuration is set with a configuration field rather than a variadic option. Additionally, this design proposes pagination by default, as this is likely desirable for most servers + +# Governance and Community + +While the sections above propose an initial implementation of the Go SDK, MCP is evolving rapidly. SDKs need to keep pace, by implementing changes to the spec, fixing bugs, and accomodating new and emerging use-cases. This section proposes how the SDK project can be managed so that it can change safely and transparently. + +Initially, the Go SDK repository will be administered by the Go team and Anthropic, and they will be the Approvers (the set of people able to merge PRs to the SDK). The policies here are also intended to satisfy necessary constraints of the Go team's participation in the project. + +The content in this section will also be included in a CONTRIBUTING.md file in the repo root. + +## Hosting, copyright, and license + +The SDK will be hosted under github.com/modelcontextprotocol/go-sdk, MIT license, copyright "Go SDK Authors". Each Go file in the repository will have a standard copyright header. For example: + +```go +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. +``` + +## Issues and Contributing + +The SDK will use its GitHub issue tracker for bug tracking, and pull requests for contributions. + +Contributions to the SDK will be welcomed, and will be accepted provided they are high quality and consistent with the direction and philosophy of the SDK outlined above. An official SDK must be conservative in the changes it accepts, to defend against compatibility problems, security vulnerabilities, and churn. To avoid being declined, PRs should be associated with open issues, and those issues should either be labeled 'Help Wanted', or the PR author should ask on the issue before contributing. + +### Proposals + +A proposal is an issue that proposes a new API for the SDK, or a change to the signature or behavior of an existing API. Proposals will be labeled with the 'Proposal' label, and require an explicit approval before being accepted (applied through the 'Proposal-Accepted' label). Proposals will remain open for at least a week to allow discussion before being accepted or declined by an Approver. + +Proposals that are straightforward and uncontroversial may be approved based on GitHub discussion. However, proposals that are deemed to be sufficiently unclear or complicated will be deferred to a regular steering meeting (see below). + +This process is similar to the [Go proposal process](https://github.com/golang/proposal), but is necessarily lighter weight to accomodate the greater rate of change expected for the SDK. + +### Steering meetings + +On a regular basis, we will host a virtual steering meeting to discuss outstanding proposals and other changes to the SDK. These 1hr meetings and their agenda will be announced in advance, and open to all to join. The meetings will be recorded, and recordings and meeting notes will be made available afterward. + +This process is similar to the [Go Tools call](https://go.dev/wiki/golang-tools), though it is expected that meetings will at least initially occur on a more frequent basis (likely biweekly). + +### Discord + +Discord (either the public or private Anthropic discord servers) should only be used for logistical coordination or answering questions. Design discussion and decisions should occur in GitHub issues or public steering meetings. + +### Antitrust considerations + +It is important that the SDK avoids bias toward specific integration paths or providers. Therefore, the CONTRIBUTING.md file will include an antitrust policy that outlines terms and practices intended to avoid such bias, or the appearance thereof. (The details of this policy will be determined by Google and Anthropic lawyers). + +## Releases and Versioning + +The SDK will consist of a single Go module, and will be released through versioned Git tags. Accordingly, it will follow semantic versioning. + +Up until the v1.0.0 release, the SDK may be unstable and may change in breaking ways. An initial v1.0.0 release will occur when the SDK is deemed by Approvers to be stable, production ready, and sufficiently complete (though some unimplemented features may remain). Subsequent to that release, new APIs will be added in minor versions, and breaking changes will require a v2 release of the module (and therefore should be avoided). All releases will have corresponding release notes in GitHub. + +It is desirable that releases occur frequently, and that a v1.0.0 release is achieved as quickly as possible. + +If feasible, the SDK will support all versions of the MCP spec. However, if breaking changes to the spec make this infeasible, preference will be given to the most recent version of the MCP spec. + +## Ongoing evaluation + +On an ongoing basis, the administrators of the SDK will evaluate whether it is keeping pace with changes to the MCP spec and meeting its goals of openness and transparency. If it is not meeting these goals, either because it exceeds the bandwidth of its current Approvers, or because the processes here are inadequate, these processes will be re-evaluated. At this time, the Approvers set may be expanded to include additional community members, based on their history of strong contribution. From d0c0a57b5fe1cb6faf1d943926740f8f165d82e8 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 2 Jun 2025 11:29:41 -0400 Subject: [PATCH 165/196] go/analysis/passes/unusedresult: remove maps.Copy It has no result, and was added by mistake. Change-Id: Ifa93c3ee9c2acf0bbfffac64e752e022bb44b58e Reviewed-on: https://go-review.googlesource.com/c/tools/+/678015 LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Robert Findley --- go/analysis/passes/unusedresult/unusedresult.go | 1 - 1 file changed, 1 deletion(-) diff --git a/go/analysis/passes/unusedresult/unusedresult.go b/go/analysis/passes/unusedresult/unusedresult.go index 193cabc2022..556ffed7d99 100644 --- a/go/analysis/passes/unusedresult/unusedresult.go +++ b/go/analysis/passes/unusedresult/unusedresult.go @@ -75,7 +75,6 @@ func init() { "maps.All": true, "maps.Clone": true, "maps.Collect": true, - "maps.Copy": true, "maps.Equal": true, "maps.EqualFunc": true, "maps.Keys": true, From a405109bb2336cbbde72ccb2149b536e19907759 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 2 Jun 2025 14:45:53 +0000 Subject: [PATCH 166/196] internal/mcp: add a test for complete mapping of CallToolParams fields Add a test to verify that the toWireParams helper maps all fields of CallToolParams. Change-Id: I81c28791734d60c793c76809c6b93d27e78c3c36 Reviewed-on: https://go-review.googlesource.com/c/tools/+/677975 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/mcp/client.go | 17 +++++++++++++---- internal/mcp/client_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 20d1c3bb42d..5ff81e000ad 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -265,17 +265,26 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams[js // CallTool is a helper to call a tool with any argument type. It returns an // error if params.Arguments fails to marshal to JSON. func CallTool[TArgs any](ctx context.Context, cs *ClientSession, params *CallToolParams[TArgs]) (*CallToolResult, error) { + wireParams, err := toWireParams(params) + if err != nil { + return nil, err + } + return cs.CallTool(ctx, wireParams) +} + +func toWireParams[TArgs any](params *CallToolParams[TArgs]) (*CallToolParams[json.RawMessage], error) { data, err := json.Marshal(params.Arguments) if err != nil { return nil, fmt.Errorf("failed to marshal arguments: %v", err) } - // TODO(rfindley): write a test that guarantees this copying is total. - wireParams := &CallToolParams[json.RawMessage]{ + // The field mapping here must be kept up to date with the CallToolParams. + // This is partially enforced by TestToWireParams, which verifies that all + // comparable fields are mapped. + return &CallToolParams[json.RawMessage]{ Meta: params.Meta, Name: params.Name, Arguments: data, - } - return cs.CallTool(ctx, wireParams) + }, nil } func (cs *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error { diff --git a/internal/mcp/client_test.go b/internal/mcp/client_test.go index 3d8b8b23dc6..e789e8130ae 100644 --- a/internal/mcp/client_test.go +++ b/internal/mcp/client_test.go @@ -7,6 +7,7 @@ package mcp import ( "context" "fmt" + "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -199,3 +200,28 @@ func TestClientPaginateVariousPageSizes(t *testing.T) { }) } } + +func TestToWireParams(t *testing.T) { + // This test verifies that toWireParams maps all fields. + // The Meta and Arguments fields are not comparable, so can't be checked by + // this simple test. However, this test will fail if new fields are added, + // and not handled by toWireParams. + params := &CallToolParams[struct{}]{ + Name: "tool", + } + wireParams, err := toWireParams(params) + if err != nil { + t.Fatal(err) + } + v := reflect.ValueOf(wireParams).Elem() + for i := range v.Type().NumField() { + f := v.Type().Field(i) + if f.Name == "Meta" || f.Name == "Arguments" { + continue // not comparable + } + fv := v.Field(i) + if fv.Interface() == reflect.Zero(f.Type).Interface() { + t.Fatalf("toWireParams: unmapped field %q", f.Name) + } + } +} From 61b248fd8afd2efd18a406703e90e8bcee09b74c Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 2 Jun 2025 15:42:28 +0000 Subject: [PATCH 167/196] internal/mcp: add missing testenv.NeedsExec Unbreak our WASM tests by correctly marking a test as testenv.NeedsExec. Change-Id: Ib976657e9b09acc6fc563f6fc26d6066fd64acf5 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678035 Reviewed-by: Alan Donovan Auto-Submit: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/cmd_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/mcp/cmd_test.go b/internal/mcp/cmd_test.go index a251b1e2ad8..b13f24f70e6 100644 --- a/internal/mcp/cmd_test.go +++ b/internal/mcp/cmd_test.go @@ -13,6 +13,7 @@ import ( "github.com/google/go-cmp/cmp" "golang.org/x/tools/internal/mcp" + "golang.org/x/tools/internal/testenv" ) const runAsServer = "_MCP_RUN_AS_SERVER" @@ -38,6 +39,8 @@ func runServer() { } func TestCmdTransport(t *testing.T) { + testenv.NeedsExec(t) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() From e88c5a40c3176d34a6cd32245c779cc2c10c8341 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sun, 25 May 2025 07:35:56 -0400 Subject: [PATCH 168/196] internal/mcp: simplify session type params Write a separate constraint for ClientSession | ServerSession, and change the pointerness of the type parameter. There is still a dynamic cast, but overall I think this is a bit simpler: - The exported types have the more readable constraint "Session" instead of "ClientSession | ServerSession". - We can put the unexported methods in Session instead of needing a separate interface (previously "session"). - This CL enables a simplified treatment of sending middleware, in the next CL. Change-Id: I41184fe93ea201d1da29dbfba26ac59f93771277 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676175 Reviewed-by: Sam Thanawalla LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/client.go | 14 +++--- internal/mcp/design/design.md | 12 ++--- internal/mcp/mcp_test.go | 8 ++-- internal/mcp/server.go | 16 +++---- internal/mcp/shared.go | 89 ++++++++++++++++++----------------- 5 files changed, 67 insertions(+), 72 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 5ff81e000ad..068a9e3877e 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -24,7 +24,7 @@ type Client struct { mu sync.Mutex roots *featureSet[*Root] sessions []*ClientSession - methodHandler_ MethodHandler[ClientSession] + methodHandler_ MethodHandler[*ClientSession] } // NewClient creates a new Client. @@ -37,7 +37,7 @@ func NewClient(name, version string, opts *ClientOptions) *Client { name: name, version: version, roots: newFeatureSet(func(r *Root) string { return r.URI }), - methodHandler_: defaultMethodHandler[ClientSession], + methodHandler_: defaultMethodHandler[*ClientSession], } if opts != nil { c.opts = *opts @@ -196,14 +196,14 @@ func (c *Client) createMessage(ctx context.Context, cs *ClientSession, params *C // // For example, AddMiddleware(m1, m2, m3) augments the client method handler as // m1(m2(m3(handler))). -func (c *Client) AddMiddleware(middleware ...Middleware[ClientSession]) { +func (c *Client) AddMiddleware(middleware ...Middleware[*ClientSession]) { c.mu.Lock() defer c.mu.Unlock() addMiddleware(&c.methodHandler_, middleware) } // clientMethodInfos maps from the RPC method name to serverMethodInfos. -var clientMethodInfos = map[string]methodInfo[ClientSession]{ +var clientMethodInfos = map[string]methodInfo{ methodPing: newMethodInfo(sessionMethod((*ClientSession).ping)), methodListRoots: newMethodInfo(clientMethod((*Client).listRoots)), methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage)), @@ -213,9 +213,7 @@ var clientMethodInfos = map[string]methodInfo[ClientSession]{ notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)), } -var _ session[ClientSession] = (*ClientSession)(nil) - -func (cs *ClientSession) methodInfos() map[string]methodInfo[ClientSession] { +func (cs *ClientSession) methodInfos() map[string]methodInfo { return clientMethodInfos } @@ -223,7 +221,7 @@ func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc2.Request) (any return handleRequest(ctx, req, cs) } -func (cs *ClientSession) methodHandler() MethodHandler[ClientSession] { +func (cs *ClientSession) methodHandler() methodHandler { cs.client.mu.Lock() defer cs.client.mu.Unlock() return cs.client.methodHandler_ diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index b13f4164515..0ed19b2cd5d 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -416,11 +416,11 @@ We provide a mechanism to add MCP-level middleware on the both the client and se // For methods, a MethodHandler must return either an XXResult struct pointer and a nil error, or // nil with a non-nil error. // For notifications, a MethodHandler must return nil, nil. -type MethodHandler[S ClientSession | ServerSession] func( - ctx context.Context, _ *S, method string, params any) (result any, err error) +type MethodHandler[S Session] func( + ctx context.Context, _ *S, method string, params Params) (result Result, err error) // Middleware is a function from MethodHandlers to MethodHandlers. -type Middleware[S ClientSession | ServerSession] func(MethodHandler[S]) MethodHandler[S] +type Middleware[S Session] func(MethodHandler[S]) MethodHandler[S] // AddMiddleware wraps the client/server's current method handler using the provided // middleware. Middleware is applied from right to left, so that the first one @@ -428,14 +428,14 @@ type Middleware[S ClientSession | ServerSession] func(MethodHandler[S]) MethodHa // // For example, AddMiddleware(m1, m2, m3) augments the server method handler as // m1(m2(m3(handler))). -func (c *Client) AddMiddleware(middleware ...Middleware[ClientSession]) -func (s *Server) AddMiddleware(middleware ...Middleware[ServerSession]) +func (c *Client) AddMiddleware(middleware ...Middleware[*ClientSession]) +func (s *Server) AddMiddleware(middleware ...Middleware[*ServerSession]) ``` As an example, this code adds server-side logging: ```go -func withLogging(h mcp.MethodHandler[ServerSession]) mcp.MethodHandler[ServerSession]{ +func withLogging(h mcp.MethodHandler[*ServerSession]) mcp.MethodHandler[*ServerSession]{ return func(ctx context.Context, s *mcp.ServerSession, method string, params any) (res any, err error) { log.Printf("request: %s %v", method, params) defer func() { log.Printf("response: %v, %v", res, err) }() diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 882a4e5ce51..c8394e10c2f 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -634,10 +634,10 @@ func TestMiddleware(t *testing.T) { // "1" is the outer middleware layer, called first; then "2" is called, and finally // the default dispatcher. - s.AddMiddleware(traceCalls[ServerSession](&sbuf, "1"), traceCalls[ServerSession](&sbuf, "2")) + s.AddMiddleware(traceCalls[*ServerSession](&sbuf, "1"), traceCalls[*ServerSession](&sbuf, "2")) c := NewClient("testClient", "v1.0.0", nil) - c.AddMiddleware(traceCalls[ClientSession](&cbuf, "1"), traceCalls[ClientSession](&cbuf, "2")) + c.AddMiddleware(traceCalls[*ClientSession](&cbuf, "1"), traceCalls[*ClientSession](&cbuf, "2")) cs, err := c.Connect(ctx, ct) if err != nil { @@ -679,9 +679,9 @@ func TestMiddleware(t *testing.T) { // traceCalls creates a middleware function that prints the method before and after each call // with the given prefix. -func traceCalls[S ClientSession | ServerSession](w io.Writer, prefix string) Middleware[S] { +func traceCalls[S Session](w io.Writer, prefix string) Middleware[S] { return func(h MethodHandler[S]) MethodHandler[S] { - return func(ctx context.Context, sess *S, method string, params Params) (Result, error) { + return func(ctx context.Context, sess S, method string, params Params) (Result, error) { fmt.Fprintf(w, "%s >%s\n", prefix, method) defer fmt.Fprintf(w, "%s <%s\n", prefix, method) return h(ctx, sess, method, params) diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 79321b8046b..16d4b8c435d 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -37,7 +37,7 @@ type Server struct { tools *featureSet[*ServerTool] resources *featureSet[*ServerResource] sessions []*ServerSession - methodHandler_ MethodHandler[ServerSession] + methodHandler_ MethodHandler[*ServerSession] } // ServerOptions is used to configure behavior of the server. @@ -77,7 +77,7 @@ func NewServer(name, version string, opts *ServerOptions) *Server { prompts: newFeatureSet(func(p *ServerPrompt) string { return p.Prompt.Name }), tools: newFeatureSet(func(t *ServerTool) string { return t.Tool.Name }), resources: newFeatureSet(func(r *ServerResource) string { return r.Resource.URI }), - methodHandler_: defaultMethodHandler[ServerSession], + methodHandler_: defaultMethodHandler[*ServerSession], } } @@ -440,14 +440,14 @@ func (ss *ServerSession) LoggingMessage(ctx context.Context, params *LoggingMess // // For example, AddMiddleware(m1, m2, m3) augments the server method handler as // m1(m2(m3(handler))). -func (s *Server) AddMiddleware(middleware ...Middleware[ServerSession]) { +func (s *Server) AddMiddleware(middleware ...Middleware[*ServerSession]) { s.mu.Lock() defer s.mu.Unlock() addMiddleware(&s.methodHandler_, middleware) } // serverMethodInfos maps from the RPC method name to serverMethodInfos. -var serverMethodInfos = map[string]methodInfo[ServerSession]{ +var serverMethodInfos = map[string]methodInfo{ methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize)), methodPing: newMethodInfo(sessionMethod((*ServerSession).ping)), methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts)), @@ -461,15 +461,11 @@ var serverMethodInfos = map[string]methodInfo[ServerSession]{ notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)), } -// *ServerSession implements the session interface. -// See toSession for why this interface seems to be necessary. -var _ session[ServerSession] = (*ServerSession)(nil) - -func (ss *ServerSession) methodInfos() map[string]methodInfo[ServerSession] { +func (ss *ServerSession) methodInfos() map[string]methodInfo { return serverMethodInfos } -func (ss *ServerSession) methodHandler() MethodHandler[ServerSession] { +func (ss *ServerSession) methodHandler() methodHandler { ss.server.mu.Lock() defer ss.server.mu.Unlock() return ss.server.methodHandler_ diff --git a/internal/mcp/shared.go b/internal/mcp/shared.go index 6f6f3d0e3eb..45478cee3f0 100644 --- a/internal/mcp/shared.go +++ b/internal/mcp/shared.go @@ -24,48 +24,45 @@ import ( // A MethodHandler handles MCP messages. // For methods, exactly one of the return values must be nil. // For notifications, both must be nil. -type MethodHandler[S ClientSession | ServerSession] func( - ctx context.Context, _ *S, method string, params Params) (result Result, err error) +type MethodHandler[S Session] func( + ctx context.Context, _ S, method string, params Params) (result Result, err error) + +// A methodHandler is a MethodHandler[Session] for some session. +// We need to give up type safety here, or we will end up with a type cycle somewhere +// else. For example, if Session.methodHandler returned a MethodHandler[Session], +// the compiler would complain. +type methodHandler any // MethodHandler[*ClientSession] | MethodHandler[*ServerSession] + +// A Session is either a ClientSession or a ServerSession. +type Session interface { + *ClientSession | *ServerSession + methodHandler() methodHandler + methodInfos() map[string]methodInfo + getConn() *jsonrpc2.Connection +} // Middleware is a function from MethodHandlers to MethodHandlers. -type Middleware[S ClientSession | ServerSession] func(MethodHandler[S]) MethodHandler[S] +type Middleware[S Session] func(MethodHandler[S]) MethodHandler[S] // addMiddleware wraps the handler in the middleware functions. -func addMiddleware[S ClientSession | ServerSession](handlerp *MethodHandler[S], middleware []Middleware[S]) { +func addMiddleware[S Session](handlerp *MethodHandler[S], middleware []Middleware[S]) { for _, m := range slices.Backward(middleware) { *handlerp = m(*handlerp) } } -// session has methods common to both ClientSession and ServerSession. -type session[S ClientSession | ServerSession] interface { - methodHandler() MethodHandler[S] - methodInfos() map[string]methodInfo[S] - getConn() *jsonrpc2.Connection -} - -// toSession[S] converts its argument to a session[S]. -// Note that since S is constrained to ClientSession | ServerSession, and pointers to those -// types both implement session[S] already, this should be a no-op. -// That it is not, is due (I believe) to a deficency in generics, possibly related to core types. -// TODO(jba): revisit in Go 1.26; perhaps the change in spec due to the removal of core types -// will have resulted by then in a more generous implementation. -func toSession[S ClientSession | ServerSession](sess *S) session[S] { - return any(sess).(session[S]) -} - // defaultMethodHandler is the initial MethodHandler for servers and clients, before being wrapped by middleware. -func defaultMethodHandler[S ClientSession | ServerSession](ctx context.Context, sess *S, method string, params Params) (Result, error) { - info, ok := toSession(sess).methodInfos()[method] +func defaultMethodHandler[S Session](ctx context.Context, session S, method string, params Params) (Result, error) { + info, ok := session.methodInfos()[method] if !ok { // This can be called from user code, with an arbitrary value for method. return nil, jsonrpc2.ErrNotHandled } - return info.handleMethod(ctx, sess, method, params) + return info.handleMethod.(MethodHandler[S])(ctx, session, method, params) } -func handleRequest[S ClientSession | ServerSession](ctx context.Context, req *jsonrpc2.Request, sess *S) (any, error) { - info, ok := toSession(sess).methodInfos()[req.Method] +func handleRequest[S Session](ctx context.Context, req *jsonrpc2.Request, session S) (any, error) { + info, ok := session.methodInfos()[req.Method] if !ok { return nil, jsonrpc2.ErrNotHandled } @@ -75,8 +72,8 @@ func handleRequest[S ClientSession | ServerSession](ctx context.Context, req *js } // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. - mh := toSession(sess).methodHandler() - res, err := mh(ctx, sess, req.Method, params) + mh := session.methodHandler().(MethodHandler[S]) + res, err := mh(ctx, session, req.Method, params) if err != nil { return nil, err } @@ -84,11 +81,11 @@ func handleRequest[S ClientSession | ServerSession](ctx context.Context, req *js } // methodInfo is information about invoking a method. -type methodInfo[TSession ClientSession | ServerSession] struct { +type methodInfo struct { // unmarshal params from the wire into an XXXParams struct unmarshalParams func(json.RawMessage) (Params, error) // run the code for the method - handleMethod MethodHandler[TSession] + handleMethod methodHandler } // The following definitions support converting from typed to untyped method handlers. @@ -98,11 +95,11 @@ type methodInfo[TSession ClientSession | ServerSession] struct { // - R: results // A typedMethodHandler is like a MethodHandler, but with type information. -type typedMethodHandler[S any, P Params, R Result] func(context.Context, *S, P) (R, error) +type typedMethodHandler[S Session, P Params, R Result] func(context.Context, S, P) (R, error) // newMethodInfo creates a methodInfo from a typedMethodHandler. -func newMethodInfo[S ClientSession | ServerSession, P Params, R Result](d typedMethodHandler[S, P, R]) methodInfo[S] { - return methodInfo[S]{ +func newMethodInfo[S Session, P Params, R Result](d typedMethodHandler[S, P, R]) methodInfo { + return methodInfo{ unmarshalParams: func(m json.RawMessage) (Params, error) { var p P if m != nil { @@ -112,29 +109,33 @@ func newMethodInfo[S ClientSession | ServerSession, P Params, R Result](d typedM } return p, nil }, - handleMethod: func(ctx context.Context, ss *S, _ string, params Params) (Result, error) { - return d(ctx, ss, params.(P)) - }, + handleMethod: MethodHandler[S](func(ctx context.Context, session S, _ string, params Params) (Result, error) { + return d(ctx, session, params.(P)) + }), } } // serverMethod is glue for creating a typedMethodHandler from a method on Server. -func serverMethod[P Params, R Result](f func(*Server, context.Context, *ServerSession, P) (R, error)) typedMethodHandler[ServerSession, P, R] { +func serverMethod[P Params, R Result]( + f func(*Server, context.Context, *ServerSession, P) (R, error), +) typedMethodHandler[*ServerSession, P, R] { return func(ctx context.Context, ss *ServerSession, p P) (R, error) { return f(ss.server, ctx, ss, p) } } // clientMethod is glue for creating a typedMethodHandler from a method on Server. -func clientMethod[P Params, R Result](f func(*Client, context.Context, *ClientSession, P) (R, error)) typedMethodHandler[ClientSession, P, R] { +func clientMethod[P Params, R Result]( + f func(*Client, context.Context, *ClientSession, P) (R, error), +) typedMethodHandler[*ClientSession, P, R] { return func(ctx context.Context, cs *ClientSession, p P) (R, error) { return f(cs.client, ctx, cs, p) } } // sessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. -func sessionMethod[S ClientSession | ServerSession, P Params, R Result](f func(*S, context.Context, P) (R, error)) typedMethodHandler[S, P, R] { - return func(ctx context.Context, sess *S, p P) (R, error) { +func sessionMethod[S Session, P Params, R Result](f func(S, context.Context, P) (R, error)) typedMethodHandler[S, P, R] { + return func(ctx context.Context, sess S, p P) (R, error) { return f(sess, ctx, p) } } @@ -152,7 +153,7 @@ const ( CodeUnsupportedMethod = -31001 ) -func callNotificationHandler[S ClientSession | ServerSession, P any](ctx context.Context, h func(context.Context, *S, *P), sess *S, params *P) (Result, error) { +func callNotificationHandler[S Session, P any](ctx context.Context, h func(context.Context, S, *P), sess S, params *P) (Result, error) { if h != nil { h(ctx, sess, params) } @@ -161,15 +162,15 @@ func callNotificationHandler[S ClientSession | ServerSession, P any](ctx context // notifySessions calls Notify on all the sessions. // Should be called on a copy of the peer sessions. -func notifySessions[S ClientSession | ServerSession](sessions []*S, method string, params Params) { +func notifySessions[S Session](sessions []S, method string, params Params) { if sessions == nil { return } // TODO: make this timeout configurable, or call Notify asynchronously. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - for _, ss := range sessions { - if err := toSession(ss).getConn().Notify(ctx, method, params); err != nil { + for _, s := range sessions { + if err := s.getConn().Notify(ctx, method, params); err != nil { // TODO(jba): surface this error better log.Printf("calling %s: %v", method, err) } From 4e672d580b0bb0ac10470374a9ea4728bcf84415 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 2 Jun 2025 10:39:52 -0400 Subject: [PATCH 169/196] gopls/internal/protocol: PublishDiagnosticsParams.version omitempty This CL reverts the status of the PublishDiagnostics.version field to omitempty. This status was removed by CL 73501, causing some clients (e.g. coc.nvim) to treat the zero version (meaning "missing" to gopls' file.Handle) as an actual version in the past. Unfortunately it is not possible to reproduce the regression using gopls' JSON encoding structures, hence the lack of a test. Fixes golang/go#73501 Updates golang/go#71489 Change-Id: Ia116f948e13782610c9de8d2d7ffd99bdb377d3e Reviewed-on: https://go-review.googlesource.com/c/tools/+/678095 Commit-Queue: Alan Donovan LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Robert Findley --- gopls/internal/cmd/cmd.go | 4 ++++ gopls/internal/protocol/generate/tables.go | 3 ++- gopls/internal/protocol/tsprotocol.go | 2 +- gopls/internal/server/diagnostics.go | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/gopls/internal/cmd/cmd.go b/gopls/internal/cmd/cmd.go index a572622e682..f4cfd99a6ba 100644 --- a/gopls/internal/cmd/cmd.go +++ b/gopls/internal/cmd/cmd.go @@ -641,6 +641,10 @@ func updateFile(filename string, old, new []byte, edits []diff.Edit, flags *Edit func (c *cmdClient) PublishDiagnostics(ctx context.Context, p *protocol.PublishDiagnosticsParams) error { // Don't worry about diagnostics without versions. + // + // (Note: the representation of PublishDiagnosticsParams + // cannot distinguish a missing Version from v0, but the + // server never sends back an explicit zero.) if p.Version == 0 { return nil } diff --git a/gopls/internal/protocol/generate/tables.go b/gopls/internal/protocol/generate/tables.go index eccaf9cd1c3..1079a0bc6b6 100644 --- a/gopls/internal/protocol/generate/tables.go +++ b/gopls/internal/protocol/generate/tables.go @@ -34,7 +34,7 @@ var goplsStar = map[prop]int{ {"CompletionItem", "kind"}: wantOpt, // need temporary variables {"CompletionParams", "context"}: wantOpt, // needs nil checks - {"Diagnostic", "severity"}: wantOpt, // nil checks or more careful thought + {"Diagnostic", "severity"}: wantOpt, // needs nil checks or more careful thought {"DidSaveTextDocumentParams", "text"}: wantOptStar, // capabilities_test.go:112 logic {"DocumentHighlight", "kind"}: wantOpt, // need temporary variables @@ -46,6 +46,7 @@ var goplsStar = map[prop]int{ {"Hover", "range"}: wantOpt, // complex expressions {"InlayHint", "kind"}: wantOpt, // temporary variables + {"PublishDiagnosticsParams", "version"}: wantOpt, // zero => missing (#73501) {"TextDocumentClientCapabilities", "codeAction"}: wantOpt, // A.B.C.D {"TextDocumentClientCapabilities", "completion"}: wantOpt, // A.B.C.D {"TextDocumentClientCapabilities", "documentSymbol"}: wantOpt, // A.B.C.D diff --git a/gopls/internal/protocol/tsprotocol.go b/gopls/internal/protocol/tsprotocol.go index a759eb2ed89..5fa2c2415db 100644 --- a/gopls/internal/protocol/tsprotocol.go +++ b/gopls/internal/protocol/tsprotocol.go @@ -4148,7 +4148,7 @@ type PublishDiagnosticsParams struct { // Optional the version number of the document the diagnostics are published for. // // @since 3.15.0 - Version int32 `json:"version"` + Version int32 `json:"version,omitempty"` // An array of diagnostic information items. Diagnostics []Diagnostic `json:"diagnostics"` } diff --git a/gopls/internal/server/diagnostics.go b/gopls/internal/server/diagnostics.go index 95046d98117..45739940a03 100644 --- a/gopls/internal/server/diagnostics.go +++ b/gopls/internal/server/diagnostics.go @@ -904,7 +904,7 @@ func (s *server) publishFileDiagnosticsLocked(ctx context.Context, views viewSet if err := s.client.PublishDiagnostics(ctx, &protocol.PublishDiagnosticsParams{ Diagnostics: toProtocolDiagnostics(unique), URI: uri, - Version: version, + Version: version, // 0 ("on disk") => omitted from JSON encoding }); err != nil { return err } From bcaee630847874c63554f72bf1bbd678909c38f0 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 2 Jun 2025 10:39:52 -0400 Subject: [PATCH 170/196] gopls/internal/protocol: make some optional integer fields indirect This change makes the following fields, all of which are integers whose zero value is not equivalent to a missing field, indirect, and adds appropriate nil checks or varOf(value) calls: SignatureHelp.activeParameter TextDocumentContentChangePartial.rangeLength WorkDoneProgressBegin.percentage WorkDoneProgressReport.percentage Updates golang/go#73501 Updates golang/go#71489 Change-Id: I81057ac235e64b2dd28097b811ecd5ed6e2b5eef Reviewed-on: https://go-review.googlesource.com/c/tools/+/677936 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Commit-Queue: Alan Donovan Auto-Submit: Alan Donovan --- gopls/internal/cache/analysis.go | 3 +-- gopls/internal/progress/progress.go | 5 +++-- gopls/internal/progress/progress_test.go | 2 +- gopls/internal/protocol/generate/tables.go | 4 ++++ gopls/internal/protocol/tsprotocol.go | 8 ++++---- gopls/internal/server/signature_help.go | 5 ++++- gopls/internal/server/text_synchronization.go | 4 ++-- gopls/internal/test/integration/env_test.go | 5 ++++- gopls/internal/test/marker/marker_test.go | 8 ++++++-- 9 files changed, 29 insertions(+), 15 deletions(-) diff --git a/gopls/internal/cache/analysis.go b/gopls/internal/cache/analysis.go index f63bcab2374..b654833e08c 100644 --- a/gopls/internal/cache/analysis.go +++ b/gopls/internal/cache/analysis.go @@ -291,8 +291,7 @@ func (s *Snapshot) Analyze(ctx context.Context, pkgs map[PackageID]*metadata.Pac // Trailing space is intentional: some LSP clients strip newlines. msg := fmt.Sprintf(`Indexed %d/%d packages. (Set "analysisProgressReporting" to false to disable notifications.)`, completed, len(nodes)) - pct := 100 * float64(completed) / float64(len(nodes)) - wd.Report(ctx, msg, pct) + wd.Report(ctx, msg, float64(completed)/float64(len(nodes))) } } } diff --git a/gopls/internal/progress/progress.go b/gopls/internal/progress/progress.go index e35c0fe19dc..d7820654b48 100644 --- a/gopls/internal/progress/progress.go +++ b/gopls/internal/progress/progress.go @@ -187,7 +187,7 @@ func (wd *WorkDone) doCancel() { } // Report reports an update on WorkDone report back to the client. -func (wd *WorkDone) Report(ctx context.Context, message string, percentage float64) { +func (wd *WorkDone) Report(ctx context.Context, message string, fraction float64) { ctx = xcontext.Detach(ctx) // progress messages should not be cancelled if wd == nil { return @@ -204,6 +204,7 @@ func (wd *WorkDone) Report(ctx context.Context, message string, percentage float return } message = strings.TrimSuffix(message, "\n") + percentage := uint32(100 * fraction) err := wd.client.Progress(ctx, &protocol.ProgressParams{ Token: wd.token, Value: &protocol.WorkDoneProgressReport{ @@ -213,7 +214,7 @@ func (wd *WorkDone) Report(ctx context.Context, message string, percentage float // yet use this feature, the value is kept constant here. Cancellable: wd.cancel != nil, Message: message, - Percentage: uint32(percentage), + Percentage: &percentage, }, }) if err != nil { diff --git a/gopls/internal/progress/progress_test.go b/gopls/internal/progress/progress_test.go index 687f99ba4a1..db0820f3046 100644 --- a/gopls/internal/progress/progress_test.go +++ b/gopls/internal/progress/progress_test.go @@ -123,7 +123,7 @@ func TestProgressTracker_Reporting(t *testing.T) { t.Errorf("got %d work begun, want %d", gotBegun, test.wantBegun) } // Ignore errors: this is just testing the reporting behavior. - work.Report(ctx, "report", 50) + work.Report(ctx, "report", 0.5) client.mu.Lock() gotReported := client.reported client.mu.Unlock() diff --git a/gopls/internal/protocol/generate/tables.go b/gopls/internal/protocol/generate/tables.go index 1079a0bc6b6..bb3f20cce90 100644 --- a/gopls/internal/protocol/generate/tables.go +++ b/gopls/internal/protocol/generate/tables.go @@ -47,14 +47,18 @@ var goplsStar = map[prop]int{ {"InlayHint", "kind"}: wantOpt, // temporary variables {"PublishDiagnosticsParams", "version"}: wantOpt, // zero => missing (#73501) + {"SignatureHelp", "activeParameter"}: wantOptStar, // unset != zero {"TextDocumentClientCapabilities", "codeAction"}: wantOpt, // A.B.C.D {"TextDocumentClientCapabilities", "completion"}: wantOpt, // A.B.C.D {"TextDocumentClientCapabilities", "documentSymbol"}: wantOpt, // A.B.C.D {"TextDocumentClientCapabilities", "publishDiagnostics"}: wantOpt, // A.B.C.D {"TextDocumentClientCapabilities", "semanticTokens"}: wantOpt, // A.B.C.D {"TextDocumentContentChangePartial", "range"}: wantOptStar, // == nil test + {"TextDocumentContentChangePartial", "rangeLength"}: wantOptStar, // unset != zero {"TextDocumentSyncOptions", "change"}: wantOpt, // &constant + {"WorkDoneProgressBegin", "percentage"}: wantOptStar, // unset != zero {"WorkDoneProgressParams", "workDoneToken"}: wantOpt, // test failures + {"WorkDoneProgressReport", "percentage"}: wantOptStar, // unset != zero {"WorkspaceClientCapabilities", "didChangeConfiguration"}: wantOpt, // A.B.C.D {"WorkspaceClientCapabilities", "didChangeWatchedFiles"}: wantOpt, // A.B.C.D } diff --git a/gopls/internal/protocol/tsprotocol.go b/gopls/internal/protocol/tsprotocol.go index 5fa2c2415db..e74a61961f9 100644 --- a/gopls/internal/protocol/tsprotocol.go +++ b/gopls/internal/protocol/tsprotocol.go @@ -4924,7 +4924,7 @@ type SignatureHelp struct { // In future version of the protocol this property might become // mandatory (but still nullable) to better express the active parameter if // the active signature does have any. - ActiveParameter uint32 `json:"activeParameter"` + ActiveParameter *uint32 `json:"activeParameter,omitempty"` } // Client Capabilities for a {@link SignatureHelpRequest}. @@ -5261,7 +5261,7 @@ type TextDocumentContentChangePartial struct { // The optional length of the range that got replaced. // // @deprecated use range instead. - RangeLength uint32 `json:"rangeLength"` + RangeLength *uint32 `json:"rangeLength,omitempty"` // The new text for the provided range. Text string `json:"text"` } @@ -5764,7 +5764,7 @@ type WorkDoneProgressBegin struct { // // The value should be steadily rising. Clients are free to ignore values // that are not following this rule. The value range is [0, 100]. - Percentage uint32 `json:"percentage"` + Percentage *uint32 `json:"percentage,omitempty"` } // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#workDoneProgressCancelParams @@ -5824,7 +5824,7 @@ type WorkDoneProgressReport struct { // // The value should be steadily rising. Clients are free to ignore values // that are not following this rule. The value range is [0, 100] - Percentage uint32 `json:"percentage"` + Percentage *uint32 `json:"percentage,omitempty"` } // Workspace specific client capabilities. diff --git a/gopls/internal/server/signature_help.go b/gopls/internal/server/signature_help.go index eb464c48e27..8850a1624ae 100644 --- a/gopls/internal/server/signature_help.go +++ b/gopls/internal/server/signature_help.go @@ -43,6 +43,9 @@ func (s *server) SignatureHelp(ctx context.Context, params *protocol.SignatureHe } return &protocol.SignatureHelp{ Signatures: []protocol.SignatureInformation{*info}, - ActiveParameter: uint32(activeParameter), + ActiveParameter: varOf(uint32(activeParameter)), }, nil } + +// varOf returns a new variable whose value is x. +func varOf[T any](x T) *T { return &x } diff --git a/gopls/internal/server/text_synchronization.go b/gopls/internal/server/text_synchronization.go index ad8554d9302..982d0e7a292 100644 --- a/gopls/internal/server/text_synchronization.go +++ b/gopls/internal/server/text_synchronization.go @@ -313,7 +313,7 @@ func (s *server) changedText(ctx context.Context, uri protocol.DocumentURI, chan // Check if the client sent the full content of the file. // We accept a full content change even if the server expected incremental changes. - if len(changes) == 1 && changes[0].Range == nil && changes[0].RangeLength == 0 { + if len(changes) == 1 && changes[0].Range == nil && changes[0].RangeLength == nil { changeFull.Inc() return []byte(changes[0].Text), nil } @@ -388,7 +388,7 @@ func (s *server) checkEfficacy(uri protocol.DocumentURI, version int32, change p } if edit.Range.Start == change.Range.Start { // the change and the proposed completion start at the same - if change.RangeLength == 0 && len(change.Text) == 1 { + if (change.RangeLength == nil || *change.RangeLength == 0) && len(change.Text) == 1 { // a single character added it does not count as a completion continue } diff --git a/gopls/internal/test/integration/env_test.go b/gopls/internal/test/integration/env_test.go index 1fa68676b5c..3b62bc748c8 100644 --- a/gopls/internal/test/integration/env_test.go +++ b/gopls/internal/test/integration/env_test.go @@ -38,7 +38,7 @@ func TestProgressUpdating(t *testing.T) { {"foo", protocol.WorkDoneProgressBegin{Kind: "begin", Title: "foo work"}}, {"bar", protocol.WorkDoneProgressBegin{Kind: "begin", Title: "bar work"}}, {"foo", protocol.WorkDoneProgressEnd{Kind: "end"}}, - {"bar", protocol.WorkDoneProgressReport{Kind: "report", Percentage: 42}}, + {"bar", protocol.WorkDoneProgressReport{Kind: "report", Percentage: varOf[uint32](42)}}, } for _, update := range updates { params := &protocol.ProgressParams{ @@ -66,3 +66,6 @@ func TestProgressUpdating(t *testing.T) { t.Errorf("work progress for \"bar\": %v, want %v", got, want) } } + +// varOf returns a new variable whose value is x. +func varOf[T any](x T) *T { return &x } diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index 3914fb76e4b..0e4a1026dfb 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -1988,8 +1988,12 @@ func signatureMarker(mark marker, src protocol.Location, label string, active in if got := gotLabels[0]; got != label { mark.errorf("signatureHelp: got label %q, want %q", got, label) } - if got := int64(got.ActiveParameter); got != active { - mark.errorf("signatureHelp: got active parameter %d, want %d", got, active) + gotActiveParameter := int64(-1) // => missing + if got.ActiveParameter != nil { + gotActiveParameter = int64(*got.ActiveParameter) + } + if gotActiveParameter != active { + mark.errorf("signatureHelp: got active parameter %d, want %d", gotActiveParameter, active) } } From 8675e27c3bcbd0a0859126b3ce2bcb43e6e21406 Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Mon, 19 May 2025 18:14:57 -0400 Subject: [PATCH 171/196] gopls/internal/mcp: add context mcp tool The context tool accept location as input parameter, the mcp server returns the context built based on the input location. The mcp server only goes through imported packages constructing summarized Go files. For now, the summarized Go file contains: - Comments on top of the go file(file doc or package doc). - Package decl. - Exported functions/methods without function body. Replace the mcp structure tag with mcp.Description. For golang/go#73580 Change-Id: I1c3c4e43e26fddb4e820113fac923647d1da95f8 Reviewed-on: https://go-review.googlesource.com/c/tools/+/675738 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan Auto-Submit: Hongxiang Jiang --- gopls/internal/cache/session.go | 15 ++ gopls/internal/golang/symbols.go | 1 - gopls/internal/mcp/context.go | 208 ++++++++++++++++++ gopls/internal/mcp/mcp.go | 66 +++--- gopls/internal/server/call_hierarchy.go | 6 +- gopls/internal/server/code_action.go | 2 +- gopls/internal/server/code_lens.go | 2 +- gopls/internal/server/command.go | 10 +- gopls/internal/server/completion.go | 2 +- gopls/internal/server/definition.go | 4 +- gopls/internal/server/diagnostics.go | 2 +- gopls/internal/server/folding_range.go | 2 +- gopls/internal/server/format.go | 2 +- gopls/internal/server/general.go | 16 -- gopls/internal/server/highlight.go | 2 +- gopls/internal/server/hover.go | 2 +- gopls/internal/server/implementation.go | 2 +- gopls/internal/server/inlay_hint.go | 2 +- gopls/internal/server/link.go | 2 +- gopls/internal/server/references.go | 2 +- gopls/internal/server/rename.go | 4 +- gopls/internal/server/selection_range.go | 2 +- gopls/internal/server/semantic.go | 2 +- gopls/internal/server/signature_help.go | 2 +- gopls/internal/server/symbols.go | 2 +- gopls/internal/server/type_hierarchy.go | 6 +- gopls/internal/server/workspace.go | 2 +- gopls/internal/test/marker/marker_test.go | 6 +- .../test/marker/testdata/mcptools/context.txt | 144 ++++++++++++ .../marker/testdata/mcptools/hello_world.txt | 17 -- 30 files changed, 428 insertions(+), 109 deletions(-) create mode 100644 gopls/internal/mcp/context.go create mode 100644 gopls/internal/test/marker/testdata/mcptools/context.txt delete mode 100644 gopls/internal/test/marker/testdata/mcptools/hello_world.txt diff --git a/gopls/internal/cache/session.go b/gopls/internal/cache/session.go index 82472e82a95..8a9a589b708 100644 --- a/gopls/internal/cache/session.go +++ b/gopls/internal/cache/session.go @@ -451,6 +451,21 @@ func (s *Session) SnapshotOf(ctx context.Context, uri protocol.DocumentURI) (*Sn return nil, nil, errNoViews } +// FileOf returns the file for a given URI and its snapshot. +// On success, the returned function must be called to release the snapshot. +func (s *Session) FileOf(ctx context.Context, uri protocol.DocumentURI) (file.Handle, *Snapshot, func(), error) { + snapshot, release, err := s.SnapshotOf(ctx, uri) + if err != nil { + return nil, nil, nil, err + } + fh, err := snapshot.ReadFile(ctx, uri) + if err != nil { + release() + return nil, nil, nil, err + } + return fh, snapshot, release, nil +} + // errNoViews is sought by orphaned file diagnostics, to detect the case where // we have no view containing a file. var errNoViews = errors.New("no views") diff --git a/gopls/internal/golang/symbols.go b/gopls/internal/golang/symbols.go index c49a498ab18..101ef7a06e3 100644 --- a/gopls/internal/golang/symbols.go +++ b/gopls/internal/golang/symbols.go @@ -169,7 +169,6 @@ func PackageSymbols(ctx context.Context, snapshot *cache.Snapshot, uri protocol. Files: pkgFiles, Symbols: symbols, }, nil - } func toPackageSymbol(fileIndex int, s protocol.DocumentSymbol) command.PackageSymbol { diff --git a/gopls/internal/mcp/context.go b/gopls/internal/mcp/context.go new file mode 100644 index 00000000000..9cac739c7a3 --- /dev/null +++ b/gopls/internal/mcp/context.go @@ -0,0 +1,208 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp + +// This file defines the "context" operation, which +// returns a summary of the specified package. + +import ( + "bytes" + "context" + "fmt" + "go/ast" + "go/token" + "path/filepath" + "strings" + + "golang.org/x/tools/gopls/internal/cache" + "golang.org/x/tools/gopls/internal/cache/metadata" + "golang.org/x/tools/gopls/internal/cache/parsego" + "golang.org/x/tools/gopls/internal/file" + "golang.org/x/tools/gopls/internal/golang" + "golang.org/x/tools/gopls/internal/protocol" + "golang.org/x/tools/gopls/internal/util/astutil" + "golang.org/x/tools/gopls/internal/util/bug" + "golang.org/x/tools/internal/mcp" +) + +type ContextParams struct { + // TODO(hxjiang): experiment if the LLM can correctly provide the right + // location information. + Location protocol.Location `json:"location"` +} + +func contextHandler(ctx context.Context, session *cache.Session, params *mcp.CallToolParams[ContextParams]) (*mcp.CallToolResult, error) { + fh, snapshot, release, err := session.FileOf(ctx, params.Arguments.Location.URI) + if err != nil { + return nil, err + } + defer release() + + // TODO(hxjiang): support context for GoMod. + if snapshot.FileKind(fh) != file.Go { + return nil, fmt.Errorf("can't provide context for non-Go file") + } + + pkg, pgf, err := golang.NarrowestPackageForFile(ctx, snapshot, params.Arguments.Location.URI) + if err != nil { + return nil, err + } + + var result strings.Builder + result.WriteString("Code blocks are delimited by --->...<--- markers.\n\n") + // TODO(hxjiang): consider making the context tool best effort. Ignore + // non-critical errors. + if err := writePackageSummary(ctx, snapshot, pkg, pgf, &result); err != nil { + return nil, err + } + + return &mcp.CallToolResult{ + Content: []*mcp.Content{ + mcp.NewTextContent(result.String()), + }, + }, nil +} + +// writePackageSummary writes the package summaries to the bytes buffer based on +// the input import specs. +func writePackageSummary(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, out *strings.Builder) error { + if len(pgf.File.Imports) == 0 { + return nil + } + + fmt.Fprintf(out, "Current file %q contains this import declaration:\n", filepath.Base(pgf.URI.Path())) + out.WriteString("--->\n") + // Add all import decl to output including all floating comment by using + // GenDecl's start and end position. + for _, decl := range pgf.File.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + + if genDecl.Tok != token.IMPORT { + continue + } + + text, err := pgf.NodeText(genDecl) + if err != nil { + return err + } + + out.Write(text) + out.WriteString("\n") + } + out.WriteString("<---\n\n") + + out.WriteString("The imported packages declare the following symbols:\n\n") + + for _, imp := range pgf.File.Imports { + importPath := metadata.UnquoteImportPath(imp) + if importPath == "" { + continue + } + + impID := pkg.Metadata().DepsByImpPath[importPath] + if impID == "" { + return fmt.Errorf("no package data for import %q", importPath) + } + impMetadata := snapshot.Metadata(impID) + if impMetadata == nil { + return bug.Errorf("failed to resolve import ID %q", impID) + } + + fmt.Fprintf(out, "%s (package %s)\n", importPath, impMetadata.Name) + for _, f := range impMetadata.CompiledGoFiles { + fmt.Fprintf(out, "%s:\n", filepath.Base(f.Path())) + out.WriteString("--->\n") + fh, err := snapshot.ReadFile(ctx, f) + if err != nil { + return err + } + pgf, err := snapshot.ParseGo(ctx, fh, parsego.Full) + if err != nil { + return err + } + + // Copy everything before the first non-import declaration: + // package decl, imports decl(s), and all comments (excluding copyright). + { + endPos := pgf.File.FileEnd + + outerloop: + for _, decl := range pgf.File.Decls { + switch decl := decl.(type) { + case *ast.FuncDecl: + if decl.Doc != nil { + endPos = decl.Doc.Pos() + } else { + endPos = decl.Pos() + } + break outerloop + case *ast.GenDecl: + if decl.Tok == token.IMPORT { + continue + } + if decl.Doc != nil { + endPos = decl.Doc.Pos() + } else { + endPos = decl.Pos() + } + break outerloop + } + } + + startPos := pgf.File.FileStart + if copyright := golang.CopyrightComment(pgf.File); copyright != nil { + startPos = copyright.End() + } + + text, err := pgf.PosText(startPos, endPos) + if err != nil { + return err + } + + out.Write(bytes.TrimSpace(text)) + out.WriteString("\n") + } + + // Write exported func decl and gen decl. + // TODO(hxjiang): write exported gen decl. + for _, decl := range pgf.File.Decls { + switch decl := decl.(type) { + case *ast.FuncDecl: + if !decl.Name.IsExported() { + continue + } + + if decl.Recv != nil && len(decl.Recv.List) > 0 { + _, rname, _ := astutil.UnpackRecv(decl.Recv.List[0].Type) + if !rname.IsExported() { + continue + } + } + + out.WriteString("\n") + // Write doc comment and func signature. + startPos := decl.Pos() + if decl.Doc != nil { + startPos = decl.Doc.Pos() + } + + text, err := pgf.PosText(startPos, decl.Type.End()) + if err != nil { + return err + } + + out.Write(text) + out.WriteString("\n") + } + } + + out.WriteString("<---\n\n") + } + } + return nil +} diff --git a/gopls/internal/mcp/mcp.go b/gopls/internal/mcp/mcp.go index b3a897d359f..7887eeab5b2 100644 --- a/gopls/internal/mcp/mcp.go +++ b/gopls/internal/mcp/mcp.go @@ -10,7 +10,6 @@ import ( "log" "net" "net/http" - "path" "sync" "golang.org/x/tools/gopls/internal/cache" @@ -114,48 +113,37 @@ func HTTPHandler(eventChan <-chan lsprpc.SessionEvent, cache *cache.Cache, isDae func newServer(_ *cache.Cache, session *cache.Session) *mcp.Server { s := mcp.NewServer("golang", "v0.1", nil) - // TODO(hxjiang): replace dummy tool with tools which use cache and session. s.AddTools( mcp.NewTool( - "hello_world", - "Say hello to someone", - func(ctx context.Context, _ *mcp.ServerSession, params *mcp.CallToolParams[HelloParams]) (*mcp.CallToolResult, error) { - return helloHandler(ctx, session, params) + "context", + "Provide context for a region within a Go file", + func(ctx context.Context, _ *mcp.ServerSession, request *mcp.CallToolParams[ContextParams]) (*mcp.CallToolResult, error) { + return contextHandler(ctx, session, request) }, + mcp.Input( + mcp.Property( + "location", + mcp.Description("location inside of a text file"), + mcp.Property("uri", mcp.Description("URI of the text document")), + mcp.Property("range", + mcp.Description("range within text document"), + mcp.Property( + "start", + mcp.Description("start position of range"), + mcp.Property("line", mcp.Description("line number (zero-based)")), + mcp.Property("character", mcp.Description("column number (zero-based, UTF-16 encoding)")), + ), + mcp.Property( + "end", + mcp.Description("end position of range"), + mcp.Property("line", mcp.Description("line number (zero-based)")), + mcp.Property("character", mcp.Description("column number (zero-based, UTF-16 encoding)")), + ), + ), + ), + ), ), ) - return s -} - -type HelloParams struct { - Name string `json:"name" mcp:"the name to say hi to"` - Location Location `json:"loc" mcp:"location inside of a text file"` -} - -func helloHandler(_ context.Context, _ *cache.Session, params *mcp.CallToolParams[HelloParams]) (*mcp.CallToolResult, error) { - return &mcp.CallToolResult{ - Content: []*mcp.Content{ - mcp.NewTextContent(fmt.Sprintf("Hi %s, current file %s.", params.Arguments.Name, path.Base(params.Arguments.Location.URI))), - }, - }, nil -} - -// Location describes a range within a text document. -// -// It is structurally equal to protocol.Location, but has mcp tags instead of json. -// TODO(hxjiang): experiment if the LLM can correctly provide the right location -// information. -type Location struct { - URI string `json:"uri" mcp:"URI to the text file"` - Range Range `json:"range" mcp:"range within text document"` -} -type Range struct { - Start Position `json:"start" mcp:"the range's start position"` - End Position `json:"end" mcp:"the range's end position"` -} - -type Position struct { - Line uint32 `json:"line" mcp:"line number (zero-based)"` - Character uint32 `json:"character" mcp:"column number (zero-based, UTF-16 encoding)"` + return s } diff --git a/gopls/internal/server/call_hierarchy.go b/gopls/internal/server/call_hierarchy.go index 1887767250c..dc8cd4cec2a 100644 --- a/gopls/internal/server/call_hierarchy.go +++ b/gopls/internal/server/call_hierarchy.go @@ -17,7 +17,7 @@ func (s *server) PrepareCallHierarchy(ctx context.Context, params *protocol.Call ctx, done := event.Start(ctx, "server.PrepareCallHierarchy") defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } @@ -33,7 +33,7 @@ func (s *server) IncomingCalls(ctx context.Context, params *protocol.CallHierarc ctx, done := event.Start(ctx, "server.IncomingCalls") defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.Item.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.Item.URI) if err != nil { return nil, err } @@ -49,7 +49,7 @@ func (s *server) OutgoingCalls(ctx context.Context, params *protocol.CallHierarc ctx, done := event.Start(ctx, "server.OutgoingCalls") defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.Item.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.Item.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/code_action.go b/gopls/internal/server/code_action.go index 9fa2bf54459..e37cfc9f73e 100644 --- a/gopls/internal/server/code_action.go +++ b/gopls/internal/server/code_action.go @@ -24,7 +24,7 @@ func (s *server) CodeAction(ctx context.Context, params *protocol.CodeActionPara ctx, done := event.Start(ctx, "server.CodeAction") defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/code_lens.go b/gopls/internal/server/code_lens.go index 2509452f0b5..b644f8b7fff 100644 --- a/gopls/internal/server/code_lens.go +++ b/gopls/internal/server/code_lens.go @@ -25,7 +25,7 @@ func (s *server) CodeLens(ctx context.Context, params *protocol.CodeLensParams) ctx, done := event.Start(ctx, "server.CodeLens", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/command.go b/gopls/internal/server/command.go index 8782dfd1460..41de2cd7c7e 100644 --- a/gopls/internal/server/command.go +++ b/gopls/internal/server/command.go @@ -355,7 +355,7 @@ func (c *commandHandler) run(ctx context.Context, cfg commandConfig, run command return bug.Errorf("internal error: forURI=%q, forView=%q", cfg.forURI, cfg.forView) } if cfg.forURI != "" { - deps.fh, deps.snapshot, release, err = c.s.fileOf(ctx, cfg.forURI) + deps.fh, deps.snapshot, release, err = c.s.session.FileOf(ctx, cfg.forURI) if err != nil { return err } @@ -546,7 +546,7 @@ func (c *commandHandler) UpdateGoSum(ctx context.Context, args command.URIArgs) progress: "Updating go.sum", }, func(ctx context.Context, _ commandDeps) error { for _, uri := range args.URIs { - fh, snapshot, release, err := c.s.fileOf(ctx, uri) + fh, snapshot, release, err := c.s.session.FileOf(ctx, uri) if err != nil { return err } @@ -567,7 +567,7 @@ func (c *commandHandler) Tidy(ctx context.Context, args command.URIArgs) error { progress: "Running go mod tidy", }, func(ctx context.Context, _ commandDeps) error { for _, uri := range args.URIs { - fh, snapshot, release, err := c.s.fileOf(ctx, uri) + fh, snapshot, release, err := c.s.session.FileOf(ctx, uri) if err != nil { return err } @@ -616,7 +616,7 @@ func (c *commandHandler) EditGoDirective(ctx context.Context, args command.EditG requireSave: true, // if go.mod isn't saved it could cause a problem forURI: args.URI, }, func(ctx context.Context, _ commandDeps) error { - fh, snapshot, release, err := c.s.fileOf(ctx, args.URI) + fh, snapshot, release, err := c.s.session.FileOf(ctx, args.URI) if err != nil { return err } @@ -1650,7 +1650,7 @@ func (c *commandHandler) DiagnoseFiles(ctx context.Context, args command.Diagnos snapshots := make(map[*cache.Snapshot]bool) for _, uri := range args.Files { - fh, snapshot, release, err := c.s.fileOf(ctx, uri) + fh, snapshot, release, err := c.s.session.FileOf(ctx, uri) if err != nil { return err } diff --git a/gopls/internal/server/completion.go b/gopls/internal/server/completion.go index 02604b2f710..21f1be040f2 100644 --- a/gopls/internal/server/completion.go +++ b/gopls/internal/server/completion.go @@ -30,7 +30,7 @@ func (s *server) Completion(ctx context.Context, params *protocol.CompletionPara ctx, done := event.Start(ctx, "server.Completion", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/definition.go b/gopls/internal/server/definition.go index 8b9d42413be..34863ffa134 100644 --- a/gopls/internal/server/definition.go +++ b/gopls/internal/server/definition.go @@ -28,7 +28,7 @@ func (s *server) Definition(ctx context.Context, params *protocol.DefinitionPara defer done() // TODO(rfindley): definition requests should be multiplexed across all views. - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } @@ -50,7 +50,7 @@ func (s *server) TypeDefinition(ctx context.Context, params *protocol.TypeDefini defer done() // TODO(rfindley): type definition requests should be multiplexed across all views. - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/diagnostics.go b/gopls/internal/server/diagnostics.go index 45739940a03..2a529ef6624 100644 --- a/gopls/internal/server/diagnostics.go +++ b/gopls/internal/server/diagnostics.go @@ -47,7 +47,7 @@ func (s *server) Diagnostic(ctx context.Context, params *protocol.DocumentDiagno ctx, done := event.Start(ctx, "server.Diagnostic") defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/folding_range.go b/gopls/internal/server/folding_range.go index 5dbfd697db4..d570234065f 100644 --- a/gopls/internal/server/folding_range.go +++ b/gopls/internal/server/folding_range.go @@ -18,7 +18,7 @@ func (s *server) FoldingRange(ctx context.Context, params *protocol.FoldingRange ctx, done := event.Start(ctx, "server.FoldingRange", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/format.go b/gopls/internal/server/format.go index 6abbb96d5b6..de30aeb9482 100644 --- a/gopls/internal/server/format.go +++ b/gopls/internal/server/format.go @@ -20,7 +20,7 @@ func (s *server) Formatting(ctx context.Context, params *protocol.DocumentFormat ctx, done := event.Start(ctx, "server.Formatting", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/general.go b/gopls/internal/server/general.go index 0eea8f641f7..b3b6c96a735 100644 --- a/gopls/internal/server/general.go +++ b/gopls/internal/server/general.go @@ -24,7 +24,6 @@ import ( "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/debug" debuglog "golang.org/x/tools/gopls/internal/debug/log" - "golang.org/x/tools/gopls/internal/file" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/protocol/semtok" "golang.org/x/tools/gopls/internal/settings" @@ -623,21 +622,6 @@ func (s *server) handleOptionResult(ctx context.Context, applied []telemetry.Cou } } -// fileOf returns the file for a given URI and its snapshot. -// On success, the returned function must be called to release the snapshot. -func (s *server) fileOf(ctx context.Context, uri protocol.DocumentURI) (file.Handle, *cache.Snapshot, func(), error) { - snapshot, release, err := s.session.SnapshotOf(ctx, uri) - if err != nil { - return nil, nil, nil, err - } - fh, err := snapshot.ReadFile(ctx, uri) - if err != nil { - release() - return nil, nil, nil, err - } - return fh, snapshot, release, nil -} - // Shutdown implements the 'shutdown' LSP handler. It releases resources // associated with the server and waits for all ongoing work to complete. func (s *server) Shutdown(ctx context.Context) error { diff --git a/gopls/internal/server/highlight.go b/gopls/internal/server/highlight.go index 04ebbfa25ec..6ff73d84b37 100644 --- a/gopls/internal/server/highlight.go +++ b/gopls/internal/server/highlight.go @@ -19,7 +19,7 @@ func (s *server) DocumentHighlight(ctx context.Context, params *protocol.Documen ctx, done := event.Start(ctx, "server.DocumentHighlight", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/hover.go b/gopls/internal/server/hover.go index ed70ce493ba..180bfcff7df 100644 --- a/gopls/internal/server/hover.go +++ b/gopls/internal/server/hover.go @@ -28,7 +28,7 @@ func (s *server) Hover(ctx context.Context, params *protocol.HoverParams) (_ *pr ctx, done := event.Start(ctx, "server.Hover", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/implementation.go b/gopls/internal/server/implementation.go index 9b2c103b2c3..c4e59c9d3e0 100644 --- a/gopls/internal/server/implementation.go +++ b/gopls/internal/server/implementation.go @@ -24,7 +24,7 @@ func (s *server) Implementation(ctx context.Context, params *protocol.Implementa ctx, done := event.Start(ctx, "server.Implementation", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/inlay_hint.go b/gopls/internal/server/inlay_hint.go index a11ab4c313a..e7a46806ba7 100644 --- a/gopls/internal/server/inlay_hint.go +++ b/gopls/internal/server/inlay_hint.go @@ -19,7 +19,7 @@ func (s *server) InlayHint(ctx context.Context, params *protocol.InlayHintParams ctx, done := event.Start(ctx, "server.InlayHint", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/link.go b/gopls/internal/server/link.go index 52e8ca379c5..a98e2bc2688 100644 --- a/gopls/internal/server/link.go +++ b/gopls/internal/server/link.go @@ -34,7 +34,7 @@ func (s *server) DocumentLink(ctx context.Context, params *protocol.DocumentLink ctx, done := event.Start(ctx, "server.DocumentLink") defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/references.go b/gopls/internal/server/references.go index 8a01e96498b..2916f27c093 100644 --- a/gopls/internal/server/references.go +++ b/gopls/internal/server/references.go @@ -25,7 +25,7 @@ func (s *server) References(ctx context.Context, params *protocol.ReferenceParam ctx, done := event.Start(ctx, "server.References", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/rename.go b/gopls/internal/server/rename.go index 218740bd679..8e876576b40 100644 --- a/gopls/internal/server/rename.go +++ b/gopls/internal/server/rename.go @@ -20,7 +20,7 @@ func (s *server) Rename(ctx context.Context, params *protocol.RenameParams) (*pr ctx, done := event.Start(ctx, "server.Rename", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } @@ -71,7 +71,7 @@ func (s *server) PrepareRename(ctx context.Context, params *protocol.PrepareRena ctx, done := event.Start(ctx, "server.PrepareRename", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/selection_range.go b/gopls/internal/server/selection_range.go index afc878b1544..a70398e1d65 100644 --- a/gopls/internal/server/selection_range.go +++ b/gopls/internal/server/selection_range.go @@ -30,7 +30,7 @@ func (s *server) SelectionRange(ctx context.Context, params *protocol.SelectionR ctx, done := event.Start(ctx, "server.SelectionRange") defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/semantic.go b/gopls/internal/server/semantic.go index f0a2e11dd98..6cbcb661937 100644 --- a/gopls/internal/server/semantic.go +++ b/gopls/internal/server/semantic.go @@ -27,7 +27,7 @@ func (s *server) semanticTokens(ctx context.Context, td protocol.TextDocumentIde ctx, done := event.Start(ctx, "server.semanticTokens", label.URI.Of(td.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, td.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, td.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/signature_help.go b/gopls/internal/server/signature_help.go index 8850a1624ae..b273cef9831 100644 --- a/gopls/internal/server/signature_help.go +++ b/gopls/internal/server/signature_help.go @@ -18,7 +18,7 @@ func (s *server) SignatureHelp(ctx context.Context, params *protocol.SignatureHe ctx, done := event.Start(ctx, "server.SignatureHelp", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/symbols.go b/gopls/internal/server/symbols.go index 334154add5b..5806fa581f2 100644 --- a/gopls/internal/server/symbols.go +++ b/gopls/internal/server/symbols.go @@ -19,7 +19,7 @@ func (s *server) DocumentSymbol(ctx context.Context, params *protocol.DocumentSy ctx, done := event.Start(ctx, "server.DocumentSymbol", label.URI.Of(params.TextDocument.URI)) defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/type_hierarchy.go b/gopls/internal/server/type_hierarchy.go index 5f40ed3c0c2..037c8c5957c 100644 --- a/gopls/internal/server/type_hierarchy.go +++ b/gopls/internal/server/type_hierarchy.go @@ -18,7 +18,7 @@ func (s *server) PrepareTypeHierarchy(ctx context.Context, params *protocol.Type ctx, done := event.Start(ctx, "server.PrepareTypeHierarchy") defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.TextDocument.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.TextDocument.URI) if err != nil { return nil, err } @@ -34,7 +34,7 @@ func (s *server) Subtypes(ctx context.Context, params *protocol.TypeHierarchySub ctx, done := event.Start(ctx, "server.Subtypes") defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.Item.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.Item.URI) if err != nil { return nil, err } @@ -50,7 +50,7 @@ func (s *server) Supertypes(ctx context.Context, params *protocol.TypeHierarchyS ctx, done := event.Start(ctx, "server.Supertypes") defer done() - fh, snapshot, release, err := s.fileOf(ctx, params.Item.URI) + fh, snapshot, release, err := s.session.FileOf(ctx, params.Item.URI) if err != nil { return nil, err } diff --git a/gopls/internal/server/workspace.go b/gopls/internal/server/workspace.go index 01e2c69d8ee..0a9536de476 100644 --- a/gopls/internal/server/workspace.go +++ b/gopls/internal/server/workspace.go @@ -149,7 +149,7 @@ func (s *server) DidCreateFiles(ctx context.Context, params *protocol.CreateFile var allChanges []protocol.DocumentChange for _, createdFile := range params.Files { uri := protocol.DocumentURI(createdFile.URI) - fh, snapshot, release, err := s.fileOf(ctx, uri) + fh, snapshot, release, err := s.session.FileOf(ctx, uri) if err != nil { event.Error(ctx, "fail to call fileOf", err) continue diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index 0e4a1026dfb..d719a758c1c 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -2449,10 +2449,8 @@ func mcpToolMarker(mark marker, tool string, rawArgs string, loc protocol.Locati return } - // Inserts the location value into the MCP tool arguments map under the - // "loc" key. - // TODO(hxjiang): Make the "loc" key configurable. - args["loc"] = loc + // TODO(hxjiang): Make the "location" key configurable. + args["location"] = loc res, err := mcp.CallTool(mark.ctx(), mark.run.env.MCPSession, &mcp.CallToolParams[map[string]any]{ Name: tool, diff --git a/gopls/internal/test/marker/testdata/mcptools/context.txt b/gopls/internal/test/marker/testdata/mcptools/context.txt new file mode 100644 index 00000000000..a6dba65f8d4 --- /dev/null +++ b/gopls/internal/test/marker/testdata/mcptools/context.txt @@ -0,0 +1,144 @@ +This test exercises mcp tool context. + +-- flags -- +-mcp +-ignore_extra_diags + +-- go.mod -- +module example.com + +-- comment/doc.go -- +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package doc for package comment. +*/ +package comment + +-- comment/foo.go -- +// File doc for foo.go part 1. +package comment + +// File doc for foo.go part 2. +import ( + // comment for package renaming. + myfmt "fmt" +) + +// File doc for foo.go part 3. + +// Func doc for comment.Foo +func Foo(foo string, _ int) { + myfmt.Printf("%s", foo) +} + +// Random comment floating around. + +-- comment.go -- +package main + +import( + "example.com/comment" +) + +func testComment() { //@loc(comment, "test") + comment.Foo("", 0) + function.Foo(0, "") +} + +//@mcptool("context", `{}`, comment, output=withComment) + +-- @withComment -- +Code blocks are delimited by --->...<--- markers. + +Current file "comment.go" contains this import declaration: +---> +import( + "example.com/comment" +) +<--- + +The imported packages declare the following symbols: + +example.com/comment (package comment) +doc.go: +---> +/* +Package doc for package comment. +*/ +package comment +<--- + +foo.go: +---> +// File doc for foo.go part 1. +package comment + +// File doc for foo.go part 2. +import ( + // comment for package renaming. + myfmt "fmt" +) + +// File doc for foo.go part 3. + +// Func doc for comment.Foo +func Foo(foo string, _ int) +<--- + +-- function/foo.go -- +package function + +func Foo(int, string) {} + +func foo(string, int) {} + +type unexported struct {} + +func (*unexported) unexported(int) {} + +func (*unexported) Exported(int) {} + +type Exported struct{} + +func (*Exported) unexported(int) {} + +func (*Exported) Exported(int) {} + +-- function.go -- +package main + +import( + "example.com/function" +) + +func testFunction() { //@loc(function, "test") + function.Foo(0, "") +} + +//@mcptool("context", `{}`, function, output=withFunction) + +-- @withFunction -- +Code blocks are delimited by --->...<--- markers. + +Current file "function.go" contains this import declaration: +---> +import( + "example.com/function" +) +<--- + +The imported packages declare the following symbols: + +example.com/function (package function) +foo.go: +---> +package function + +func Foo(int, string) + +func (*Exported) Exported(int) +<--- + diff --git a/gopls/internal/test/marker/testdata/mcptools/hello_world.txt b/gopls/internal/test/marker/testdata/mcptools/hello_world.txt deleted file mode 100644 index 8ae6f745565..00000000000 --- a/gopls/internal/test/marker/testdata/mcptools/hello_world.txt +++ /dev/null @@ -1,17 +0,0 @@ -This test exercises mcp tool hello_world. - --- flags -- --mcp - --- go.mod -- -module golang.org/mcptests/mcptools - --- mcp/tools/helloworld.go -- -package helloworld - -func A() {} //@loc(loc, "A") - -//@mcptool("hello_world", `{"name": "jerry"}`, loc, output=hello) - --- @hello -- -Hi jerry, current file helloworld.go. From 8fbc77367e116d3f2b7fdd46d7f4f39537a12762 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 2 Jun 2025 14:44:07 -0400 Subject: [PATCH 172/196] gopls/internal/analysis/modernize: add TODOs for nilness problem append(x, y...) and slices.Concat(x, y) differ in the nilness of their results when x and y are empty. This may lead to unsound transformations. I don't know how to realistically fix this bug short of turning off this modernizer. For now, document the problem is more detail. Updates golang/go#73557 Change-Id: I30e9dceb93d11e6632b30eee28b7b5a2518722eb Reviewed-on: https://go-review.googlesource.com/c/tools/+/678135 Reviewed-by: Robert Findley Auto-Submit: Alan Donovan LUCI-TryBot-Result: Go LUCI --- gopls/internal/analysis/modernize/slices.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/gopls/internal/analysis/modernize/slices.go b/gopls/internal/analysis/modernize/slices.go index 18e02d51ebf..2725cb6d8a4 100644 --- a/gopls/internal/analysis/modernize/slices.go +++ b/gopls/internal/analysis/modernize/slices.go @@ -44,7 +44,7 @@ import ( // append([]string(nil), os.Environ()...) -> os.Environ() // // The fix does not always preserve nilness the of base slice when the -// addends (a, b, c) are all empty. +// addends (a, b, c) are all empty (see #73557). func appendclipped(pass *analysis.Pass) { // Skip the analyzer in packages where its // fixes would create an import cycle. @@ -67,6 +67,10 @@ func appendclipped(pass *analysis.Pass) { // If the (clipped) base is empty, it may be safely ignored. // Otherwise treat it (or its unclipped subexpression, if possible) // as just another arg (the first) to Concat. + // + // TODO(adonovan): not so fast! If all the operands + // are empty, then the nilness of base matters, because + // append preserves nilness whereas Concat does not (#73557). if !empty { sliceArgs = append(sliceArgs, clipped) } @@ -118,6 +122,9 @@ func appendclipped(pass *analysis.Pass) { "slices") // append(zerocap, s...) -> slices.Clone(s) or bytes.Clone(s) + // + // This is unsound if s is empty and its nilness + // differs from zerocap (#73557). _, prefix, importEdits := analysisinternal.AddImport(info, file, clonepkg, clonepkg, "Clone", call.Pos()) message := fmt.Sprintf("Replace append with %s.Clone", clonepkg) pass.Report(analysis.Diagnostic{ @@ -138,6 +145,8 @@ func appendclipped(pass *analysis.Pass) { } // append(append(append(base, a...), b..., c...) -> slices.Concat(base, a, b, c) + // + // This is unsound if all slices are empty and base is non-nil (#73557). _, prefix, importEdits := analysisinternal.AddImport(info, file, "slices", "slices", "Concat", call.Pos()) pass.Report(analysis.Diagnostic{ Pos: call.Pos(), @@ -200,7 +209,7 @@ func appendclipped(pass *analysis.Pass) { // The value of res is either the same as e or is a subexpression of e // that denotes the same slice but without the clipping operation. // -// In addition, it reports whether the slice is definitely empty, +// In addition, it reports whether the slice is definitely empty. // // Examples of clipped slices: // From 2246f6dc306a3d42cc9d5a1af844dc755a78128a Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sun, 25 May 2025 07:56:36 -0400 Subject: [PATCH 173/196] internal/mcp: middleware on the sending side Implement sending middleware. Sending middleware is useful for adding tracing, progress tokens, and so on. Change-Id: I0acab2550f0212d549c4adbcf1ebf425a9486df6 Reviewed-on: https://go-review.googlesource.com/c/tools/+/676355 Reviewed-by: Sam Thanawalla Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/client.go | 100 +++++++++++++++++--------- internal/mcp/design/design.md | 10 +-- internal/mcp/example_progress_test.go | 29 ++++++++ internal/mcp/mcp_test.go | 76 ++++++++++++-------- internal/mcp/server.go | 93 +++++++++++++++--------- internal/mcp/shared.go | 83 +++++++++++++++------ internal/mcp/transport.go | 4 +- 7 files changed, 274 insertions(+), 121 deletions(-) create mode 100644 internal/mcp/example_progress_test.go diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 068a9e3877e..8f47f6fd420 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -18,13 +18,14 @@ import ( // A Client is an MCP client, which may be connected to an MCP server // using the [Client.Connect] method. type Client struct { - name string - version string - opts ClientOptions - mu sync.Mutex - roots *featureSet[*Root] - sessions []*ClientSession - methodHandler_ MethodHandler[*ClientSession] + name string + version string + opts ClientOptions + mu sync.Mutex + roots *featureSet[*Root] + sessions []*ClientSession + sendingMethodHandler_ MethodHandler[*ClientSession] + receivingMethodHandler_ MethodHandler[*ClientSession] } // NewClient creates a new Client. @@ -34,10 +35,11 @@ type Client struct { // If non-nil, the provided options configure the Client. func NewClient(name, version string, opts *ClientOptions) *Client { c := &Client{ - name: name, - version: version, - roots: newFeatureSet(func(r *Root) string { return r.URI }), - methodHandler_: defaultMethodHandler[*ClientSession], + name: name, + version: version, + roots: newFeatureSet(func(r *Root) string { return r.URI }), + sendingMethodHandler_: defaultSendingMethodHandler[*ClientSession], + receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], } if opts != nil { c.opts = *opts @@ -103,11 +105,13 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e ClientInfo: &implementation{Name: c.name, Version: c.version}, Capabilities: caps, } - if err := call(ctx, cs.conn, "initialize", params, &cs.initializeResult); err != nil { + res, err := handleSend[*InitializeResult](ctx, cs, methodInitialize, params) + if err != nil { _ = cs.Close() return nil, err } - if err := cs.conn.Notify(ctx, notificationInitialized, &InitializedParams{}); err != nil { + cs.initializeResult = res + if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil { _ = cs.Close() return nil, err } @@ -190,16 +194,34 @@ func (c *Client) createMessage(ctx context.Context, cs *ClientSession, params *C return c.opts.CreateMessageHandler(ctx, cs, params) } -// AddMiddleware wraps the client's current method handler using the provided -// middleware. Middleware is applied from right to left, so that the first one -// is executed first. +// AddSendingMiddleware wraps the current sending method handler using the provided +// middleware. Middleware is applied from right to left, so that the first one is +// executed first. // -// For example, AddMiddleware(m1, m2, m3) augments the client method handler as +// For example, AddSendingMiddleware(m1, m2, m3) augments the method handler as // m1(m2(m3(handler))). -func (c *Client) AddMiddleware(middleware ...Middleware[*ClientSession]) { +// +// Sending middleware is called when a request is sent. It is useful for tasks +// such as tracing, metrics, and adding progress tokens. +func (c *Client) AddSendingMiddleware(middleware ...Middleware[*ClientSession]) { c.mu.Lock() defer c.mu.Unlock() - addMiddleware(&c.methodHandler_, middleware) + addMiddleware(&c.sendingMethodHandler_, middleware) +} + +// AddReceivingMiddleware wraps the current receiving method handler using +// the provided middleware. Middleware is applied from right to left, so that the +// first one is executed first. +// +// For example, AddReceivingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Receiving middleware is called when a request is received. It is useful for tasks +// such as authentication, request logging and metrics. +func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession]) { + c.mu.Lock() + defer c.mu.Unlock() + addMiddleware(&c.receivingMethodHandler_, middleware) } // clientMethodInfos maps from the RPC method name to serverMethodInfos. @@ -213,51 +235,62 @@ var clientMethodInfos = map[string]methodInfo{ notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)), } -func (cs *ClientSession) methodInfos() map[string]methodInfo { +func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { + return serverMethodInfos +} + +func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { return clientMethodInfos } func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { - return handleRequest(ctx, req, cs) + return handleReceive(ctx, cs, req) +} + +func (cs *ClientSession) sendingMethodHandler() methodHandler { + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.sendingMethodHandler_ } -func (cs *ClientSession) methodHandler() methodHandler { +func (cs *ClientSession) receivingMethodHandler() methodHandler { cs.client.mu.Lock() defer cs.client.mu.Unlock() - return cs.client.methodHandler_ + return cs.client.receivingMethodHandler_ } // getConn implements [session.getConn]. func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn } -func (cs *ClientSession) ping(ct context.Context, params *PingParams) (Result, error) { - return emptyResult{}, nil +func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) { + return &emptyResult{}, nil } // Ping makes an MCP "ping" request to the server. func (cs *ClientSession) Ping(ctx context.Context, params *PingParams) error { - return call(ctx, cs.conn, methodPing, params, nil) + _, err := handleSend[*emptyResult](ctx, cs, methodPing, params) + return err } // ListPrompts lists prompts that are currently available on the server. func (cs *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { - return standardCall[ListPromptsResult](ctx, cs.conn, methodListPrompts, params) + return handleSend[*ListPromptsResult](ctx, cs, methodListPrompts, params) } // GetPrompt gets a prompt from the server. func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { - return standardCall[GetPromptResult](ctx, cs.conn, methodGetPrompt, params) + return handleSend[*GetPromptResult](ctx, cs, methodGetPrompt, params) } // ListTools lists tools that are currently available on the server. func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { - return standardCall[ListToolsResult](ctx, cs.conn, methodListTools, params) + return handleSend[*ListToolsResult](ctx, cs, methodListTools, params) } // CallTool calls the tool with the given name and arguments. // Pass a [CallToolOptions] to provide additional request fields. func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams[json.RawMessage]) (*CallToolResult, error) { - return standardCall[CallToolResult](ctx, cs.conn, methodCallTool, params) + return handleSend[*CallToolResult](ctx, cs, methodCallTool, params) } // CallTool is a helper to call a tool with any argument type. It returns an @@ -286,17 +319,18 @@ func toWireParams[TArgs any](params *CallToolParams[TArgs]) (*CallToolParams[jso } func (cs *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error { - return call(ctx, cs.conn, methodSetLevel, params, nil) + _, err := handleSend[*emptyResult](ctx, cs, methodSetLevel, params) + return err } // ListResources lists the resources that are currently available on the server. func (cs *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { - return standardCall[ListResourcesResult](ctx, cs.conn, methodListResources, params) + return handleSend[*ListResourcesResult](ctx, cs, methodListResources, params) } // ReadResource ask the server to read a resource and return its contents. func (cs *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { - return standardCall[ReadResourceResult](ctx, cs.conn, methodReadResource, params) + return handleSend[*ReadResourceResult](ctx, cs, methodReadResource, params) } func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (Result, error) { diff --git a/internal/mcp/design/design.md b/internal/mcp/design/design.md index 0ed19b2cd5d..87adfebf322 100644 --- a/internal/mcp/design/design.md +++ b/internal/mcp/design/design.md @@ -408,7 +408,7 @@ func (*ClientSession) ResourceTemplates(context.Context, *ListResourceTemplatesP ### Middleware -We provide a mechanism to add MCP-level middleware on the both the client and server side, which runs after the request has been parsed but before any normal handling. +We provide a mechanism to add MCP-level middleware on the both the client and server side. Receiving middleware runs after the request has been parsed but before any normal handling. It is analogous to traditional HTTP server middleware. Sending middleware runs after a call to a method but before the request is sent. It is an alternative to transport middleware that exposes MCP types instead of raw JSON-RPC 2.0 messages. It is useful for tracing and setting progress tokens, for example. ```go // A MethodHandler handles MCP messages. @@ -428,8 +428,10 @@ type Middleware[S Session] func(MethodHandler[S]) MethodHandler[S] // // For example, AddMiddleware(m1, m2, m3) augments the server method handler as // m1(m2(m3(handler))). -func (c *Client) AddMiddleware(middleware ...Middleware[*ClientSession]) -func (s *Server) AddMiddleware(middleware ...Middleware[*ServerSession]) +func (c *Client) AddSendingMiddleware(middleware ...Middleware[*ClientSession]) +func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession]) +func (s *Server) AddSendingMiddleware(middleware ...Middleware[*ServerSession]) +func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession]) ``` As an example, this code adds server-side logging: @@ -443,7 +445,7 @@ func withLogging(h mcp.MethodHandler[*ServerSession]) mcp.MethodHandler[*ServerS } } -server.AddMiddleware(withLogging) +server.AddReceivingMiddleware(withLogging) ``` **Differences from mcp-go**: Version 0.26.0 of mcp-go defines 24 server hooks. Each hook consists of a field in the `Hooks` struct, a `Hooks.Add` method, and a type for the hook function. These are rarely used. The most common is `OnError`, which occurs fewer than ten times in open-source code. diff --git a/internal/mcp/example_progress_test.go b/internal/mcp/example_progress_test.go new file mode 100644 index 00000000000..ab281aabf11 --- /dev/null +++ b/internal/mcp/example_progress_test.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "sync/atomic" + + "golang.org/x/tools/internal/mcp" +) + +var nextProgressToken atomic.Int64 + +// This middleware function adds a progress token to every outgoing request +// from the client. +func Example_progressMiddleware() { + c := mcp.NewClient("test", "v1", nil) + c.AddSendingMiddleware(addProgressToken[*mcp.ClientSession]) + _ = c +} + +func addProgressToken[S mcp.Session](h mcp.MethodHandler[S]) mcp.MethodHandler[S] { + return func(ctx context.Context, s S, method string, params mcp.Params) (result mcp.Result, err error) { + params.GetMeta().ProgressToken = nextProgressToken.Add(1) + return h(ctx, s, method, params) + } +} diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index c8394e10c2f..96a3dd7269c 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -117,7 +117,7 @@ func TestEndToEnd(t *testing.T) { t.Run("prompts", func(t *testing.T) { res, err := cs.ListPrompts(ctx, nil) if err != nil { - t.Errorf("prompts/list failed: %v", err) + t.Fatalf("prompts/list failed: %v", err) } wantPrompts := []*Prompt{ { @@ -634,10 +634,12 @@ func TestMiddleware(t *testing.T) { // "1" is the outer middleware layer, called first; then "2" is called, and finally // the default dispatcher. - s.AddMiddleware(traceCalls[*ServerSession](&sbuf, "1"), traceCalls[*ServerSession](&sbuf, "2")) + s.AddSendingMiddleware(traceCalls[*ServerSession](&sbuf, "S1"), traceCalls[*ServerSession](&sbuf, "S2")) + s.AddReceivingMiddleware(traceCalls[*ServerSession](&sbuf, "R1"), traceCalls[*ServerSession](&sbuf, "R2")) c := NewClient("testClient", "v1.0.0", nil) - c.AddMiddleware(traceCalls[*ClientSession](&cbuf, "1"), traceCalls[*ClientSession](&cbuf, "2")) + c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2")) + c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2")) cs, err := c.Connect(ctx, ct) if err != nil { @@ -646,34 +648,52 @@ func TestMiddleware(t *testing.T) { if _, err := cs.ListTools(ctx, nil); err != nil { t.Fatal(err) } - want := ` -1 >initialize -2 >initialize -2 notifications/initialized -2 >notifications/initialized -2 tools/list -2 >tools/list -2 roots/list -2 >roots/list -2 initialize +R2 >initialize +R2 notifications/initialized +R2 >notifications/initialized +R2 tools/list +R2 >tools/list +R2 roots/list +S2 >roots/list +S2 initialize +S2 >initialize +S2 notifications/initialized +S2 >notifications/initialized +S2 tools/list +S2 >tools/list +S2 roots/list +R2 >roots/list +R2 Date: Mon, 2 Jun 2025 14:59:28 -0400 Subject: [PATCH 174/196] gopls/internal/analysis/modernize: appendclipped: preserve result type This CL prevents appendclipped from offering a fix that might change the result type of the expression. For example: type ( S1 int[]; S2 int[] ) var a S1 -var b S2 = append([]int{}, a...) // ok +var b S2 = slices.Clone(a) // error: cannot assign... (A better fix would be to insert an explicit instantiation slices.Clone[S2], but that's non-trivial.) Thanks to Ethan Reesor and Xie Yuchen for analyzing the problem; see CL 671915 and CL 671975. Fixes golang/go#73661 Change-Id: I90d60168f2e1746854a7939238201b0f5861212d Reviewed-on: https://go-review.googlesource.com/c/tools/+/678116 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- gopls/internal/analysis/modernize/slices.go | 25 +++++++++++++++++++ .../src/appendclipped/appendclipped.go | 11 +++++++- .../src/appendclipped/appendclipped.go.golden | 11 +++++++- gopls/internal/protocol/uri.go | 2 +- 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/gopls/internal/analysis/modernize/slices.go b/gopls/internal/analysis/modernize/slices.go index 2725cb6d8a4..f7461a043b5 100644 --- a/gopls/internal/analysis/modernize/slices.go +++ b/gopls/internal/analysis/modernize/slices.go @@ -64,6 +64,31 @@ func appendclipped(pass *analysis.Pass) { return } + // If any slice arg has a different type from the base + // (and thus the result) don't offer a fix, to avoid + // changing the return type, e.g: + // + // type S []int + // - x := append([]int(nil), S{}...) // x : []int + // + x := slices.Clone(S{}) // x : S + // + // We could do better by inserting an explicit generic + // instantiation: + // + // x := slices.Clone[[]int](S{}) + // + // but this is often unnecessary and unwanted, such as + // when the value is used an in assignment context that + // provides an explicit type: + // + // var x []int = slices.Clone(S{}) + baseType := info.TypeOf(base) + for _, arg := range sliceArgs { + if !types.Identical(info.TypeOf(arg), baseType) { + return + } + } + // If the (clipped) base is empty, it may be safely ignored. // Otherwise treat it (or its unclipped subexpression, if possible) // as just another arg (the first) to Concat. diff --git a/gopls/internal/analysis/modernize/testdata/src/appendclipped/appendclipped.go b/gopls/internal/analysis/modernize/testdata/src/appendclipped/appendclipped.go index c4e98535a37..6c1ae3eca37 100644 --- a/gopls/internal/analysis/modernize/testdata/src/appendclipped/appendclipped.go +++ b/gopls/internal/analysis/modernize/testdata/src/appendclipped/appendclipped.go @@ -5,7 +5,10 @@ import ( "slices" ) -type Bytes []byte +type ( + Bytes []byte + Bytes2 []byte +) func _(s, other []string) { print(append([]string{}, s...)) // want "Replace append with slices.Clone" @@ -24,3 +27,9 @@ func _(s, other []string) { print(append(append(slices.Clip(other), s...), other...)) // want "Replace append with slices.Concat" print(append(append(append(other[:0], s...), other...), other...)) // nope: intent may be to mutate other } + +var ( + _ Bytes = append(Bytes(nil), []byte(nil)...) // nope: correct fix requires Clone[Bytes] (#73661) + _ Bytes = append([]byte(nil), Bytes(nil)...) // nope: correct fix requires Clone[Bytes] (#73661) + _ Bytes2 = append([]byte(nil), Bytes(nil)...) // nope: correct fix requires Clone[Bytes2] (#73661) +) diff --git a/gopls/internal/analysis/modernize/testdata/src/appendclipped/appendclipped.go.golden b/gopls/internal/analysis/modernize/testdata/src/appendclipped/appendclipped.go.golden index 6352d525b34..173582f25fb 100644 --- a/gopls/internal/analysis/modernize/testdata/src/appendclipped/appendclipped.go.golden +++ b/gopls/internal/analysis/modernize/testdata/src/appendclipped/appendclipped.go.golden @@ -5,7 +5,10 @@ import ( "slices" ) -type Bytes []byte +type ( + Bytes []byte + Bytes2 []byte +) func _(s, other []string) { print(slices.Clone(s)) // want "Replace append with slices.Clone" @@ -24,3 +27,9 @@ func _(s, other []string) { print(slices.Concat(other, s, other)) // want "Replace append with slices.Concat" print(append(append(append(other[:0], s...), other...), other...)) // nope: intent may be to mutate other } + +var ( + _ Bytes = append(Bytes(nil), []byte(nil)...) // nope: correct fix requires Clone[Bytes] (#73661) + _ Bytes = append([]byte(nil), Bytes(nil)...) // nope: correct fix requires Clone[Bytes] (#73661) + _ Bytes2 = append([]byte(nil), Bytes(nil)...) // nope: correct fix requires Clone[Bytes2] (#73661) +) diff --git a/gopls/internal/protocol/uri.go b/gopls/internal/protocol/uri.go index 5d00009b30d..361bc441cfe 100644 --- a/gopls/internal/protocol/uri.go +++ b/gopls/internal/protocol/uri.go @@ -110,7 +110,7 @@ func (uri DocumentURI) Encloses(file DocumentURI) bool { return pathutil.InDir(uri.Path(), file.Path()) } -// Locationr returns the Location for the specified range of this URI's file. +// Location returns the Location for the specified range of this URI's file. func (uri DocumentURI) Location(rng Range) Location { return Location{URI: uri, Range: rng} } From f12067dbd3714639176c64d2be0127493fced258 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 2 Jun 2025 16:18:00 -0400 Subject: [PATCH 175/196] gopls/internal/analysis/modernize: disable appendclipped This analyzer is unsound w.r.t. niless of the result. Sadly the problem is impractical to fix. This CL disables the analyzer for now. Fixes golang/go#73557 Change-Id: Ide38a02747439cca7383604dce3c81efa3f69b25 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678117 Auto-Submit: Alan Donovan Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- gopls/doc/analyzers.md | 3 --- gopls/internal/analysis/modernize/doc.go | 3 --- .../analysis/modernize/modernize_test.go | 1 + gopls/internal/analysis/modernize/slices.go | 24 ++++++++++++++----- gopls/internal/doc/api.json | 4 ++-- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md index c2bb5a6ad4f..ca793a4a885 100644 --- a/gopls/doc/analyzers.md +++ b/gopls/doc/analyzers.md @@ -3845,9 +3845,6 @@ Categories of modernize diagnostic: - efaceany: replace interface{} by the 'any' type added in go1.18. - - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or - slices.Concat(s), added in go1.21. - - mapsloop: replace a loop around an m[k]=v map update by a call to one of the Collect, Copy, Clone, or Insert functions from the maps package, added in go1.21. diff --git a/gopls/internal/analysis/modernize/doc.go b/gopls/internal/analysis/modernize/doc.go index e136807089f..62a8c6df309 100644 --- a/gopls/internal/analysis/modernize/doc.go +++ b/gopls/internal/analysis/modernize/doc.go @@ -63,9 +63,6 @@ // // - efaceany: replace interface{} by the 'any' type added in go1.18. // -// - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or -// slices.Concat(s), added in go1.21. -// // - mapsloop: replace a loop around an m[k]=v map update by a call // to one of the Collect, Copy, Clone, or Insert functions from // the maps package, added in go1.21. diff --git a/gopls/internal/analysis/modernize/modernize_test.go b/gopls/internal/analysis/modernize/modernize_test.go index 7ef77f16bce..833ff898e35 100644 --- a/gopls/internal/analysis/modernize/modernize_test.go +++ b/gopls/internal/analysis/modernize/modernize_test.go @@ -13,6 +13,7 @@ import ( func Test(t *testing.T) { modernize.EnableSlicesDelete = true + modernize.EnableAppendClipped = true analysistest.RunWithSuggestedFixes(t, analysistest.TestData(), modernize.Analyzer, "appendclipped", diff --git a/gopls/internal/analysis/modernize/slices.go b/gopls/internal/analysis/modernize/slices.go index f7461a043b5..1cefb2df51a 100644 --- a/gopls/internal/analysis/modernize/slices.go +++ b/gopls/internal/analysis/modernize/slices.go @@ -4,9 +4,6 @@ package modernize -// This file defines modernizers that use the "slices" package. -// TODO(adonovan): actually let's split them up and rename this file. - import ( "fmt" "go/ast" @@ -21,6 +18,17 @@ import ( "golang.org/x/tools/internal/analysisinternal" ) +// append(clipped, ...) cannot be replaced by slices.Concat (etc) +// without more attention to preservation of nilness; see #73557. +// Until we either fix it or revise our safety goals, we disable this +// analyzer for now. +// +// Its former documentation in doc.go was: +// +// - appendclipped: replace append([]T(nil), s...) by +// slices.Clone(s) or slices.Concat(s), added in go1.21. +var EnableAppendClipped = false + // The appendclipped pass offers to simplify a tower of append calls: // // append(append(append(base, a...), b..., c...) @@ -46,6 +54,10 @@ import ( // The fix does not always preserve nilness the of base slice when the // addends (a, b, c) are all empty (see #73557). func appendclipped(pass *analysis.Pass) { + if !EnableAppendClipped { + return + } + // Skip the analyzer in packages where its // fixes would create an import cycle. if within(pass, "slices", "bytes", "runtime") { @@ -115,7 +127,7 @@ func appendclipped(pass *analysis.Pass) { pass.Report(analysis.Diagnostic{ Pos: call.Pos(), End: call.End(), - Category: "slicesclone", + Category: "appendclipped", Message: "Redundant clone of os.Environ()", SuggestedFixes: []analysis.SuggestedFix{{ Message: "Eliminate redundant clone", @@ -155,7 +167,7 @@ func appendclipped(pass *analysis.Pass) { pass.Report(analysis.Diagnostic{ Pos: call.Pos(), End: call.End(), - Category: "slicesclone", + Category: "appendclipped", Message: message, SuggestedFixes: []analysis.SuggestedFix{{ Message: message, @@ -176,7 +188,7 @@ func appendclipped(pass *analysis.Pass) { pass.Report(analysis.Diagnostic{ Pos: call.Pos(), End: call.End(), - Category: "slicesclone", + Category: "appendclipped", Message: "Replace append with slices.Concat", SuggestedFixes: []analysis.SuggestedFix{{ Message: "Replace append with slices.Concat", diff --git a/gopls/internal/doc/api.json b/gopls/internal/doc/api.json index a77b64c473a..13b2dc91724 100644 --- a/gopls/internal/doc/api.json +++ b/gopls/internal/doc/api.json @@ -1498,7 +1498,7 @@ }, { "Name": "\"modernize\"", - "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n B.Loop intentionally defeats compiler optimizations such as\n inlining so that the benchmark is not entirely optimized away.\n Currently, however, it may cause benchmarks to become slower\n in some cases due to increased allocation; see\n https://go.dev/issue/73137.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", + "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n B.Loop intentionally defeats compiler optimizations such as\n inlining so that the benchmark is not entirely optimized away.\n Currently, however, it may cause benchmarks to become slower\n in some cases due to increased allocation; see\n https://go.dev/issue/73137.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", "Default": "true", "Status": "" }, @@ -3230,7 +3230,7 @@ }, { "Name": "modernize", - "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - slicesclone: replace append([]T(nil), s...) by slices.Clone(s) or\n slices.Concat(s), added in go1.21.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n B.Loop intentionally defeats compiler optimizations such as\n inlining so that the benchmark is not entirely optimized away.\n Currently, however, it may cause benchmarks to become slower\n in some cases due to increased allocation; see\n https://go.dev/issue/73137.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", + "Doc": "simplify code by using modern constructs\n\nThis analyzer reports opportunities for simplifying and clarifying\nexisting code by using more modern features of Go and its standard\nlibrary.\n\nEach diagnostic provides a fix. Our intent is that these fixes may\nbe safely applied en masse without changing the behavior of your\nprogram. In some cases the suggested fixes are imperfect and may\nlead to (for example) unused imports or unused local variables,\ncausing build breakage. However, these problems are generally\ntrivial to fix. We regard any modernizer whose fix changes program\nbehavior to have a serious bug and will endeavor to fix it.\n\nTo apply all modernization fixes en masse, you can use the\nfollowing command:\n\n\t$ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...\n\n(Do not use \"go get -tool\" to add gopls as a dependency of your\nmodule; gopls commands must be built from their release branch.)\n\nIf the tool warns of conflicting fixes, you may need to run it more\nthan once until it has applied all fixes cleanly. This command is\nnot an officially supported interface and may change in the future.\n\nChanges produced by this tool should be reviewed as usual before\nbeing merged. In some cases, a loop may be replaced by a simple\nfunction call, causing comments within the loop to be discarded.\nHuman judgment may be required to avoid losing comments of value.\n\nEach diagnostic reported by modernize has a specific category. (The\ncategories are listed below.) Diagnostics in some categories, such\nas \"efaceany\" (which replaces \"interface{}\" with \"any\" where it is\nsafe to do so) are particularly numerous. It may ease the burden of\ncode review to apply fixes in two passes, the first change\nconsisting only of fixes of category \"efaceany\", the second\nconsisting of all others. This can be achieved using the -category flag:\n\n\t$ modernize -category=efaceany -fix -test ./...\n\t$ modernize -category=-efaceany -fix -test ./...\n\nCategories of modernize diagnostic:\n\n - forvar: remove x := x variable declarations made unnecessary by the new semantics of loops in go1.22.\n\n - slicescontains: replace 'for i, elem := range s { if elem == needle { ...; break }'\n by a call to slices.Contains, added in go1.21.\n\n - minmax: replace an if/else conditional assignment by a call to\n the built-in min or max functions added in go1.21.\n\n - sortslice: replace sort.Slice(x, func(i, j int) bool) { return s[i] \u003c s[j] }\n by a call to slices.Sort(s), added in go1.21.\n\n - efaceany: replace interface{} by the 'any' type added in go1.18.\n\n - mapsloop: replace a loop around an m[k]=v map update by a call\n to one of the Collect, Copy, Clone, or Insert functions from\n the maps package, added in go1.21.\n\n - fmtappendf: replace []byte(fmt.Sprintf...) by fmt.Appendf(nil, ...),\n added in go1.19.\n\n - testingcontext: replace uses of context.WithCancel in tests\n with t.Context, added in go1.24.\n\n - omitzero: replace omitempty by omitzero on structs, added in go1.24.\n\n - bloop: replace \"for i := range b.N\" or \"for range b.N\" in a\n benchmark with \"for b.Loop()\", and remove any preceding calls\n to b.StopTimer, b.StartTimer, and b.ResetTimer.\n\n B.Loop intentionally defeats compiler optimizations such as\n inlining so that the benchmark is not entirely optimized away.\n Currently, however, it may cause benchmarks to become slower\n in some cases due to increased allocation; see\n https://go.dev/issue/73137.\n\n - rangeint: replace a 3-clause \"for i := 0; i \u003c n; i++\" loop by\n \"for i := range n\", added in go1.22.\n\n - stringsseq: replace Split in \"for range strings.Split(...)\" by go1.24's\n more efficient SplitSeq, or Fields with FieldSeq.\n\n - stringscutprefix: replace some uses of HasPrefix followed by TrimPrefix with CutPrefix,\n added to the strings package in go1.20.\n\n - waitgroup: replace old complex usages of sync.WaitGroup by less complex WaitGroup.Go method in go1.25.", "URL": "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/modernize", "Default": true }, From ec7b2b34cbdece4f23b7838b0d014bcff468252e Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Wed, 28 May 2025 10:44:30 -0400 Subject: [PATCH 176/196] gopls/internal/mcp: add exported type spec and value spec to context For importing pacakges, only add type spec and value spec whose ident is exported. When adding type spec and value spec, includes doc comments, line comments and all the floating comments within the spec. Preserve the indentation by inserting the content between the start position and the preceding new line. For golang/go#73580 Change-Id: Ib927cdbf17649f844f3b725b19edf055f29a471d Reviewed-on: https://go-review.googlesource.com/c/tools/+/676796 Reviewed-by: Alan Donovan Auto-Submit: Hongxiang Jiang LUCI-TryBot-Result: Go LUCI --- gopls/internal/cache/parsego/file.go | 21 ++ gopls/internal/golang/extract.go | 18 +- gopls/internal/golang/undeclared.go | 2 +- gopls/internal/mcp/context.go | 116 ++++++++++- .../test/marker/testdata/mcptools/context.txt | 183 ++++++++++++++++++ 5 files changed, 317 insertions(+), 23 deletions(-) diff --git a/gopls/internal/cache/parsego/file.go b/gopls/internal/cache/parsego/file.go index 7254e1f4621..68ea0d0e4c2 100644 --- a/gopls/internal/cache/parsego/file.go +++ b/gopls/internal/cache/parsego/file.go @@ -10,6 +10,7 @@ import ( "go/scanner" "go/token" "sync" + "unicode" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/protocol" @@ -163,3 +164,23 @@ func (pgf *File) Resolve() { resolveFile(pgf.File, pgf.Tok, declErr) }) } + +// Indentation returns the string of spaces representing the indentation +// of the line containing the specified position. +// This can be used to ensure that inserted code maintains consistent indentation +// and column alignment. +func (pgf *File) Indentation(pos token.Pos) (string, error) { + line := safetoken.Line(pgf.Tok, pos) + start, end, err := safetoken.Offsets(pgf.Tok, pgf.Tok.LineStart(line), pos) + if err != nil { + return "", err + } + + s := string(pgf.Src[start:end]) + for i, r := range s { + if !unicode.IsSpace(r) { + return s[:i], nil // prefix of spaces + } + } + return s, nil +} diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 37c3352a68f..91bea65a1f2 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -54,7 +54,6 @@ func extractVariable(pkg *cache.Package, pgf *parsego.File, start, end token.Pos file = pgf.File ) // TODO(adonovan): simplify, using Cursor. - tokFile := fset.File(file.FileStart) exprs, err := canExtractVariable(info, pgf.Cursor, start, end, all) if err != nil { return nil, nil, fmt.Errorf("cannot extract: %v", err) @@ -164,7 +163,7 @@ Outer: return nil, nil, fmt.Errorf("cannot find location to insert extraction: %v", err) } // Within function: compute appropriate statement indentation. - indent, err := calculateIndentation(pgf.Src, tokFile, before) + indent, err := pgf.Indentation(before.Pos()) if err != nil { return nil, nil, err } @@ -506,19 +505,6 @@ func canExtractVariable(info *types.Info, curFile inspector.Cursor, start, end t return exprs, nil } -// Calculate indentation for insertion. -// When inserting lines of code, we must ensure that the lines have consistent -// formatting (i.e. the proper indentation). To do so, we observe the indentation on the -// line of code on which the insertion occurs. -func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) (string, error) { - line := safetoken.Line(tok, insertBeforeStmt.Pos()) - lineOffset, stmtOffset, err := safetoken.Offsets(tok, tok.LineStart(line), insertBeforeStmt.Pos()) - if err != nil { - return "", err - } - return string(content[lineOffset:stmtOffset]), nil -} - // freshName returns an identifier based on prefix (perhaps with a // numeric suffix) that is not in scope at the specified position // within the file. It returns the next numeric suffix to use. @@ -1193,7 +1179,7 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to } before := src[outerStart:startOffset] after := src[endOffset:outerEnd] - indent, err := calculateIndentation(src, tok, node) + indent, err := pgf.Indentation(node.Pos()) if err != nil { return nil, nil, err } diff --git a/gopls/internal/golang/undeclared.go b/gopls/internal/golang/undeclared.go index 515da9bd891..63cf06e6943 100644 --- a/gopls/internal/golang/undeclared.go +++ b/gopls/internal/golang/undeclared.go @@ -137,7 +137,7 @@ func createUndeclared(pkg *cache.Package, pgf *parsego.File, start, end token.Po if err != nil { return nil, nil, fmt.Errorf("could not locate insertion point: %v", err) } - indent, err := calculateIndentation(pgf.Src, fset.File(file.FileStart), insertBeforeStmt) + indent, err := pgf.Indentation(insertBeforeStmt.Pos()) if err != nil { return nil, nil, err } diff --git a/gopls/internal/mcp/context.go b/gopls/internal/mcp/context.go index 9cac739c7a3..f1849db65f1 100644 --- a/gopls/internal/mcp/context.go +++ b/gopls/internal/mcp/context.go @@ -4,8 +4,8 @@ package mcp -// This file defines the "context" operation, which -// returns a summary of the specified package. +// This file defines the "context" operation, which returns a summary of the +// specified package. import ( "bytes" @@ -14,6 +14,7 @@ import ( "go/ast" "go/token" "path/filepath" + "slices" "strings" "golang.org/x/tools/gopls/internal/cache" @@ -165,11 +166,10 @@ func writePackageSummary(ctx context.Context, snapshot *cache.Snapshot, pkg *cac } out.Write(bytes.TrimSpace(text)) - out.WriteString("\n") + out.WriteString("\n\n") } // Write exported func decl and gen decl. - // TODO(hxjiang): write exported gen decl. for _, decl := range pgf.File.Decls { switch decl := decl.(type) { case *ast.FuncDecl: @@ -184,7 +184,6 @@ func writePackageSummary(ctx context.Context, snapshot *cache.Snapshot, pkg *cac } } - out.WriteString("\n") // Write doc comment and func signature. startPos := decl.Pos() if decl.Doc != nil { @@ -197,7 +196,112 @@ func writePackageSummary(ctx context.Context, snapshot *cache.Snapshot, pkg *cac } out.Write(text) - out.WriteString("\n") + out.WriteString("\n\n") + + case *ast.GenDecl: + if decl.Tok == token.IMPORT { + continue + } + + var buf bytes.Buffer + if decl.Doc != nil { + text, err := pgf.NodeText(decl.Doc) + if err != nil { + return err + } + buf.Write(text) + buf.WriteString("\n") + } + + buf.WriteString(decl.Tok.String() + " ") + if decl.Lparen.IsValid() { + buf.WriteString("(\n") + } + + var anyExported bool + for _, spec := range decl.Specs { + // Captures the full byte range of the spec, including + // its associated doc comments and line comments. + // This range also covers any floating comments as these + // can be valuable for context. Like + // ``` + // type foo struct { // floating comment. + // // floating comment. + // + // x int + // } + // ``` + var startPos, endPos token.Pos + + switch spec := spec.(type) { + case *ast.TypeSpec: + // TODO(hxjiang): only keep the exported field of + // struct spec and exported method of interface spec. + if !spec.Name.IsExported() { + continue + } + anyExported = true + + // Include preceding doc comment, if any. + if spec.Doc == nil { + startPos = spec.Pos() + } else { + startPos = spec.Doc.Pos() + } + + // Include trailing line comment, if any. + if spec.Comment == nil { + endPos = spec.End() + } else { + endPos = spec.Comment.End() + } + + case *ast.ValueSpec: + // TODO(hxjiang): only keep the exported identifier. + if !slices.ContainsFunc(spec.Names, (*ast.Ident).IsExported) { + continue + } + anyExported = true + + if spec.Doc == nil { + startPos = spec.Pos() + } else { + startPos = spec.Doc.Pos() + } + + if spec.Comment == nil { + endPos = spec.End() + } else { + endPos = spec.Comment.End() + } + } + + indent, err := pgf.Indentation(startPos) + if err != nil { + return err + } + + buf.WriteString(indent) + + text, err := pgf.PosText(startPos, endPos) + if err != nil { + return err + } + + buf.Write(text) + buf.WriteString("\n") + } + + if decl.Lparen.IsValid() { + buf.WriteString(")\n") + } + + // Only write the summary of the genDecl if there is + // any exported spec. + if anyExported { + out.Write(buf.Bytes()) + out.WriteString("\n") + } } } diff --git a/gopls/internal/test/marker/testdata/mcptools/context.txt b/gopls/internal/test/marker/testdata/mcptools/context.txt index a6dba65f8d4..0cd63086130 100644 --- a/gopls/internal/test/marker/testdata/mcptools/context.txt +++ b/gopls/internal/test/marker/testdata/mcptools/context.txt @@ -69,6 +69,7 @@ doc.go: Package doc for package comment. */ package comment + <--- foo.go: @@ -86,6 +87,7 @@ import ( // Func doc for comment.Foo func Foo(foo string, _ int) + <--- -- function/foo.go -- @@ -139,6 +141,187 @@ package function func Foo(int, string) +type Exported struct{} + func (*Exported) Exported(int) + +<--- + +-- type.go -- +package main + +import( + "example.com/types" +) + +var x types.Exported //@loc(types, "x") + +//@mcptool("context", `{}`, types, output=withType) + +-- types/types.go -- +package types + +// Doc for exported. +type Exported struct { + // Doc for exported. + Exported string + // Doc for unexported. + unexported string +} + +// Doc for types. +type ( + // Doc for Foo first line. + // Doc for Foo second line. + Foo struct { + foo string + } + + // Doc for foo. + foo struct {} + + // Doc for Bar. + Bar struct { + bar string + } + + // Doc for bar. + bar struct {} +) + +-- @withType -- +Code blocks are delimited by --->...<--- markers. + +Current file "type.go" contains this import declaration: +---> +import( + "example.com/types" +) +<--- + +The imported packages declare the following symbols: + +example.com/types (package types) +types.go: +---> +package types + +// Doc for exported. +type Exported struct { + // Doc for exported. + Exported string + // Doc for unexported. + unexported string +} + +// Doc for types. +type ( + // Doc for Foo first line. + // Doc for Foo second line. + Foo struct { + foo string + } + // Doc for Bar. + Bar struct { + bar string + } +) + +<--- + +-- value.go -- +package main + +import( + "example.com/values" +) + +var y values.ConstFoo //@loc(values, "y") + +//@mcptool("context", `{}`, values, output=withValue) + +-- values/consts.go -- +package values + +const ( + // doc for ConstFoo + ConstFoo = "Foo" // comment for ConstFoo + // doc for constFoo + constFoo = "foo" // comment for constFoo + // doc for ConstBar + ConstBar = "Bar" // comment for ConstBar + // doc for constBar + constBar = "bar" // comment for constBar +) + +// doc for ConstExported +const ConstExported = "Exported" // comment for ConstExported + +// doc for constUnexported +var constUnexported = "unexported" // comment for constUnexported + +-- values/vars.go -- +package values + +var ( + // doc for VarFoo + VarFoo = "Foo" // comment for VarFoo + // doc for varFoo + varFoo = "foo" // comment for varFoo + // doc for VarBar + VarBar = "Bar" // comment for VarBar + // doc for varBar + varBar = "bar" // comment for varBar +) + +// doc for VarExported +var VarExported = "Exported" // comment for VarExported + +// doc for varUnexported +var varUnexported = "unexported" // comment for varUnexported + +-- @withValue -- +Code blocks are delimited by --->...<--- markers. + +Current file "value.go" contains this import declaration: +---> +import( + "example.com/values" +) +<--- + +The imported packages declare the following symbols: + +example.com/values (package values) +consts.go: +---> +package values + +const ( + // doc for ConstFoo + ConstFoo = "Foo" // comment for ConstFoo + // doc for ConstBar + ConstBar = "Bar" // comment for ConstBar +) + +// doc for ConstExported +const ConstExported = "Exported" // comment for ConstExported + +<--- + +vars.go: +---> +package values + +var ( + // doc for VarFoo + VarFoo = "Foo" // comment for VarFoo + // doc for VarBar + VarBar = "Bar" // comment for VarBar +) + +// doc for VarExported +var VarExported = "Exported" // comment for VarExported + <--- From 58e5e62336b61c4b3cd7caecc3fe4e1feaa9ce27 Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Mon, 2 Jun 2025 16:03:02 -0400 Subject: [PATCH 177/196] gopls/internal/test/marker: organize mcp tool context test To avoid confusion, different kind contest test have its own package, and the test is always triggered from testkind/main.go. For golang/go#73580 Change-Id: I6197d30ed4c8b4813f44efdbc45bb46162116006 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678155 Auto-Submit: Hongxiang Jiang Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- .../test/marker/testdata/mcptools/context.txt | 132 +++++++++--------- 1 file changed, 66 insertions(+), 66 deletions(-) diff --git a/gopls/internal/test/marker/testdata/mcptools/context.txt b/gopls/internal/test/marker/testdata/mcptools/context.txt index 0cd63086130..31f75ac732c 100644 --- a/gopls/internal/test/marker/testdata/mcptools/context.txt +++ b/gopls/internal/test/marker/testdata/mcptools/context.txt @@ -7,7 +7,21 @@ This test exercises mcp tool context. -- go.mod -- module example.com --- comment/doc.go -- +-- commenttest/main.go -- +package main + +import( + "example.com/commenttest/comment" +) + +func testComment() { //@loc(comment, "test") + comment.Foo("", 0) + function.Foo(0, "") +} + +//@mcptool("context", `{}`, comment, output=withComment) + +-- commenttest/comment/doc.go -- // Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -17,17 +31,17 @@ Package doc for package comment. */ package comment --- comment/foo.go -- -// File doc for foo.go part 1. +-- commenttest/comment/comment.go -- +// File doc for comment.go part 1. package comment -// File doc for foo.go part 2. +// File doc for comment.go part 2. import ( // comment for package renaming. myfmt "fmt" ) -// File doc for foo.go part 3. +// File doc for comment.go part 3. // Func doc for comment.Foo func Foo(foo string, _ int) { @@ -36,61 +50,60 @@ func Foo(foo string, _ int) { // Random comment floating around. --- comment.go -- -package main - -import( - "example.com/comment" -) - -func testComment() { //@loc(comment, "test") - comment.Foo("", 0) - function.Foo(0, "") -} - -//@mcptool("context", `{}`, comment, output=withComment) - -- @withComment -- Code blocks are delimited by --->...<--- markers. -Current file "comment.go" contains this import declaration: +Current file "main.go" contains this import declaration: ---> import( - "example.com/comment" + "example.com/commenttest/comment" ) <--- The imported packages declare the following symbols: -example.com/comment (package comment) -doc.go: ----> -/* -Package doc for package comment. -*/ -package comment - -<--- - -foo.go: +example.com/commenttest/comment (package comment) +comment.go: ---> -// File doc for foo.go part 1. +// File doc for comment.go part 1. package comment -// File doc for foo.go part 2. +// File doc for comment.go part 2. import ( // comment for package renaming. myfmt "fmt" ) -// File doc for foo.go part 3. +// File doc for comment.go part 3. // Func doc for comment.Foo func Foo(foo string, _ int) <--- --- function/foo.go -- +doc.go: +---> +/* +Package doc for package comment. +*/ +package comment + +<--- + +-- functiontest/main.go -- +package main + +import( + "example.com/functiontest/function" +) + +func testFunction() { //@loc(function, "test") + function.Foo(0, "") +} + +//@mcptool("context", `{}`, function, output=withFunction) + +-- functiontest/function/function.go -- package function func Foo(int, string) {} @@ -109,33 +122,20 @@ func (*Exported) unexported(int) {} func (*Exported) Exported(int) {} --- function.go -- -package main - -import( - "example.com/function" -) - -func testFunction() { //@loc(function, "test") - function.Foo(0, "") -} - -//@mcptool("context", `{}`, function, output=withFunction) - -- @withFunction -- Code blocks are delimited by --->...<--- markers. -Current file "function.go" contains this import declaration: +Current file "main.go" contains this import declaration: ---> import( - "example.com/function" + "example.com/functiontest/function" ) <--- The imported packages declare the following symbols: -example.com/function (package function) -foo.go: +example.com/functiontest/function (package function) +function.go: ---> package function @@ -147,18 +147,18 @@ func (*Exported) Exported(int) <--- --- type.go -- +-- typetest/main.go -- package main import( - "example.com/types" + "example.com/typetest/types" ) var x types.Exported //@loc(types, "x") //@mcptool("context", `{}`, types, output=withType) --- types/types.go -- +-- typetest/types/types.go -- package types // Doc for exported. @@ -192,16 +192,16 @@ type ( -- @withType -- Code blocks are delimited by --->...<--- markers. -Current file "type.go" contains this import declaration: +Current file "main.go" contains this import declaration: ---> import( - "example.com/types" + "example.com/typetest/types" ) <--- The imported packages declare the following symbols: -example.com/types (package types) +example.com/typetest/types (package types) types.go: ---> package types @@ -229,18 +229,18 @@ type ( <--- --- value.go -- +-- valuetest/main.go -- package main import( - "example.com/values" + "example.com/valuetest/values" ) var y values.ConstFoo //@loc(values, "y") //@mcptool("context", `{}`, values, output=withValue) --- values/consts.go -- +-- valuetest/values/consts.go -- package values const ( @@ -260,7 +260,7 @@ const ConstExported = "Exported" // comment for ConstExported // doc for constUnexported var constUnexported = "unexported" // comment for constUnexported --- values/vars.go -- +-- valuetest/values/vars.go -- package values var ( @@ -283,16 +283,16 @@ var varUnexported = "unexported" // comment for varUnexported -- @withValue -- Code blocks are delimited by --->...<--- markers. -Current file "value.go" contains this import declaration: +Current file "main.go" contains this import declaration: ---> import( - "example.com/values" + "example.com/valuetest/values" ) <--- The imported packages declare the following symbols: -example.com/values (package values) +example.com/valuetest/values (package values) consts.go: ---> package values From effd83eb4d44f4443100debb7ede6619a4ac6528 Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Mon, 2 Jun 2025 10:20:34 -0400 Subject: [PATCH 178/196] gopls/internal/golang: add type inlayhint for variable decl Skip type inlayhint if the type is explicitly written like: var foo string to avoid duplicated information. Rename gomod inlayhint.go to inlay_hint.go for consistency. For golang/go#73946 Change-Id: I94944d2fde685ee0b57ce9fdb38cee31bb1038b2 Reviewed-on: https://go-review.googlesource.com/c/tools/+/677935 Auto-Submit: Hongxiang Jiang Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- gopls/internal/golang/inlay_hint.go | 34 +++++++++++++++---- .../mod/{inlayhint.go => inlay_hint.go} | 0 .../marker/testdata/inlayhints/inlayhints.txt | 4 +++ .../marker/testdata/inlayhints/issue67142.txt | 4 +-- 4 files changed, 33 insertions(+), 9 deletions(-) rename gopls/internal/mod/{inlayhint.go => inlay_hint.go} (100%) diff --git a/gopls/internal/golang/inlay_hint.go b/gopls/internal/golang/inlay_hint.go index 589a809f933..10f5db02827 100644 --- a/gopls/internal/golang/inlay_hint.go +++ b/gopls/internal/golang/inlay_hint.go @@ -165,13 +165,33 @@ func funcTypeParams(info *types.Info, pgf *parsego.File, qual types.Qualifier, c } func assignVariableTypes(info *types.Info, pgf *parsego.File, qual types.Qualifier, cur inspector.Cursor, add func(protocol.InlayHint)) { - for curAssign := range cur.Preorder((*ast.AssignStmt)(nil)) { - stmt := curAssign.Node().(*ast.AssignStmt) - if stmt.Tok != token.DEFINE { - continue - } - for _, v := range stmt.Lhs { - variableType(info, pgf, qual, v, add) + for node := range cur.Preorder((*ast.AssignStmt)(nil), (*ast.ValueSpec)(nil)) { + switch n := node.Node().(type) { + case *ast.AssignStmt: + if n.Tok != token.DEFINE { + continue + } + for _, v := range n.Lhs { + variableType(info, pgf, qual, v, add) + } + case *ast.GenDecl: + if n.Tok != token.VAR { + continue + } + for _, v := range n.Specs { + spec := v.(*ast.ValueSpec) + // The type of the variable is written, skip showing type of this var. + // ```go + // var foo string + // ``` + if spec.Type != nil { + continue + } + + for _, v := range spec.Names { + variableType(info, pgf, qual, v, add) + } + } } } } diff --git a/gopls/internal/mod/inlayhint.go b/gopls/internal/mod/inlay_hint.go similarity index 100% rename from gopls/internal/mod/inlayhint.go rename to gopls/internal/mod/inlay_hint.go diff --git a/gopls/internal/test/marker/testdata/inlayhints/inlayhints.txt b/gopls/internal/test/marker/testdata/inlayhints/inlayhints.txt index 0ea40f78bc2..a2021cc6103 100644 --- a/gopls/internal/test/marker/testdata/inlayhints/inlayhints.txt +++ b/gopls/internal/test/marker/testdata/inlayhints/inlayhints.txt @@ -363,6 +363,8 @@ func SumNumbers[K comparable, V Number](m map[K]V) V { package inlayHint //@inlayhints(vartypes) func assignTypes() { + var x string + var y = "" i, j := 0, len([]string{})-1 println(i, j) } @@ -385,6 +387,8 @@ func compositeLitType() { package inlayHint //@inlayhints(vartypes) func assignTypes() { + var x string + var y = "" i< int>, j< int> := 0, len([]string{})-1 println(i, j) } diff --git a/gopls/internal/test/marker/testdata/inlayhints/issue67142.txt b/gopls/internal/test/marker/testdata/inlayhints/issue67142.txt index df25e6fb190..456da252377 100644 --- a/gopls/internal/test/marker/testdata/inlayhints/issue67142.txt +++ b/gopls/internal/test/marker/testdata/inlayhints/issue67142.txt @@ -25,11 +25,11 @@ go 1.21.9 //@inlayhints(out) package p -var _ = rand.Float64() +var _ = rand.Float64() -- @out -- //@inlayhints(out) package p -var _ = rand.Float64() +var _ = rand.Float64() From c3cb1f1305e93dfa7f6ce7ac4a809f22f93a765c Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Fri, 30 May 2025 14:35:14 -0400 Subject: [PATCH 179/196] gopls/internal/mcp: add top level symbols from current package When adding context against the current package based on the input file URI, write both the exported and unexported symbols. Providing context against the current file over the other files in current package. For golang/go#73580 Change-Id: Ica2166b8b005badcbf5b81e2714e5d326d29e7e0 Reviewed-on: https://go-review.googlesource.com/c/tools/+/677537 Auto-Submit: Hongxiang Jiang LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- gopls/internal/mcp/context.go | 440 ++++++++++-------- .../test/marker/testdata/mcptools/context.txt | 234 ++++++++-- 2 files changed, 441 insertions(+), 233 deletions(-) diff --git a/gopls/internal/mcp/context.go b/gopls/internal/mcp/context.go index f1849db65f1..06911915c19 100644 --- a/gopls/internal/mcp/context.go +++ b/gopls/internal/mcp/context.go @@ -24,7 +24,6 @@ import ( "golang.org/x/tools/gopls/internal/golang" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/util/astutil" - "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/internal/mcp" ) @@ -53,10 +52,90 @@ func contextHandler(ctx context.Context, session *cache.Session, params *mcp.Cal var result strings.Builder result.WriteString("Code blocks are delimited by --->...<--- markers.\n\n") - // TODO(hxjiang): consider making the context tool best effort. Ignore - // non-critical errors. - if err := writePackageSummary(ctx, snapshot, pkg, pgf, &result); err != nil { - return nil, err + + // TODO(hxjiang): add context based on location's range. + + fmt.Fprintf(&result, "Current package %q (package %s) declares the following symbols:\n\n", pkg.Metadata().PkgPath, pkg.Metadata().Name) + // Write context of the current file. + { + fmt.Fprintf(&result, "%s (current file):\n", filepath.Base(pgf.URI.Path())) + result.WriteString("--->\n") + if err := writeFileSummary(ctx, snapshot, pgf.URI, &result, false); err != nil { + return nil, err + } + result.WriteString("<---\n\n") + } + + // Write context of the rest of the files in the current package. + { + for _, file := range pkg.CompiledGoFiles() { + if file.URI == pgf.URI { + continue + } + + fmt.Fprintf(&result, "%s:\n", filepath.Base(file.URI.Path())) + result.WriteString("--->\n") + if err := writeFileSummary(ctx, snapshot, file.URI, &result, false); err != nil { + return nil, err + } + result.WriteString("<---\n\n") + } + } + + // Write dependencies context of current file. + if len(pgf.File.Imports) > 0 { + // Write import decls of the current file. + { + fmt.Fprintf(&result, "Current file %q contains this import declaration:\n", filepath.Base(pgf.URI.Path())) + result.WriteString("--->\n") + // Add all import decl to output including all floating comment by + // using GenDecl's start and end position. + for _, decl := range pgf.File.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.IMPORT { + continue + } + + text, err := pgf.NodeText(genDecl) + if err != nil { + return nil, err + } + + result.Write(text) + result.WriteString("\n") + } + result.WriteString("<---\n\n") + } + + // Write summaries from imported packages. + { + result.WriteString("The imported packages declare the following symbols:\n\n") + for _, imp := range pgf.File.Imports { + importPath := metadata.UnquoteImportPath(imp) + if importPath == "" { + continue + } + + impID := pkg.Metadata().DepsByImpPath[importPath] + if impID == "" { + continue // ignore error + } + impMetadata := snapshot.Metadata(impID) + if impMetadata == nil { + continue // ignore error + } + + fmt.Fprintf(&result, "%q (package %s)\n", importPath, impMetadata.Name) + for _, f := range impMetadata.CompiledGoFiles { + fmt.Fprintf(&result, "%s:\n", filepath.Base(f.Path())) + result.WriteString("--->\n") + if err := writeFileSummary(ctx, snapshot, f, &result, true); err != nil { + return nil, err + } + result.WriteString("<---\n\n") + } + } + } } return &mcp.CallToolResult{ @@ -66,246 +145,213 @@ func contextHandler(ctx context.Context, session *cache.Session, params *mcp.Cal }, nil } -// writePackageSummary writes the package summaries to the bytes buffer based on -// the input import specs. -func writePackageSummary(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, out *strings.Builder) error { - if len(pgf.File.Imports) == 0 { - return nil +// writeFileSummary writes the file summary to the string builder based on +// the input file URI. +func writeFileSummary(ctx context.Context, snapshot *cache.Snapshot, f protocol.DocumentURI, out *strings.Builder, onlyExported bool) error { + fh, err := snapshot.ReadFile(ctx, f) + if err != nil { + return err + } + pgf, err := snapshot.ParseGo(ctx, fh, parsego.Full) + if err != nil { + return err } - fmt.Fprintf(out, "Current file %q contains this import declaration:\n", filepath.Base(pgf.URI.Path())) - out.WriteString("--->\n") - // Add all import decl to output including all floating comment by using - // GenDecl's start and end position. - for _, decl := range pgf.File.Decls { - genDecl, ok := decl.(*ast.GenDecl) - if !ok { - continue + // Copy everything before the first non-import declaration: + // package decl, imports decl(s), and all comments (excluding copyright). + { + endPos := pgf.File.FileEnd + + outerloop: + for _, decl := range pgf.File.Decls { + switch decl := decl.(type) { + case *ast.FuncDecl: + if decl.Doc != nil { + endPos = decl.Doc.Pos() + } else { + endPos = decl.Pos() + } + break outerloop + case *ast.GenDecl: + if decl.Tok == token.IMPORT { + continue + } + if decl.Doc != nil { + endPos = decl.Doc.Pos() + } else { + endPos = decl.Pos() + } + break outerloop + } } - if genDecl.Tok != token.IMPORT { - continue + startPos := pgf.File.FileStart + if copyright := golang.CopyrightComment(pgf.File); copyright != nil { + startPos = copyright.End() } - text, err := pgf.NodeText(genDecl) + text, err := pgf.PosText(startPos, endPos) if err != nil { return err } - out.Write(text) - out.WriteString("\n") + out.Write(bytes.TrimSpace(text)) + out.WriteString("\n\n") } - out.WriteString("<---\n\n") - out.WriteString("The imported packages declare the following symbols:\n\n") + // Write func decl and gen decl. + for _, decl := range pgf.File.Decls { + switch decl := decl.(type) { + case *ast.FuncDecl: + if onlyExported { + if !decl.Name.IsExported() { + continue + } - for _, imp := range pgf.File.Imports { - importPath := metadata.UnquoteImportPath(imp) - if importPath == "" { - continue - } + if decl.Recv != nil && len(decl.Recv.List) > 0 { + _, rname, _ := astutil.UnpackRecv(decl.Recv.List[0].Type) + if !rname.IsExported() { + continue + } + } + } - impID := pkg.Metadata().DepsByImpPath[importPath] - if impID == "" { - return fmt.Errorf("no package data for import %q", importPath) - } - impMetadata := snapshot.Metadata(impID) - if impMetadata == nil { - return bug.Errorf("failed to resolve import ID %q", impID) - } + // Write doc comment and func signature. + startPos := decl.Pos() + if decl.Doc != nil { + startPos = decl.Doc.Pos() + } - fmt.Fprintf(out, "%s (package %s)\n", importPath, impMetadata.Name) - for _, f := range impMetadata.CompiledGoFiles { - fmt.Fprintf(out, "%s:\n", filepath.Base(f.Path())) - out.WriteString("--->\n") - fh, err := snapshot.ReadFile(ctx, f) + text, err := pgf.PosText(startPos, decl.Type.End()) if err != nil { return err } - pgf, err := snapshot.ParseGo(ctx, fh, parsego.Full) - if err != nil { - return err + + out.Write(text) + out.WriteString("\n\n") + + case *ast.GenDecl: + if decl.Tok == token.IMPORT { + continue } - // Copy everything before the first non-import declaration: - // package decl, imports decl(s), and all comments (excluding copyright). - { - endPos := pgf.File.FileEnd - - outerloop: - for _, decl := range pgf.File.Decls { - switch decl := decl.(type) { - case *ast.FuncDecl: - if decl.Doc != nil { - endPos = decl.Doc.Pos() - } else { - endPos = decl.Pos() - } - break outerloop - case *ast.GenDecl: - if decl.Tok == token.IMPORT { - continue - } - if decl.Doc != nil { - endPos = decl.Doc.Pos() - } else { - endPos = decl.Pos() - } - break outerloop - } + // Dump the entire GenDecl (exported or unexported) + // including doc comment without any filtering to the output. + if !onlyExported { + startPos := decl.Pos() + if decl.Doc != nil { + startPos = decl.Doc.Pos() } - - startPos := pgf.File.FileStart - if copyright := golang.CopyrightComment(pgf.File); copyright != nil { - startPos = copyright.End() + text, err := pgf.PosText(startPos, decl.End()) + if err != nil { + return err } - text, err := pgf.PosText(startPos, endPos) + out.Write(text) + out.WriteString("\n") + continue + } + + // Write only the GenDecl with exported identifier to the output. + var buf bytes.Buffer + if decl.Doc != nil { + text, err := pgf.NodeText(decl.Doc) if err != nil { return err } + buf.Write(text) + buf.WriteString("\n") + } - out.Write(bytes.TrimSpace(text)) - out.WriteString("\n\n") + buf.WriteString(decl.Tok.String() + " ") + if decl.Lparen.IsValid() { + buf.WriteString("(\n") } - // Write exported func decl and gen decl. - for _, decl := range pgf.File.Decls { - switch decl := decl.(type) { - case *ast.FuncDecl: - if !decl.Name.IsExported() { + var anyExported bool + for _, spec := range decl.Specs { + // Captures the full byte range of the spec, including + // its associated doc comments and line comments. + // This range also covers any floating comments as these + // can be valuable for context. Like + // ``` + // type foo struct { // floating comment. + // // floating comment. + // + // x int + // } + // ``` + var startPos, endPos token.Pos + + switch spec := spec.(type) { + case *ast.TypeSpec: + // TODO(hxjiang): only keep the exported field of + // struct spec and exported method of interface spec. + if !spec.Name.IsExported() { continue } + anyExported = true - if decl.Recv != nil && len(decl.Recv.List) > 0 { - _, rname, _ := astutil.UnpackRecv(decl.Recv.List[0].Type) - if !rname.IsExported() { - continue - } + // Include preceding doc comment, if any. + if spec.Doc == nil { + startPos = spec.Pos() + } else { + startPos = spec.Doc.Pos() } - // Write doc comment and func signature. - startPos := decl.Pos() - if decl.Doc != nil { - startPos = decl.Doc.Pos() + // Include trailing line comment, if any. + if spec.Comment == nil { + endPos = spec.End() + } else { + endPos = spec.Comment.End() } - text, err := pgf.PosText(startPos, decl.Type.End()) - if err != nil { - return err - } - - out.Write(text) - out.WriteString("\n\n") - - case *ast.GenDecl: - if decl.Tok == token.IMPORT { + case *ast.ValueSpec: + // TODO(hxjiang): only keep the exported identifier. + if !slices.ContainsFunc(spec.Names, (*ast.Ident).IsExported) { continue } + anyExported = true - var buf bytes.Buffer - if decl.Doc != nil { - text, err := pgf.NodeText(decl.Doc) - if err != nil { - return err - } - buf.Write(text) - buf.WriteString("\n") + if spec.Doc == nil { + startPos = spec.Pos() + } else { + startPos = spec.Doc.Pos() } - buf.WriteString(decl.Tok.String() + " ") - if decl.Lparen.IsValid() { - buf.WriteString("(\n") + if spec.Comment == nil { + endPos = spec.End() + } else { + endPos = spec.Comment.End() } + } - var anyExported bool - for _, spec := range decl.Specs { - // Captures the full byte range of the spec, including - // its associated doc comments and line comments. - // This range also covers any floating comments as these - // can be valuable for context. Like - // ``` - // type foo struct { // floating comment. - // // floating comment. - // - // x int - // } - // ``` - var startPos, endPos token.Pos - - switch spec := spec.(type) { - case *ast.TypeSpec: - // TODO(hxjiang): only keep the exported field of - // struct spec and exported method of interface spec. - if !spec.Name.IsExported() { - continue - } - anyExported = true - - // Include preceding doc comment, if any. - if spec.Doc == nil { - startPos = spec.Pos() - } else { - startPos = spec.Doc.Pos() - } - - // Include trailing line comment, if any. - if spec.Comment == nil { - endPos = spec.End() - } else { - endPos = spec.Comment.End() - } - - case *ast.ValueSpec: - // TODO(hxjiang): only keep the exported identifier. - if !slices.ContainsFunc(spec.Names, (*ast.Ident).IsExported) { - continue - } - anyExported = true - - if spec.Doc == nil { - startPos = spec.Pos() - } else { - startPos = spec.Doc.Pos() - } - - if spec.Comment == nil { - endPos = spec.End() - } else { - endPos = spec.Comment.End() - } - } - - indent, err := pgf.Indentation(startPos) - if err != nil { - return err - } - - buf.WriteString(indent) - - text, err := pgf.PosText(startPos, endPos) - if err != nil { - return err - } - - buf.Write(text) - buf.WriteString("\n") - } + indent, err := pgf.Indentation(startPos) + if err != nil { + return err + } - if decl.Lparen.IsValid() { - buf.WriteString(")\n") - } + buf.WriteString(indent) - // Only write the summary of the genDecl if there is - // any exported spec. - if anyExported { - out.Write(buf.Bytes()) - out.WriteString("\n") - } + text, err := pgf.PosText(startPos, endPos) + if err != nil { + return err } + + buf.Write(text) + buf.WriteString("\n") + } + + if decl.Lparen.IsValid() { + buf.WriteString(")\n") } - out.WriteString("<---\n\n") + // Only write the summary of the genDecl if there is + // any exported spec. + if anyExported { + out.Write(buf.Bytes()) + out.WriteString("\n") + } } } return nil diff --git a/gopls/internal/test/marker/testdata/mcptools/context.txt b/gopls/internal/test/marker/testdata/mcptools/context.txt index 31f75ac732c..493545952ba 100644 --- a/gopls/internal/test/marker/testdata/mcptools/context.txt +++ b/gopls/internal/test/marker/testdata/mcptools/context.txt @@ -7,21 +7,39 @@ This test exercises mcp tool context. -- go.mod -- module example.com --- commenttest/main.go -- +-- a/main.go -- +// File doc for main.go part 1. package main +// File doc for main.go part 2. import( - "example.com/commenttest/comment" + "example.com/a/comment" ) -func testComment() { //@loc(comment, "test") +// File doc for main.go part 3. + +// doc comment for func foo. +func foo() {//@mcptool("context", `{}`, "foo", output=withComment) comment.Foo("", 0) - function.Foo(0, "") } -//@mcptool("context", `{}`, comment, output=withComment) +-- a/a.go -- +// File doc for a.go. +package main + +// doc comment for func a. +func a () {} + +// doc comment for type b. +type b struct {} + +// doc comment for const c. +const c = "" + +// doc comment for var d. +var d int --- commenttest/comment/doc.go -- +-- a/comment/doc.go -- // Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -31,7 +49,7 @@ Package doc for package comment. */ package comment --- commenttest/comment/comment.go -- +-- a/comment/comment.go -- // File doc for comment.go part 1. package comment @@ -43,7 +61,7 @@ import ( // File doc for comment.go part 3. -// Func doc for comment.Foo +// doc comment for comment.Foo func Foo(foo string, _ int) { myfmt.Printf("%s", foo) } @@ -53,16 +71,51 @@ func Foo(foo string, _ int) { -- @withComment -- Code blocks are delimited by --->...<--- markers. +Current package "example.com/a" (package main) declares the following symbols: + +main.go (current file): +---> +// File doc for main.go part 1. +package main + +// File doc for main.go part 2. +import( + "example.com/a/comment" +) + +// File doc for main.go part 3. + +// doc comment for func foo. +func foo() + +<--- + +a.go: +---> +// File doc for a.go. +package main + +// doc comment for func a. +func a () + +// doc comment for type b. +type b struct {} +// doc comment for const c. +const c = "" +// doc comment for var d. +var d int +<--- + Current file "main.go" contains this import declaration: ---> import( - "example.com/commenttest/comment" + "example.com/a/comment" ) <--- The imported packages declare the following symbols: -example.com/commenttest/comment (package comment) +"example.com/a/comment" (package comment) comment.go: ---> // File doc for comment.go part 1. @@ -76,7 +129,7 @@ import ( // File doc for comment.go part 3. -// Func doc for comment.Foo +// doc comment for comment.Foo func Foo(foo string, _ int) <--- @@ -90,20 +143,18 @@ package comment <--- --- functiontest/main.go -- +-- b/main.go -- package main import( - "example.com/functiontest/function" + "example.com/b/function" ) -func testFunction() { //@loc(function, "test") +func testFunction() {//@mcptool("context", `{}`, "test", output=withFunction) function.Foo(0, "") } -//@mcptool("context", `{}`, function, output=withFunction) - --- functiontest/function/function.go -- +-- b/function/function.go -- package function func Foo(int, string) {} @@ -125,16 +176,30 @@ func (*Exported) Exported(int) {} -- @withFunction -- Code blocks are delimited by --->...<--- markers. +Current package "example.com/b" (package main) declares the following symbols: + +main.go (current file): +---> +package main + +import( + "example.com/b/function" +) + +func testFunction() + +<--- + Current file "main.go" contains this import declaration: ---> import( - "example.com/functiontest/function" + "example.com/b/function" ) <--- The imported packages declare the following symbols: -example.com/functiontest/function (package function) +"example.com/b/function" (package function) function.go: ---> package function @@ -147,18 +212,16 @@ func (*Exported) Exported(int) <--- --- typetest/main.go -- +-- c/main.go -- package main import( - "example.com/typetest/types" + "example.com/c/types" ) -var x types.Exported //@loc(types, "x") - -//@mcptool("context", `{}`, types, output=withType) +var x types.Exported //@mcptool("context", `{}`, "x", output=withType) --- typetest/types/types.go -- +-- c/types/types.go -- package types // Doc for exported. @@ -192,16 +255,29 @@ type ( -- @withType -- Code blocks are delimited by --->...<--- markers. +Current package "example.com/c" (package main) declares the following symbols: + +main.go (current file): +---> +package main + +import( + "example.com/c/types" +) + +var x types.Exported +<--- + Current file "main.go" contains this import declaration: ---> import( - "example.com/typetest/types" + "example.com/c/types" ) <--- The imported packages declare the following symbols: -example.com/typetest/types (package types) +"example.com/c/types" (package types) types.go: ---> package types @@ -229,18 +305,16 @@ type ( <--- --- valuetest/main.go -- +-- d/main.go -- package main import( - "example.com/valuetest/values" + "example.com/d/values" ) -var y values.ConstFoo //@loc(values, "y") - -//@mcptool("context", `{}`, values, output=withValue) +var y values.ConstFoo //@mcptool("context", `{}`, "y", output=withValue) --- valuetest/values/consts.go -- +-- d/values/consts.go -- package values const ( @@ -260,7 +334,7 @@ const ConstExported = "Exported" // comment for ConstExported // doc for constUnexported var constUnexported = "unexported" // comment for constUnexported --- valuetest/values/vars.go -- +-- d/values/vars.go -- package values var ( @@ -283,16 +357,29 @@ var varUnexported = "unexported" // comment for varUnexported -- @withValue -- Code blocks are delimited by --->...<--- markers. +Current package "example.com/d" (package main) declares the following symbols: + +main.go (current file): +---> +package main + +import( + "example.com/d/values" +) + +var y values.ConstFoo +<--- + Current file "main.go" contains this import declaration: ---> import( - "example.com/valuetest/values" + "example.com/d/values" ) <--- The imported packages declare the following symbols: -example.com/valuetest/values (package values) +"example.com/d/values" (package values) consts.go: ---> package values @@ -325,3 +412,78 @@ var VarExported = "Exported" // comment for VarExported <--- +-- e/main.go -- +package main + +func main() {} //@mcptool("context", `{}`, "main", output=samePackage) + +-- e/foo.go -- +package main + +var ( + foo string + Foo string +) + +-- e/bar.go -- +package main + +const ( + bar = "" + Bar = "" +) + +-- e/baz.go -- +package main + +func baz(int) string { + return "" +} + +func Baz(string) int { + return 0 +} + +-- @samePackage -- +Code blocks are delimited by --->...<--- markers. + +Current package "example.com/e" (package main) declares the following symbols: + +main.go (current file): +---> +package main + +func main() + +<--- + +bar.go: +---> +package main + +const ( + bar = "" + Bar = "" +) +<--- + +baz.go: +---> +package main + +func baz(int) string + +func Baz(string) int + +<--- + +foo.go: +---> +package main + +var ( + foo string + Foo string +) +<--- + From 25caa76c0f396c8d794d17f7401cb8b1deb88137 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 3 Jun 2025 10:23:08 -0400 Subject: [PATCH 180/196] gopls/internal/telemetry/cmd/stacks: delete It moved to x/telemetry/cmd/stacks in CL 678275. Change-Id: I8e91068dc8227d37b34720ff3f2e8311f899bee6 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678295 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- gopls/internal/telemetry/cmd/stacks/stacks.go | 1379 ----------------- .../telemetry/cmd/stacks/stacks_test.go | 347 ----- 2 files changed, 1726 deletions(-) delete mode 100644 gopls/internal/telemetry/cmd/stacks/stacks.go delete mode 100644 gopls/internal/telemetry/cmd/stacks/stacks_test.go diff --git a/gopls/internal/telemetry/cmd/stacks/stacks.go b/gopls/internal/telemetry/cmd/stacks/stacks.go deleted file mode 100644 index 5c7625e3b9c..00000000000 --- a/gopls/internal/telemetry/cmd/stacks/stacks.go +++ /dev/null @@ -1,1379 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build linux || darwin - -// The stacks command finds all gopls stack traces reported by -// telemetry in the past 7 days, and reports their associated GitHub -// issue, creating new issues as needed. -// -// The association of stacks with GitHub issues (labelled -// gopls/telemetry-wins) is represented in two different ways by the -// body (first comment) of the issue: -// -// 1. Each distinct stack is identified by an ID, 6-digit base64 -// string such as "TwtkSg". If a stack's ID appears anywhere -// within the issue body, the stack is associated with the issue. -// -// Some problems are highly deterministic, resulting in many -// field reports of the exact same stack. For such problems, a -// single ID in the issue body suffices to record the -// association. But most problems are exhibited in a variety of -// ways, leading to multiple field reports of similar but -// distinct stacks. Hence the following way to associate stacks -// with issues. -// -// 2. Each GitHub issue body may start with a code block of this form: -// -// ``` -// #!stacks -// "runtime.sigpanic" && "golang.hover:+170" -// ``` -// -// The first line indicates the purpose of the block; the -// remainder is a predicate that matches stacks. -// It is an expression defined by this grammar: -// -// > expr = "string literal" -// > | ( expr ) -// > | ! expr -// > | expr && expr -// > | expr || expr -// -// Each string literal must match complete words on the stack; -// the other productions are boolean operations. -// As an example of literal matching, "fu+12" matches "x:fu+12 " -// but not "fu:123" or "snafu+12". -// -// The stacks command gathers all such predicates out of the -// labelled issues and evaluates each one against each new stack. -// If the predicate for an issue matches, the issue is considered -// to have "claimed" the stack: the stack command appends a -// comment containing the new (variant) stack to the issue, and -// appends the stack's ID to the last line of the issue body. -// -// It is an error if two issues' predicates attempt to claim the -// same stack. -package main - -// TODO(adonovan): create a proper package with tests. Much of this -// machinery might find wider use in other x/telemetry clients. - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "flag" - "fmt" - "go/ast" - "go/parser" - "go/token" - "hash/fnv" - "io" - "log" - "net/http" - "net/url" - "os" - "os/exec" - "path" - "path/filepath" - "regexp" - "runtime" - "sort" - "strconv" - "strings" - "time" - "unicode" - - "golang.org/x/mod/semver" - "golang.org/x/sys/unix" - "golang.org/x/telemetry" - "golang.org/x/tools/gopls/internal/util/browser" - "golang.org/x/tools/gopls/internal/util/moremaps" - "golang.org/x/tools/gopls/internal/util/morestrings" -) - -// flags -var ( - programFlag = flag.String("program", "golang.org/x/tools/gopls", "Package path of program to process") - - daysFlag = flag.Int("days", 7, "number of previous days of telemetry data to read") - - dryRun = flag.Bool("n", false, "dry run, avoid updating issues") -) - -// ProgramConfig is the configuration for processing reports for a specific -// program. -type ProgramConfig struct { - // Program is the package path of the program to process. - Program string - - // IncludeClient indicates that stack Info should include gopls/client metadata. - IncludeClient bool - - // SearchLabel is the GitHub label used to find all existing reports. - SearchLabel string - - // NewIssuePrefix is the package prefix to apply to new issue titles. - NewIssuePrefix string - - // NewIssueLabels are the labels to apply to new issues. - NewIssueLabels []string - - // MatchSymbolPrefix is the prefix of "interesting" symbol names. - // - // A given stack will be "blamed" on the deepest symbol in the stack that: - // 1. Matches MatchSymbolPrefix - // 2. Is an exported function or any method on an exported Type. - // 3. Does _not_ match IgnoreSymbolContains. - MatchSymbolPrefix string - - // IgnoreSymbolContains are "uninteresting" symbol substrings. e.g., - // logging packages. - IgnoreSymbolContains []string - - // Repository is the repository where the issues should be created, for example: "golang/go" - Repository string -} - -var programs = map[string]ProgramConfig{ - "golang.org/x/tools/gopls": { - Program: "golang.org/x/tools/gopls", - IncludeClient: true, - SearchLabel: "gopls/telemetry-wins", - NewIssuePrefix: "x/tools/gopls", - NewIssueLabels: []string{ - "gopls", - "Tools", - "gopls/telemetry-wins", - "NeedsInvestigation", - }, - MatchSymbolPrefix: "golang.org/x/tools/gopls/", - IgnoreSymbolContains: []string{ - "internal/util/bug.", - "internal/bug.", // former name in gopls/0.14.2 - }, - Repository: "golang/go", - }, - "cmd/compile": { - Program: "cmd/compile", - SearchLabel: "compiler/telemetry-wins", - NewIssuePrefix: "cmd/compile", - NewIssueLabels: []string{ - "compiler/runtime", - "compiler/telemetry-wins", - "NeedsInvestigation", - }, - MatchSymbolPrefix: "cmd/compile", - IgnoreSymbolContains: []string{ - // Various "fatal" wrappers. - "Fatal", // base.Fatal*, ssa.Value.Fatal*, etc. - "cmd/compile/internal/base.Assert", - "cmd/compile/internal/noder.assert", - "cmd/compile/internal/ssa.Compile.func1", // basically a Fatalf wrapper. - // Panic recovery. - "cmd/compile/internal/types2.(*Checker).handleBailout", - "cmd/compile/internal/gc.handlePanic", - }, - Repository: "golang/go", - }, - "github.com/go-delve/delve/cmd/dlv": { - Program: "github.com/go-delve/delve/cmd/dlv", - IncludeClient: false, - SearchLabel: "delve/telemetry-wins", - NewIssuePrefix: "telemetry report", - NewIssueLabels: []string{ - "delve/telemetry-wins", - }, - MatchSymbolPrefix: "github.com/go-delve/delve", - IgnoreSymbolContains: []string{ - "service/dap.(*Session).recoverPanic", - "rpccommon.newInternalError", - "rpccommon.(*ServerImpl).serveJSONCodec", - }, - Repository: "go-delve/delve", - }, -} - -func main() { - log.SetFlags(0) - log.SetPrefix("stacks: ") - flag.Parse() - - var ghclient *githubClient - - // Read GitHub authentication token from $HOME/.stacks.token. - // - // You can create one using the flow at: GitHub > You > Settings > - // Developer Settings > Personal Access Tokens > Fine-grained tokens > - // Generate New Token. Generate the token on behalf of golang/go - // with R/W access to "Issues". - // The token is typically of the form "github_pat_XXX", with 82 hex digits. - // Save it in the file, with mode 0400. - // - // For security, secret tokens should be read from files, not - // command-line flags or environment variables. - { - home, err := os.UserHomeDir() - if err != nil { - log.Fatal(err) - } - tokenFile := filepath.Join(home, ".stacks.token") - content, err := os.ReadFile(tokenFile) - if err != nil { - log.Fatalf("cannot read GitHub authentication token: %v", err) - } - ghclient = &githubClient{authToken: string(bytes.TrimSpace(content))} - } - - pcfg, ok := programs[*programFlag] - if !ok { - log.Fatalf("unknown -program %s", *programFlag) - } - - // Read all recent telemetry reports. - stacks, distinctStacks, stackToURL, err := readReports(pcfg, *daysFlag) - if err != nil { - log.Fatalf("Error reading reports: %v", err) - } - - issues, err := readIssues(ghclient, pcfg) - if err != nil { - log.Fatalf("Error reading issues: %v", err) - } - - // Map stacks to existing issues (if any). - claimedBy := claimStacks(issues, stacks) - - // Update existing issues that claimed new stacks. - updateIssues(ghclient, pcfg.Repository, issues, stacks, stackToURL) - - // For each stack, show existing issue or create a new one. - // Aggregate stack IDs by issue summary. - var ( - // Both vars map the summary line to the stack count. - existingIssues = make(map[string]int64) - newIssues = make(map[string]int64) - ) - for stack, counts := range stacks { - id := stackID(stack) - - var total int64 - for _, count := range counts { - total += count - } - - if issue, ok := claimedBy[id]; ok { - // existing issue, already updated above, just store - // the summary. - state := issue.State - if issue.State == "closed" && issue.StateReason == "completed" { - state = "completed" - } - summary := fmt.Sprintf("#%d: %s [%s]", - issue.Number, issue.Title, state) - if state == "completed" && issue.Milestone != nil { - summary += " milestone " + strings.TrimPrefix(issue.Milestone.Title, "gopls/") - } - existingIssues[summary] += total - } else { - // new issue, need to create GitHub issue and store - // summary. - title := newIssue(pcfg, stack, id, stackToURL[stack], counts) - summary := fmt.Sprintf("%s: %s [%s]", id, title, "new") - newIssues[summary] += total - } - } - - fmt.Printf("Found %d distinct stacks in last %v days:\n", distinctStacks, *daysFlag) - print := func(caption string, issues map[string]int64) { - // Print items in descending frequency. - keys := moremaps.KeySlice(issues) - sort.Slice(keys, func(i, j int) bool { - return issues[keys[i]] > issues[keys[j]] - }) - fmt.Printf("%s issues:\n", caption) - for _, summary := range keys { - count := issues[summary] - // Show closed issues in "white". - if isTerminal(os.Stdout) && (strings.Contains(summary, "[closed]") || strings.Contains(summary, "[completed]")) { - // ESC + "[" + n + "m" => change color to n - // (37 = white, 0 = default) - summary = "\x1B[37m" + summary + "\x1B[0m" - } - fmt.Printf("%s (n=%d)\n", summary, count) - } - } - print("Existing", existingIssues) - print("New", newIssues) -} - -// Info is used as a key for de-duping and aggregating. -// Do not add detail about particular records (e.g. data, telemetry URL). -type Info struct { - Program string // "golang.org/x/tools/gopls" - ProgramVersion string // "v0.16.1" - GoVersion string // "go1.23" - GOOS, GOARCH string - GoplsClient string // e.g. "vscode" (only set if Program == "golang.org/x/tools/gopls") -} - -func (info Info) String() string { - s := fmt.Sprintf("%s@%s %s %s/%s", - info.Program, info.ProgramVersion, - info.GoVersion, info.GOOS, info.GOARCH) - if info.GoplsClient != "" { - s += " " + info.GoplsClient - } - return s -} - -// readReports downloads telemetry stack reports for a program from the -// specified number of most recent days. -// -// stacks is a map of stack text to program metadata to stack+metadata report -// count. -// distinctStacks is the number of distinct stacks across all reports. -// stackToURL maps the stack text to the oldest telemetry JSON report it was -// included in. -func readReports(pcfg ProgramConfig, days int) (stacks map[string]map[Info]int64, distinctStacks int, stackToURL map[string]string, err error) { - stacks = make(map[string]map[Info]int64) - stackToURL = make(map[string]string) - - t := time.Now() - for i := range days { - date := t.Add(-time.Duration(i+1) * 24 * time.Hour).Format(time.DateOnly) - - url := fmt.Sprintf("https://telemetry.go.dev/data/%s", date) - resp, err := http.Get(url) - if err != nil { - return nil, 0, nil, fmt.Errorf("error on GET %s: %v", url, err) - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - return nil, 0, nil, fmt.Errorf("GET %s returned %d %s", url, resp.StatusCode, resp.Status) - } - - dec := json.NewDecoder(resp.Body) - for { - var report telemetry.Report - if err := dec.Decode(&report); err != nil { - if err == io.EOF { - break - } - return nil, 0, nil, fmt.Errorf("error decoding report: %v", err) - } - for _, prog := range report.Programs { - if prog.Program != pcfg.Program { - continue - } - if len(prog.Stacks) == 0 { - continue - } - // Ignore @devel versions as they correspond to - // ephemeral (and often numerous) variations of - // the program as we work on a fix to a bug. - if prog.Version == "devel" { - continue - } - - // Include applicable client names (e.g. vscode, eglot) for gopls. - var clientSuffix string - if pcfg.IncludeClient { - var clients []string - for key := range prog.Counters { - if client, ok := strings.CutPrefix(key, "gopls/client:"); ok { - clients = append(clients, client) - } - } - sort.Strings(clients) - if len(clients) > 0 { - clientSuffix = strings.Join(clients, ",") - } - } - - info := Info{ - Program: prog.Program, - ProgramVersion: prog.Version, - GoVersion: prog.GoVersion, - GOOS: prog.GOOS, - GOARCH: prog.GOARCH, - GoplsClient: clientSuffix, - } - for stack, count := range prog.Stacks { - counts := stacks[stack] - if counts == nil { - counts = make(map[Info]int64) - stacks[stack] = counts - } - counts[info] += count - stackToURL[stack] = url - } - distinctStacks += len(prog.Stacks) - } - } - } - - return stacks, distinctStacks, stackToURL, nil -} - -// readIssues returns all existing issues for the given program and parses any -// predicates. -func readIssues(cli *githubClient, pcfg ProgramConfig) ([]*Issue, error) { - // Query GitHub for all existing GitHub issues with the report label. - issues, err := cli.searchIssues(pcfg.Repository, pcfg.SearchLabel) - if err != nil { - // TODO(jba): return error instead of dying, or doc. - log.Fatalf("GitHub issues label %q search failed: %v", pcfg.SearchLabel, err) - } - - // Extract and validate predicate expressions in ```#!stacks...``` code blocks. - // See the package doc comment for the grammar. - for _, issue := range issues { - block := findPredicateBlock(issue.Body) - if block != "" { - pred, err := parsePredicate(block) - if err != nil { - log.Printf("invalid predicate in issue #%d: %v\n<<%s>>", - issue.Number, err, block) - continue - } - issue.predicate = pred - } - } - - return issues, nil -} - -// parsePredicate parses a predicate expression, returning a function that evaluates -// the predicate on a stack. -// The expression must match this grammar: -// -// expr = "string literal" -// | ( expr ) -// | ! expr -// | expr && expr -// | expr || expr -// -// The value of a string literal is whether it is a substring of the stack, respecting word boundaries. -// That is, a literal L behaves like the regular expression \bL'\b, where L' is L with -// regexp metacharacters quoted. -func parsePredicate(s string) (func(string) bool, error) { - expr, err := parser.ParseExpr(s) - if err != nil { - return nil, fmt.Errorf("parse error: %w", err) - } - - // Cache compiled regexps since we need them more than once. - literalRegexps := make(map[*ast.BasicLit]*regexp.Regexp) - - // Check for errors in the predicate so we can report them now, - // ensuring that evaluation is error-free. - var validate func(ast.Expr) error - validate = func(e ast.Expr) error { - switch e := e.(type) { - case *ast.UnaryExpr: - if e.Op != token.NOT { - return fmt.Errorf("invalid op: %s", e.Op) - } - return validate(e.X) - - case *ast.BinaryExpr: - if e.Op != token.LAND && e.Op != token.LOR { - return fmt.Errorf("invalid op: %s", e.Op) - } - if err := validate(e.X); err != nil { - return err - } - return validate(e.Y) - - case *ast.ParenExpr: - return validate(e.X) - - case *ast.BasicLit: - if e.Kind != token.STRING { - return fmt.Errorf("invalid literal (%s)", e.Kind) - } - lit, err := strconv.Unquote(e.Value) - if err != nil { - return err - } - // The end of the literal (usually "symbol", - // "pkg.symbol", or "pkg.symbol:+1") must - // match a word boundary. However, the start - // of the literal need not: an input line such - // as "domain.name/dir/pkg.symbol:+1" should - // match literal "pkg.symbol", but the slash - // is not a word boundary (witness: - // https://go.dev/play/p/w-8ev_VUBSq). - // - // It may match multiple words if it contains - // non-word runes like whitespace. - // - // The constructed regular expression is always valid. - literalRegexps[e] = regexp.MustCompile(regexp.QuoteMeta(lit) + `\b`) - - default: - return fmt.Errorf("syntax error (%T)", e) - } - return nil - } - if err := validate(expr); err != nil { - return nil, err - } - - return func(stack string) bool { - var eval func(ast.Expr) bool - eval = func(e ast.Expr) bool { - switch e := e.(type) { - case *ast.UnaryExpr: - return !eval(e.X) - - case *ast.BinaryExpr: - if e.Op == token.LAND { - return eval(e.X) && eval(e.Y) - } else { - return eval(e.X) || eval(e.Y) - } - - case *ast.ParenExpr: - return eval(e.X) - - case *ast.BasicLit: - return literalRegexps[e].MatchString(stack) - } - panic("unreachable") - } - return eval(expr) - }, nil -} - -// claimStacks maps each stack ID to its issue (if any). -// -// It returns a map of stack text to the issue that claimed it. -// -// An issue can claim a stack two ways: -// -// 1. if the issue body contains the ID of the stack. Matching -// is a little loose but base64 will rarely produce words -// that appear in the body by chance. -// -// 2. if the issue body contains a ```#!stacks``` predicate -// that matches the stack. -// -// We log an error if two different issues attempt to claim -// the same stack. -func claimStacks(issues []*Issue, stacks map[string]map[Info]int64) map[string]*Issue { - // This is O(new stacks x existing issues). - claimedBy := make(map[string]*Issue) - for stack := range stacks { - id := stackID(stack) - for _, issue := range issues { - byPredicate := false - if strings.Contains(issue.Body, id) { - // nop - } else if issue.predicate != nil && issue.predicate(stack) { - byPredicate = true - } else { - continue - } - - if prev := claimedBy[id]; prev != nil && prev != issue { - log.Printf("stack %s is claimed by issues #%d and #%d:%s", - id, prev.Number, issue.Number, strings.ReplaceAll("\n"+stack, "\n", "\n- ")) - continue - } - if false { - log.Printf("stack %s claimed by issue #%d", - id, issue.Number) - } - claimedBy[id] = issue - if byPredicate { - // The stack ID matched the predicate but was not - // found in the issue body, so this is a new stack. - issue.newStacks = append(issue.newStacks, stack) - } - } - } - - return claimedBy -} - -// updateIssues updates existing issues that claimed new stacks by predicate. -func updateIssues(cli *githubClient, repo string, issues []*Issue, stacks map[string]map[Info]int64, stackToURL map[string]string) { - for _, issue := range issues { - if len(issue.newStacks) == 0 { - continue - } - - // Add a comment to the existing issue listing all its new stacks. - // (Save the ID of each stack for the second step.) - comment := new(bytes.Buffer) - var newStackIDs []string - for _, stack := range issue.newStacks { - id := stackID(stack) - newStackIDs = append(newStackIDs, id) - writeStackComment(comment, stack, id, stackToURL[stack], stacks[stack]) - } - - if err := cli.addIssueComment(repo, issue.Number, comment.String()); err != nil { - log.Println(err) - continue - } - - // Append to the "Dups: ID ..." list on last line of issue body. - body := strings.TrimSpace(issue.Body) - lastLineStart := strings.LastIndexByte(body, '\n') + 1 - lastLine := body[lastLineStart:] - if !strings.HasPrefix(lastLine, "Dups:") { - body += "\nDups:" - } - body += " " + strings.Join(newStackIDs, " ") - - update := updateIssue{number: issue.Number, Body: body} - if shouldReopen(issue, stacks) { - update.State = "open" - update.StateReason = "reopened" - } - if err := cli.updateIssue(repo, update); err != nil { - log.Printf("added comment to issue #%d but failed to update: %v", - issue.Number, err) - continue - } - - log.Printf("added stacks %s to issue #%d", newStackIDs, issue.Number) - } -} - -// An issue should be re-opened if it was closed as fixed, and at least one of the -// new stacks happened since the version containing the fix. -func shouldReopen(issue *Issue, stacks map[string]map[Info]int64) bool { - if !issue.isFixed() { - return false - } - issueProgram, issueVersion, ok := parseMilestone(issue.Milestone) - if !ok { - return false - } - - matchProgram := func(infoProg string) bool { - switch issueProgram { - case "gopls": - return path.Base(infoProg) == issueProgram - case "go": - // At present, we only care about compiler stacks. - // Issues should have milestones like "Go1.24". - return infoProg == "cmd/compile" - default: - return false - } - } - - for _, stack := range issue.newStacks { - for info := range stacks[stack] { - if matchProgram(info.Program) && semver.Compare(semVer(info.ProgramVersion), issueVersion) >= 0 { - log.Printf("reopening issue #%d: purportedly fixed in %s@%s, but found a new stack from version %s", - issue.Number, issueProgram, issueVersion, info.ProgramVersion) - return true - } - } - } - return false -} - -// An issue is fixed if it was closed because it was completed. -func (i *Issue) isFixed() bool { - return i.State == "closed" && i.StateReason == "completed" -} - -// parseMilestone parses a the title of a GitHub milestone. -// If it is in the format PROGRAM/VERSION (for example, "gopls/v0.17.0"), -// then it returns PROGRAM and VERSION. -// If it is in the format Go1.X, then it returns "go" as the program and -// "v1.X" or "v1.X.0" as the version. -// Otherwise, the last return value is false. -func parseMilestone(m *Milestone) (program, version string, ok bool) { - if m == nil { - return "", "", false - } - if strings.HasPrefix(m.Title, "Go") { - v := semVer(m.Title) - if !semver.IsValid(v) { - return "", "", false - } - return "go", v, true - } - program, version, ok = morestrings.CutLast(m.Title, "/") - if !ok || program == "" || version == "" || version[0] != 'v' { - return "", "", false - } - return program, version, true -} - -// semVer returns a semantic version for its argument, which may already be -// a semantic version, or may be a Go version. -// -// v1.2.3 => v1.2.3 -// go1.24 => v1.24 -// Go1.23.5 => v1.23.5 -// goHome => vHome -// -// It returns "", false if the go version is in the wrong format. -func semVer(v string) string { - if strings.HasPrefix(v, "go") || strings.HasPrefix(v, "Go") { - return "v" + v[2:] - } - return v -} - -// stackID returns a 32-bit identifier for a stack -// suitable for use in GitHub issue titles. -func stackID(stack string) string { - // Encode it using base64 (6 bytes) for brevity, - // as a single issue's body might contain multiple IDs - // if separate issues with same cause were manually de-duped, - // e.g. "AAAAAA, BBBBBB" - // - // https://hbfs.wordpress.com/2012/03/30/finding-collisions: - // the chance of a collision is 1 - exp(-n(n-1)/2d) where n - // is the number of items and d is the number of distinct values. - // So, even with n=10^4 telemetry-reported stacks each identified - // by a uint32 (d=2^32), we have a 1% chance of a collision, - // which is plenty good enough. - h := fnv.New32() - io.WriteString(h, stack) - return base64.URLEncoding.EncodeToString(h.Sum(nil))[:6] -} - -// newIssue creates a browser tab with a populated GitHub "New issue" -// form for the specified stack. (The triage person is expected to -// manually de-dup the issue before deciding whether to submit the form.) -// -// It returns the title. -func newIssue(pcfg ProgramConfig, stack, id, jsonURL string, counts map[Info]int64) string { - // Use a heuristic to find a suitable symbol to blame in the title: the - // first public function or method of a public type, in - // MatchSymbolPrefix, to appear in the stack trace. We can always - // refine it later. - // - // TODO(adonovan): include in the issue a source snippet ±5 - // lines around the PC in this symbol. - var symbol string -outer: - for line := range strings.SplitSeq(stack, "\n") { - for _, s := range pcfg.IgnoreSymbolContains { - if strings.Contains(line, s) { - continue outer // not interesting - } - } - // Look for: - // pcfg.MatchSymbolPrefix/.../pkg.Func - // pcfg.MatchSymbolPrefix/.../pkg.Type.method - // pcfg.MatchSymbolPrefix/.../pkg.(*Type).method - if _, rest, ok := strings.Cut(line, pcfg.MatchSymbolPrefix); ok { - if i := strings.IndexByte(rest, '.'); i >= 0 { - rest = rest[i+1:] - rest = strings.TrimPrefix(rest, "(*") - if rest != "" && 'A' <= rest[0] && rest[0] <= 'Z' { - rest, _, _ = strings.Cut(rest, ":") - symbol = " " + rest - break - } - } - } - } - - // Populate the form (title, body, label) - title := fmt.Sprintf("%s: bug in %s", pcfg.NewIssuePrefix, symbol) - - body := new(bytes.Buffer) - - // Add a placeholder ```#!stacks``` block since this is a new issue. - body.WriteString("```" + ` -#!stacks -"" -` + "```\n") - fmt.Fprintf(body, "Issue created by [stacks](https://pkg.go.dev/golang.org/x/tools/gopls/internal/telemetry/cmd/stacks).\n\n") - - writeStackComment(body, stack, id, jsonURL, counts) - - labels := strings.Join(pcfg.NewIssueLabels, ",") - - // Report it. The user will interactively finish the task, - // since they will typically de-dup it without even creating a new issue - // by expanding the #!stacks predicate of an existing issue. - if !browser.Open("https://github.com/" + pcfg.Repository + "/issues/new?labels=" + labels + "&title=" + url.QueryEscape(title) + "&body=" + url.QueryEscape(body.String())) { - log.Print("Please file a new issue at golang.org/issue/new using this template:\n\n") - log.Printf("Title: %s\n", title) - log.Printf("Labels: %s\n", labels) - log.Printf("Body: %s\n", body) - } - - return title -} - -// writeStackComment writes a stack in Markdown form, for a new GitHub -// issue or new comment on an existing one. -func writeStackComment(body *bytes.Buffer, stack, id string, jsonURL string, counts map[Info]int64) { - if len(counts) == 0 { - panic("no counts") - } - var info Info // pick an arbitrary key - for info = range counts { - break - } - - fmt.Fprintf(body, "This stack `%s` was [reported by telemetry](%s):\n\n", - id, jsonURL) - - // Read the mapping from symbols to file/line. - pclntab, err := readPCLineTable(info, defaultStacksDir) - if err != nil { - log.Fatal(err) - } - - // Parse the stack and get the symbol names out. - for frame := range strings.SplitSeq(stack, "\n") { - if url := frameURL(pclntab, info, frame); url != "" { - fmt.Fprintf(body, "- [`%s`](%s)\n", frame, url) - } else { - fmt.Fprintf(body, "- `%s`\n", frame) - } - } - - // Add counts, gopls version, and platform info. - // This isn't very precise but should provide clues. - fmt.Fprintf(body, "```\n") - for info, count := range counts { - fmt.Fprintf(body, "%s (%d)\n", info, count) - } - fmt.Fprintf(body, "```\n\n") -} - -// frameURL returns the CodeSearch URL for the stack frame, if known. -func frameURL(pclntab map[string]FileLine, info Info, frame string) string { - // e.g. "golang.org/x/tools/gopls/foo.(*Type).Method.inlined.func3:+5" - symbol, offset, ok := strings.Cut(frame, ":") - if !ok { - // Not a symbol (perhaps stack counter title: "gopls/bug"?) - return "" - } - - fileline, ok := pclntab[symbol] - if !ok { - // objdump reports ELF symbol names, which in - // rare cases may be the Go symbols of - // runtime.CallersFrames mangled by (e.g.) the - // addition of .abi0 suffix; see - // https://github.com/golang/go/issues/69390#issuecomment-2343795920 - // So this should not be a hard error. - if symbol != "runtime.goexit" { - log.Printf("no pclntab info for symbol: %s", symbol) - } - return "" - } - - if offset == "" { - log.Fatalf("missing line offset: %s", frame) - } - if unicode.IsDigit(rune(offset[0])) { - // Fix gopls/v0.14.2 legacy syntax ":%d" -> ":+%d". - offset = "+" + offset - } - offsetNum, err := strconv.Atoi(offset[1:]) - if err != nil { - log.Fatalf("invalid line offset: %s", frame) - } - linenum := fileline.line - switch offset[0] { - case '-': - linenum -= offsetNum - case '+': - linenum += offsetNum - case '=': - linenum = offsetNum - } - - // Construct CodeSearch URL. - - // std module? - firstSegment, _, _ := strings.Cut(fileline.file, "/") - if !strings.Contains(firstSegment, ".") { - // (First segment is a dir beneath GOROOT/src, not a module domain name.) - return fmt.Sprintf("https://cs.opensource.google/go/go/+/%s:src/%s;l=%d", - info.GoVersion, fileline.file, linenum) - } - - // x/tools repo (tools or gopls module)? - if rest, ok := strings.CutPrefix(fileline.file, "golang.org/x/tools"); ok { - if rest[0] == '/' { - // "golang.org/x/tools/gopls" -> "gopls" - rest = rest[1:] - } else if rest[0] == '@' { - // "golang.org/x/tools@version/dir/file.go" -> "dir/file.go" - rest = rest[strings.Index(rest, "/")+1:] - } - - return fmt.Sprintf("https://cs.opensource.google/go/x/tools/+/%s:%s;l=%d", - "gopls/"+info.ProgramVersion, rest, linenum) - } - - // other x/ module dependency? - // e.g. golang.org/x/sync@v0.8.0/errgroup/errgroup.go - if rest, ok := strings.CutPrefix(fileline.file, "golang.org/x/"); ok { - if modVer, filename, ok := strings.Cut(rest, "/"); ok { - if mod, version, ok := strings.Cut(modVer, "@"); ok { - return fmt.Sprintf("https://cs.opensource.google/go/x/%s/+/%s:%s;l=%d", - mod, version, filename, linenum) - } - } - } - - // Delve - const delveRepo = "github.com/go-delve/delve/" - if strings.HasPrefix(fileline.file, delveRepo) { - filename := fileline.file[len(delveRepo):] - return fmt.Sprintf("https://%sblob/%s/%s#L%d", delveRepo, info.ProgramVersion, filename, linenum) - - } - - log.Printf("no CodeSearch URL for %q (%s:%d)", - symbol, fileline.file, linenum) - return "" -} - -// -- GitHub client -- - -// A githubClient interacts with GitHub. -// During testing, updates to GitHub are saved in changes instead of being applied. -// Reads from GitHub occur normally. -type githubClient struct { - authToken string // mandatory GitHub authentication token (for R/W issues access) - divertChanges bool // divert attempted GitHub changes to the changes field instead of executing them - changes []any // slice of (addIssueComment | updateIssueBody) -} - -func (cli *githubClient) takeChanges() []any { - r := cli.changes - cli.changes = nil - return r -} - -// addIssueComment is a change for creating a comment on an issue. -type addIssueComment struct { - number int - comment string -} - -// updateIssue is a change for modifying an existing issue. -// It includes the issue number and the fields that can be updated on a GitHub issue. -// A JSON-marshaled updateIssue can be used as the body of the update request sent to GitHub. -// See https://docs.github.com/en/rest/issues/issues?apiVersion=2022-11-28#update-an-issue. -type updateIssue struct { - number int // issue number; must be unexported - Body string `json:"body,omitempty"` - State string `json:"state,omitempty"` // "open" or "closed" - StateReason string `json:"state_reason,omitempty"` // "completed", "not_planned", "reopened" -} - -// -- GitHub search -- - -// searchIssues queries the GitHub issue tracker. -func (cli *githubClient) searchIssues(repo, label string) ([]*Issue, error) { - label = url.QueryEscape(label) - - // Slurp all issues with the telemetry label. - // - // The pagination link headers have an annoying format, but ultimately - // are just ?page=1, ?page=2, etc with no extra state. So just keep - // trying new pages until we get no more results. - // - // NOTE: With this scheme, GitHub clearly has no protection against - // race conditions, so presumably we could get duplicate issues or miss - // issues across pages. - - getPage := func(page int) ([]*Issue, error) { - url := fmt.Sprintf("https://api.github.com/repos/%s/issues?state=all&labels=%s&per_page=100&page=%d", repo, label, page) - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return nil, err - } - req.Header.Add("Authorization", "Bearer "+cli.authToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("search query %s failed: %s (body: %s)", url, resp.Status, body) - } - var r []*Issue - if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { - return nil, err - } - - return r, nil - } - - var results []*Issue - for page := 1; ; page++ { - r, err := getPage(page) - if err != nil { - return nil, err - } - if len(r) == 0 { - // No more results. - break - } - - results = append(results, r...) - } - - return results, nil -} - -// updateIssue updates the numbered issue. -func (cli *githubClient) updateIssue(repo string, update updateIssue) error { - if cli.divertChanges { - cli.changes = append(cli.changes, update) - return nil - } - - data, err := json.Marshal(update) - if err != nil { - return err - } - - url := fmt.Sprintf("https://api.github.com/repos/%s/issues/%d", repo, update.number) - if err := cli.requestChange("PATCH", url, data, http.StatusOK); err != nil { - return fmt.Errorf("updating issue: %v", err) - } - return nil -} - -// addIssueComment adds a markdown comment to the numbered issue. -func (cli *githubClient) addIssueComment(repo string, number int, comment string) error { - if cli.divertChanges { - cli.changes = append(cli.changes, addIssueComment{number, comment}) - return nil - } - - // https://docs.github.com/en/rest/issues/comments#create-an-issue-comment - var payload struct { - Body string `json:"body"` - } - payload.Body = comment - data, err := json.Marshal(payload) - if err != nil { - return err - } - - url := fmt.Sprintf("https://api.github.com/repos/%s/issues/%d/comments", repo, number) - if err := cli.requestChange("POST", url, data, http.StatusCreated); err != nil { - return fmt.Errorf("creating issue comment: %v", err) - } - return nil -} - -// requestChange sends a request to url using method, which may change the state at the server. -// The data is sent as the request body, and wantStatus is the expected response status code. -func (cli *githubClient) requestChange(method, url string, data []byte, wantStatus int) error { - if *dryRun { - log.Printf("DRY RUN: %s %s", method, url) - return nil - } - req, err := http.NewRequest(method, url, bytes.NewReader(data)) - if err != nil { - return err - } - req.Header.Add("Authorization", "Bearer "+cli.authToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != wantStatus { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("request failed: %s (body: %s)", resp.Status, body) - } - return nil -} - -// See https://docs.github.com/en/rest/issues/issues?apiVersion=2022-11-28#list-repository-issues. - -type Issue struct { - Number int - HTMLURL string `json:"html_url"` - Title string - State string - StateReason string `json:"state_reason"` - User *User - CreatedAt time.Time `json:"created_at"` - Body string // in Markdown format - Milestone *Milestone - - // Set by readIssues. - predicate func(string) bool // matching predicate over stack text - - // Set by claimIssues. - newStacks []string // new stacks to add to existing issue (comments and IDs) -} - -func (issue *Issue) String() string { return fmt.Sprintf("#%d", issue.Number) } - -type User struct { - Login string - HTMLURL string `json:"html_url"` -} - -type Milestone struct { - Title string -} - -// -- pclntab -- - -type FileLine struct { - file string // "module@version/dir/file.go" or path relative to $GOROOT/src - line int -} - -const defaultStacksDir = "/tmp/stacks-cache" - -// readPCLineTable builds the gopls executable specified by info, -// reads its PC-to-line-number table, and returns the file/line of -// each TEXT symbol. -// -// stacksDir is a semi-durable temp directory (i.e. lasts for at least a few -// hours) to hold recent sources and executables. -func readPCLineTable(info Info, stacksDir string) (map[string]FileLine, error) { - // The stacks dir will be a semi-durable temp directory - // (i.e. lasts for at least hours) holding source trees - // and executables we have built recently. - // - // Each subdir will hold a specific revision. - if err := os.MkdirAll(stacksDir, 0777); err != nil { - return nil, fmt.Errorf("can't create stacks dir: %v", err) - } - - // When building a subrepo tool, we must clone the source of the - // subrepo, and run go build from that checkout. - // - // When building a main repo tool, no need to clone or change - // directories. GOTOOLCHAIN is sufficient to fetch and build the - // appropriate version. - var buildDir string - switch info.Program { - case "golang.org/x/tools/gopls": - // Fetch the source for the tools repo, - // shallow-cloning just the desired revision. - // (Skip if it's already cloned.) - revDir := filepath.Join(stacksDir, info.ProgramVersion) - if !fileExists(filepath.Join(revDir, "go.mod")) { - // We check for presence of the go.mod file, - // not just the directory itself, as the /tmp reaper - // often removes stale files before removing their directories. - // Remove those stale directories now. - _ = os.RemoveAll(revDir) // ignore errors - - // TODO(prattmic): Consider using ProgramConfig - // configuration if we add more configurations. - log.Printf("cloning tools@gopls/%s", info.ProgramVersion) - if err := shallowClone(revDir, "https://go.googlesource.com/tools", "gopls/"+info.ProgramVersion); err != nil { - _ = os.RemoveAll(revDir) // ignore errors - return nil, fmt.Errorf("clone: %v", err) - } - } - - // gopls is in its own module, we must build from there. - buildDir = filepath.Join(revDir, "gopls") - case "cmd/compile": - // Nothing to do, GOTOOLCHAIN is sufficient. - - // Switch build directories so if we happen to be in Go module - // directory its go.mod doesn't restrict the toolchain versions - // we're allowed to use. - buildDir = "/" - case "github.com/go-delve/delve/cmd/dlv": - revDir := filepath.Join(stacksDir, "delve@"+info.ProgramVersion) - if !fileExists(filepath.Join(revDir, "go.mod")) { - _ = os.RemoveAll(revDir) - log.Printf("cloning github.com/go-delve/delve@%s", info.ProgramVersion) - if err := shallowClone(revDir, "https://github.com/go-delve/delve", info.ProgramVersion); err != nil { - _ = os.RemoveAll(revDir) - return nil, fmt.Errorf("clone: %v", err) - } - } - buildDir = revDir - default: - return nil, fmt.Errorf("don't know how to build unknown program %s", info.Program) - } - - // No slashes in file name. - escapedProg := strings.Replace(info.Program, "/", "_", -1) - - // Build the executable with the correct GOTOOLCHAIN, GOOS, GOARCH. - // Use -trimpath for normalized file names. - // (Skip if it's already built.) - exe := fmt.Sprintf("exe-%s-%s.%s-%s", escapedProg, info.GoVersion, info.GOOS, info.GOARCH) - exe = filepath.Join(stacksDir, exe) - - if !fileExists(exe) { - log.Printf("building %s@%s with %s for %s/%s", - info.Program, info.ProgramVersion, info.GoVersion, info.GOOS, info.GOARCH) - - cmd := exec.Command("go", "build", "-trimpath", "-o", exe, info.Program) - cmd.Stderr = os.Stderr - cmd.Dir = buildDir - cmd.Env = append(os.Environ(), - "GOTOOLCHAIN="+info.GoVersion, - "GOEXPERIMENT=", // Don't forward GOEXPERIMENT from current environment since the GOTOOLCHAIN selected might not support the same experiments. - "GOOS="+info.GOOS, - "GOARCH="+info.GOARCH, - ) - if err := cmd.Run(); err != nil { - return nil, fmt.Errorf("building: %v (rm -fr %s?)", err, stacksDir) - } - } - - // Read pclntab of executable. - cmd := exec.Command("go", "tool", "objdump", exe) - cmd.Stdout = new(strings.Builder) - cmd.Stderr = os.Stderr - cmd.Env = append(os.Environ(), - "GOTOOLCHAIN="+info.GoVersion, - "GOEXPERIMENT=", // Don't forward GOEXPERIMENT from current environment since the GOTOOLCHAIN selected might not support the same experiments. - "GOOS="+info.GOOS, - "GOARCH="+info.GOARCH, - ) - if err := cmd.Run(); err != nil { - return nil, fmt.Errorf("reading pclntab %v", err) - } - pclntab := make(map[string]FileLine) - lines := strings.Split(fmt.Sprint(cmd.Stdout), "\n") - for i, line := range lines { - // Each function is of this form: - // - // TEXT symbol(SB) filename - // basename.go:line instruction - // ... - if !strings.HasPrefix(line, "TEXT ") { - continue - } - fields := strings.Fields(line) - if len(fields) != 3 { - continue // symbol without file (e.g. go:buildid) - } - - symbol := strings.TrimSuffix(fields[1], "(SB)") - - filename := fields[2] - - _, line, ok := strings.Cut(strings.Fields(lines[i+1])[0], ":") - if !ok { - return nil, fmt.Errorf("can't parse 'basename.go:line' from first instruction of %s:\n%s", - symbol, line) - } - linenum, err := strconv.Atoi(line) - if err != nil { - return nil, fmt.Errorf("can't parse line number of %s: %s", symbol, line) - } - pclntab[symbol] = FileLine{filename, linenum} - } - - return pclntab, nil -} - -// shallowClone performs a shallow clone of repo into dir at the given -// 'commitish' ref (any commit reference understood by git). -// -// The directory dir must not already exist. -func shallowClone(dir, repo, commitish string) error { - if err := os.Mkdir(dir, 0750); err != nil { - return fmt.Errorf("creating dir for %s: %v", repo, err) - } - - // Set a timeout for git fetch. If this proves flaky, it can be removed. - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) - defer cancel() - - // Use a shallow fetch to download just the relevant commit. - shInit := fmt.Sprintf("git init && git fetch --depth=1 %q %q && git checkout FETCH_HEAD", repo, commitish) - initCmd := exec.CommandContext(ctx, "/bin/sh", "-c", shInit) - initCmd.Dir = dir - if output, err := initCmd.CombinedOutput(); err != nil { - return fmt.Errorf("checking out %s: %v\n%s", repo, err, output) - } - return nil -} - -func fileExists(filename string) bool { - _, err := os.Stat(filename) - return err == nil -} - -// findPredicateBlock returns the content (sans "#!stacks") of the -// code block at the start of the issue body. -// Logic plundered from x/build/cmd/watchflakes/github.go. -func findPredicateBlock(body string) string { - // Extract ```-fenced or indented code block at start of issue description (body). - body = strings.ReplaceAll(body, "\r\n", "\n") - lines := strings.SplitAfter(body, "\n") - for len(lines) > 0 && strings.TrimSpace(lines[0]) == "" { - lines = lines[1:] - } - text := "" - // A code quotation is bracketed by sequence of 3+ backticks. - // (More than 3 are permitted so that one can quote 3 backticks.) - if len(lines) > 0 && strings.HasPrefix(lines[0], "```") { - marker := lines[0] - n := 0 - for n < len(marker) && marker[n] == '`' { - n++ - } - marker = marker[:n] - i := 1 - for i := 1; i < len(lines); i++ { - if strings.HasPrefix(lines[i], marker) && strings.TrimSpace(strings.TrimLeft(lines[i], "`")) == "" { - text = strings.Join(lines[1:i], "") - break - } - } - if i < len(lines) { - } - } else if strings.HasPrefix(lines[0], "\t") || strings.HasPrefix(lines[0], " ") { - i := 1 - for i < len(lines) && (strings.HasPrefix(lines[i], "\t") || strings.HasPrefix(lines[i], " ")) { - i++ - } - text = strings.Join(lines[:i], "") - } - - // Must start with #!stacks so we're sure it is for us. - hdr, rest, _ := strings.Cut(text, "\n") - hdr = strings.TrimSpace(hdr) - if hdr != "#!stacks" { - return "" - } - return rest -} - -// isTerminal reports whether file is a terminal, -// avoiding a dependency on golang.org/x/term. -func isTerminal(file *os.File) bool { - // Hardwire the constants to avoid the need for build tags. - // The values here are good for our dev machines. - switch runtime.GOOS { - case "darwin": - const TIOCGETA = 0x40487413 // from unix.TIOCGETA - _, err := unix.IoctlGetTermios(int(file.Fd()), TIOCGETA) - return err == nil - case "linux": - const TCGETS = 0x5401 // from unix.TCGETS - _, err := unix.IoctlGetTermios(int(file.Fd()), TCGETS) - return err == nil - } - panic("unreachable") -} diff --git a/gopls/internal/telemetry/cmd/stacks/stacks_test.go b/gopls/internal/telemetry/cmd/stacks/stacks_test.go deleted file mode 100644 index d7bf12f830f..00000000000 --- a/gopls/internal/telemetry/cmd/stacks/stacks_test.go +++ /dev/null @@ -1,347 +0,0 @@ -// Copyright 2025 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build linux || darwin - -package main - -import ( - "encoding/json" - "strings" - "testing" -) - -func TestReadPCLineTable(t *testing.T) { - if testing.Short() { - // TODO(prattmic): It would be nice to have a unit test that - // didn't require downloading. - t.Skip("downloads source from the internet, skipping in -short") - } - - type testCase struct { - name string - info Info - wantSymbol string - wantFileLine FileLine - } - - tests := []testCase{ - { - name: "gopls", - info: Info{ - Program: "golang.org/x/tools/gopls", - ProgramVersion: "v0.16.1", - GoVersion: "go1.23.4", - GOOS: "linux", - GOARCH: "amd64", - }, - wantSymbol: "golang.org/x/tools/gopls/internal/cmd.(*Application).Run", - wantFileLine: FileLine{ - file: "golang.org/x/tools/gopls/internal/cmd/cmd.go", - line: 230, - }, - }, - { - name: "compile", - info: Info{ - Program: "cmd/compile", - ProgramVersion: "go1.23.4", - GoVersion: "go1.23.4", - GOOS: "linux", - GOARCH: "amd64", - }, - wantSymbol: "runtime.main", - wantFileLine: FileLine{ - file: "runtime/proc.go", - line: 147, - }, - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - stacksDir := t.TempDir() - pcln, err := readPCLineTable(tc.info, stacksDir) - if err != nil { - t.Fatalf("readPCLineTable got err %v want nil", err) - } - - got, ok := pcln[tc.wantSymbol] - if !ok { - t.Fatalf("PCLineTable want entry %s got !ok from pcln %+v", tc.wantSymbol, pcln) - } - - if got != tc.wantFileLine { - t.Fatalf("symbol %s got FileLine %+v want %+v", tc.wantSymbol, got, tc.wantFileLine) - } - }) - } -} - -func TestParsePredicate(t *testing.T) { - for _, tc := range []struct { - expr string - arg string - want bool - }{ - {`"x"`, `"x"`, true}, - {`"x"`, `"axe"`, false}, // literals must match word ends - {`"xe"`, `"axe"`, true}, - {`"x"`, "val:x+5", true}, - {`"fu+12"`, "x:fu+12,", true}, - {`"fu+12"`, "snafu+12,", true}, // literals needn't match word start - {`"fu+12"`, "x:fu+123,", false}, - {`"foo:+12"`, "dir/foo:+12,", true}, // literals needn't match word start - {`"a.*b"`, "a.*b", true}, // regexp metachars are escaped - {`"a.*b"`, "axxb", false}, // ditto - {`"x"`, `"y"`, false}, - {`!"x"`, "x", false}, - {`!"x"`, "y", true}, - {`"x" && "y"`, "xy", false}, - {`"x" && "y"`, "x y", true}, - {`"x" && "y"`, "x", false}, - {`"x" && "y"`, "y", false}, - {`"xz" && "zy"`, "xzy", false}, - {`"xz" && "zy"`, "zy,xz", true}, - {`"x" || "y"`, "x\ny", true}, - {`"x" || "y"`, "x", true}, - {`"x" || "y"`, "y", true}, - {`"x" || "y"`, "z", false}, - } { - eval, err := parsePredicate(tc.expr) - if err != nil { - t.Fatal(err) - } - got := eval(tc.arg) - if got != tc.want { - t.Errorf("%s applied to %q: got %t, want %t", tc.expr, tc.arg, got, tc.want) - } - } -} - -func TestParsePredicateError(t *testing.T) { - // Validate that bad predicates return errors. - for _, expr := range []string{ - ``, - `1`, - `foo`, // an identifier, not a literal - `"x" + "y"`, - `"x" &&`, - `~"x"`, - `f(1)`, - } { - if _, err := parsePredicate(expr); err == nil { - t.Errorf("%s: got nil, want error", expr) - } - } -} - -// which takes the bulk of the time. -func TestUpdateIssues(t *testing.T) { - if testing.Short() { - t.Skip("downloads source from the internet, skipping in -short") - } - - c := &githubClient{divertChanges: true} - const stack1 = "stack1" - id1 := stackID(stack1) - stacksToURL := map[string]string{stack1: "URL1"} - - // checkIssueComment asserts that the change adds an issue of the specified - // number, with a body that contains various strings. - checkIssueComment := func(t *testing.T, change any, number int, version string) { - t.Helper() - cic, ok := change.(addIssueComment) - if !ok { - t.Fatalf("got %T, want addIssueComment", change) - } - if cic.number != number { - t.Errorf("issue number: got %d, want %d", cic.number, number) - } - for _, want := range []string{"URL1", stack1, id1, "golang.org/x/tools/gopls@" + version} { - if !strings.Contains(cic.comment, want) { - t.Errorf("missing %q in comment:\n%s", want, cic.comment) - } - } - } - - t.Run("open issue", func(t *testing.T) { - issues := []*Issue{{ - Number: 1, - State: "open", - newStacks: []string{stack1}, - }} - - info := Info{ - Program: "golang.org/x/tools/gopls", - ProgramVersion: "v0.16.1", - } - stacks := map[string]map[Info]int64{stack1: map[Info]int64{info: 3}} - updateIssues(c, "golang/go", issues, stacks, stacksToURL) - changes := c.takeChanges() - - if g, w := len(changes), 2; g != w { - t.Fatalf("got %d changes, want %d", g, w) - } - - // The first change creates an issue comment. - checkIssueComment(t, changes[0], 1, "v0.16.1") - - // The second change updates the issue body, and only the body. - ui, ok := changes[1].(updateIssue) - if !ok { - t.Fatalf("got %T, want updateIssue", changes[1]) - } - if ui.number != 1 { - t.Errorf("issue number: got %d, want 1", ui.number) - } - if ui.Body == "" || ui.State != "" || ui.StateReason != "" { - t.Errorf("updating other than just the body:\n%+v", ui) - } - want := "Dups: " + id1 - if !strings.Contains(ui.Body, want) { - t.Errorf("missing %q in body %q", want, ui.Body) - } - }) - t.Run("should be reopened", func(t *testing.T) { - issues := []*Issue{{ - // Issue purportedly fixed in v0.16.0 - Number: 2, - State: "closed", - StateReason: "completed", - Milestone: &Milestone{Title: "gopls/v0.16.0"}, - newStacks: []string{stack1}, - }} - // New stack in a later version. - info := Info{ - Program: "golang.org/x/tools/gopls", - ProgramVersion: "v0.17.0", - } - stacks := map[string]map[Info]int64{stack1: map[Info]int64{info: 3}} - updateIssues(c, "golang/go", issues, stacks, stacksToURL) - - changes := c.takeChanges() - if g, w := len(changes), 2; g != w { - t.Fatalf("got %d changes, want %d", g, w) - } - // The first change creates an issue comment. - checkIssueComment(t, changes[0], 2, "v0.17.0") - - // The second change updates the issue body, state, and state reason. - ui, ok := changes[1].(updateIssue) - if !ok { - t.Fatalf("got %T, want updateIssue", changes[1]) - } - if ui.number != 2 { - t.Errorf("issue number: got %d, want 2", ui.number) - } - if ui.Body == "" || ui.State != "open" || ui.StateReason != "reopened" { - t.Errorf(`update fields should be non-empty body, state "open", state reason "reopened":\n%+v`, ui) - } - want := "Dups: " + id1 - if !strings.Contains(ui.Body, want) { - t.Errorf("missing %q in body %q", want, ui.Body) - } - - }) - -} - -func TestMarshalUpdateIssueFields(t *testing.T) { - // Verify that only the non-empty fields of updateIssueFields are marshalled. - for _, tc := range []struct { - fields updateIssue - want string - }{ - {updateIssue{Body: "b"}, `{"body":"b"}`}, - {updateIssue{State: "open"}, `{"state":"open"}`}, - {updateIssue{State: "open", StateReason: "reopened"}, `{"state":"open","state_reason":"reopened"}`}, - } { - bytes, err := json.Marshal(tc.fields) - if err != nil { - t.Fatal(err) - } - got := string(bytes) - if got != tc.want { - t.Errorf("%+v: got %s, want %s", tc.fields, got, tc.want) - } - } -} - -func TestShouldReopen(t *testing.T) { - const stack = "stack" - const gopls = "golang.org/x/tools/gopls" - goplsMilestone := &Milestone{Title: "gopls/v0.2.0"} - goMilestone := &Milestone{Title: "Go1.23"} - - for _, tc := range []struct { - name string - issue Issue - info Info - want bool - }{ - { - "issue open", - Issue{State: "open", Milestone: goplsMilestone}, - Info{Program: gopls, ProgramVersion: "v0.2.0"}, - false, - }, - { - "issue closed but not fixed", - Issue{State: "closed", StateReason: "not_planned", Milestone: goplsMilestone}, - Info{Program: gopls, ProgramVersion: "v0.2.0"}, - false, - }, - { - "different program", - Issue{State: "closed", StateReason: "completed", Milestone: goplsMilestone}, - Info{Program: "other", ProgramVersion: "v0.2.0"}, - false, - }, - { - "later version", - Issue{State: "closed", StateReason: "completed", Milestone: goplsMilestone}, - Info{Program: gopls, ProgramVersion: "v0.3.0"}, - true, - }, - { - "earlier version", - Issue{State: "closed", StateReason: "completed", Milestone: goplsMilestone}, - Info{Program: gopls, ProgramVersion: "v0.1.0"}, - false, - }, - { - "same version", - Issue{State: "closed", StateReason: "completed", Milestone: goplsMilestone}, - Info{Program: gopls, ProgramVersion: "v0.2.0"}, - true, - }, - { - "compiler later version", - Issue{State: "closed", StateReason: "completed", Milestone: goMilestone}, - Info{Program: "cmd/compile", ProgramVersion: "go1.24"}, - true, - }, - { - "compiler earlier version", - Issue{State: "closed", StateReason: "completed", Milestone: goMilestone}, - Info{Program: "cmd/compile", ProgramVersion: "go1.22"}, - false, - }, - { - "compiler same version", - Issue{State: "closed", StateReason: "completed", Milestone: goMilestone}, - Info{Program: "cmd/compile", ProgramVersion: "go1.23"}, - true, - }, - } { - t.Run(tc.name, func(t *testing.T) { - tc.issue.Number = 1 - tc.issue.newStacks = []string{stack} - got := shouldReopen(&tc.issue, map[string]map[Info]int64{stack: map[Info]int64{tc.info: 1}}) - if got != tc.want { - t.Errorf("got %t, want %t", got, tc.want) - } - }) - } -} From 6e8a193e2414a54f8df129199168f6fad76efb52 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 12 Jan 2024 18:35:44 -0500 Subject: [PATCH 181/196] gopls/internal/debug: integrate flight recorder This change adds support to gopls for Flight Recorder, the new always-on version of the Go 1.25's runtime's event tracing. Simply visit the gopls debug page and hit Flight Recorder, and you'll immediately see the event trace for the past 30 seconds. Also, a test of basic functionality. Updates golang/go#66843 - in process http.Handler for trace Updates golang/go#63185 - flight recorder runtime API Change-Id: I6335e986990445014a51ed923b5ee7093c723fe9 Reviewed-on: https://go-review.googlesource.com/c/tools/+/555716 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan --- gopls/internal/debug/flight.go | 129 ++++++++++++++++++ gopls/internal/debug/flight_go124.go | 16 +++ gopls/internal/debug/serve.go | 15 +- .../test/integration/web/flight_test.go | 66 +++++++++ 4 files changed, 221 insertions(+), 5 deletions(-) create mode 100644 gopls/internal/debug/flight.go create mode 100644 gopls/internal/debug/flight_go124.go create mode 100644 gopls/internal/test/integration/web/flight_test.go diff --git a/gopls/internal/debug/flight.go b/gopls/internal/debug/flight.go new file mode 100644 index 00000000000..2eb179061d2 --- /dev/null +++ b/gopls/internal/debug/flight.go @@ -0,0 +1,129 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.25 + +package debug + +import ( + "bufio" + "fmt" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime/trace" + "strings" + "sync" + "time" +) + +// The FlightRecorder is a global resource, so create at most one per process. +var getRecorder = sync.OnceValues(func() (*trace.FlightRecorder, error) { + fr := trace.NewFlightRecorder(trace.FlightRecorderConfig{ + // half a minute is usually enough to know "what just happened?" + MinAge: 30 * time.Second, + }) + if err := fr.Start(); err != nil { + return nil, err + } + return fr, nil +}) + +func startFlightRecorder() (http.HandlerFunc, error) { + fr, err := getRecorder() + if err != nil { + return nil, err + } + + // Return a handler that writes the most recent flight record, + // starts a trace viewer server, and redirects to it. + return func(w http.ResponseWriter, r *http.Request) { + errorf := func(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + http.Error(w, msg, http.StatusInternalServerError) + } + + // Write the most recent flight record into a temp file. + f, err := os.CreateTemp("", "flightrecord") + if err != nil { + errorf("can't create temp file for flight record: %v", err) + return + } + if _, err := fr.WriteTo(f); err != nil { + f.Close() + errorf("failed to write flight record: %s", err) + return + } + if err := f.Close(); err != nil { + errorf("failed to close flight record: %s", err) + return + } + tracefile, err := filepath.Abs(f.Name()) + if err != nil { + errorf("can't absolutize name of trace file: %v", err) + return + } + + // Run 'go tool trace' to start a new trace-viewer + // web server process. It will run until gopls terminates. + // (It would be nicer if we could just link it in; see #66843.) + cmd := exec.Command("go", "tool", "trace", tracefile) + + // Don't connect trace's std{out,err} to our os.Stderr directly, + // otherwise the child may outlive the parent in tests, + // and 'go test' will complain about unclosed pipes. + // Instead, interpose a pipe that will close when gopls exits. + // See CL 677262 for a better solution (a cmd/trace flag). + // (#66843 is of course better still.) + // Also, this notifies us of the server's readiness and URL. + urlC := make(chan string) + { + r, w, err := os.Pipe() + if err != nil { + errorf("can't create pipe: %v", err) + return + } + go func() { + // Copy from the pipe to stderr, + // keeping an eye out for the "listening on URL" string. + scan := bufio.NewScanner(r) + for scan.Scan() { + line := scan.Text() + if _, url, ok := strings.Cut(line, "Trace viewer is listening on "); ok { + urlC <- url + } + fmt.Fprintln(os.Stderr, line) + } + if err := scan.Err(); err != nil { + log.Printf("reading from pipe to cmd/trace: %v", err) + } + }() + cmd.Stderr = w + cmd.Stdout = w + } + + // Suppress the usual cmd/trace behavior of opening a new + // browser tab by setting BROWSER to /usr/bin/true (a no-op). + cmd.Env = append(os.Environ(), "BROWSER=true") + if err := cmd.Start(); err != nil { + errorf("failed to start trace server: %s", err) + return + } + + select { + case addr := <-urlC: + // Success! Send a redirect to the new location. + // (This URL bypasses the help screen at /.) + http.Redirect(w, r, addr+"/trace?view=proc", 302) + + case <-r.Context().Done(): + errorf("canceled") + + case <-time.After(10 * time.Second): + errorf("trace viewer failed to start", err) + } + }, nil +} diff --git a/gopls/internal/debug/flight_go124.go b/gopls/internal/debug/flight_go124.go new file mode 100644 index 00000000000..807fa11093e --- /dev/null +++ b/gopls/internal/debug/flight_go124.go @@ -0,0 +1,16 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.25 + +package debug + +import ( + "errors" + "net/http" +) + +func startFlightRecorder() (http.HandlerFunc, error) { + return nil, errors.ErrUnsupported +} diff --git a/gopls/internal/debug/serve.go b/gopls/internal/debug/serve.go index 77a86d8c8da..c7729103b35 100644 --- a/gopls/internal/debug/serve.go +++ b/gopls/internal/debug/serve.go @@ -438,6 +438,13 @@ func (i *Instance) Serve(ctx context.Context, addr string) (string, error) { mux.HandleFunc("/debug/pprof/profile", pprof.Profile) mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + + if h, err := startFlightRecorder(); err != nil { + stdlog.Printf("failed to start flight recorder: %v", err) // e.g. go1.24 + } else { + mux.HandleFunc("/flightrecorder", h) + } + if i.prometheus != nil { mux.HandleFunc("/metrics/", i.prometheus.Serve) } @@ -468,11 +475,8 @@ func (i *Instance) Serve(ctx context.Context, addr string) (string, error) { http.Error(w, "made a bug", http.StatusOK) }) - if err := http.Serve(listener, mux); err != nil { - event.Error(ctx, "Debug server failed", err) - return - } - event.Log(ctx, "Debug server finished") + err := http.Serve(listener, mux) // always non-nil + event.Error(ctx, "Debug server failed", err) }() return i.listenedDebugAddress, nil } @@ -650,6 +654,7 @@ body { Metrics RPC Trace +Flight recorder Analysis

    {{template "title" .}}

    diff --git a/gopls/internal/test/integration/web/flight_test.go b/gopls/internal/test/integration/web/flight_test.go new file mode 100644 index 00000000000..0aba411aab0 --- /dev/null +++ b/gopls/internal/test/integration/web/flight_test.go @@ -0,0 +1,66 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package web_test + +import ( + "encoding/json" + "runtime" + "testing" + + "golang.org/x/tools/gopls/internal/protocol" + "golang.org/x/tools/gopls/internal/protocol/command" + . "golang.org/x/tools/gopls/internal/test/integration" + "golang.org/x/tools/internal/testenv" +) + +// TestFlightRecorder checks that the flight recorder is minimally functional. +func TestFlightRecorder(t *testing.T) { + // The usual UNIX mechanisms cause timely termination of the + // cmd/trace process, but this doesn't happen on Windows, + // leading to CI failures because of process->file locking. + // Rather than invent a complex mechanism, skip the test: + // this feature is only for gopls developers anyway. + // Better long term solutions are CL 677262 and issue #66843. + if runtime.GOOS == "windows" { + t.Skip("not reliable on windows") + } + testenv.NeedsGo1Point(t, 25) + + const files = ` +-- go.mod -- +module example.com + +-- a/a.go -- +package a + +const A = 1 +` + + Run(t, files, func(t *testing.T, env *Env) { + env.OpenFile("a/a.go") + + // Start the debug server. + var result command.DebuggingResult + env.ExecuteCommand(&protocol.ExecuteCommandParams{ + Command: command.StartDebugging.String(), + Arguments: []json.RawMessage{json.RawMessage("{}")}, // no args -> pick port + }, &result) + uri := result.URLs[0] + t.Logf("StartDebugging: URLs[0] = %s", uri) + + // Check the debug server page is sensible. + doc1 := get(t, uri) + checkMatch(t, true, doc1, "Gopls server information") + checkMatch(t, true, doc1, `Flight recorder`) + + // "Click" the Flight Recorder link. + // It should redirect to the web server + // of a "go tool trace" process. + // The resulting web page is entirely programmatic, + // so we check for an arbitrary expected symbol. + doc2 := get(t, uri+"/flightrecorder") + checkMatch(t, true, doc2, `onTraceViewerImportFail`) + }) +} From c7873a32f081fb34978527f7df03166ca57a0024 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 3 Jun 2025 17:10:19 -0400 Subject: [PATCH 182/196] gopls/internal/golang: eliminate dot import: skip keyed fields This CL fixes a bug in "eliminate dot import" that would cause it to change field F in T{F: ...} to T{pkg.F: ...}. + test Fixes golang/go#73960 Change-Id: I090dacf855ec04137b5995d22a0b6fafc3f32d61 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678595 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- gopls/internal/golang/codeaction.go | 11 ++++++++--- .../testdata/codeaction/eliminate_dot_import.txt | 16 +++++++++++++--- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index 703b06bc6a2..7ed96cf4505 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -17,6 +17,7 @@ import ( "strings" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ast/edge" "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/analysis/fillstruct" "golang.org/x/tools/gopls/internal/analysis/fillswitch" @@ -733,11 +734,15 @@ func refactorRewriteEliminateDotImport(ctx context.Context, req *codeActionsRequ continue } - // Only qualify unqualified identifiers (due to dot imports). + // Only qualify unqualified identifiers (due to dot imports) + // that reference package-level symbols. // All other references to a symbol imported from another package // are nested within a select expression (pkg.Foo, v.Method, v.Field). - if is[*ast.SelectorExpr](curId.Parent().Node()) { - continue + if ek, _ := curId.ParentEdge(); ek == edge.SelectorExpr_Sel { + continue // qualified identifier (pkg.X) or selector (T.X or e.X) + } + if !typesinternal.IsPackageLevel(use) { + continue // unqualified field reference T{X: ...} } // Make sure that the package name will not be shadowed by something else in scope. diff --git a/gopls/internal/test/marker/testdata/codeaction/eliminate_dot_import.txt b/gopls/internal/test/marker/testdata/codeaction/eliminate_dot_import.txt index e72d8bd5417..f2e4d58732e 100644 --- a/gopls/internal/test/marker/testdata/codeaction/eliminate_dot_import.txt +++ b/gopls/internal/test/marker/testdata/codeaction/eliminate_dot_import.txt @@ -1,7 +1,7 @@ This test checks the behavior of the 'remove dot import' code action. -- go.mod -- -module golang.org/lsptests/removedotimport +module example.com go 1.18 @@ -13,6 +13,7 @@ package dotimport import ( . "fmt" //@codeaction(`.`, "refactor.rewrite.eliminateDotImport", edit=a1) . "bytes" //@codeaction(`.`, "refactor.rewrite.eliminateDotImport", edit=a2) + . "time" //@codeaction(`.`, "refactor.rewrite.eliminateDotImport", edit=a3) ) var _ = a @@ -22,19 +23,28 @@ func a() { buf := NewBuffer(nil) buf.Grow(10) + + _ = Ticker{C: nil} } -- @a1/a.go -- @@ -6 +6 @@ - . "fmt" //@codeaction(`.`, "refactor.rewrite.eliminateDotImport", edit=a1) + "fmt" //@codeaction(`.`, "refactor.rewrite.eliminateDotImport", edit=a1) -@@ -13 +13 @@ +@@ -14 +14 @@ - Println("hello") + fmt.Println("hello") -- @a2/a.go -- @@ -7 +7 @@ - . "bytes" //@codeaction(`.`, "refactor.rewrite.eliminateDotImport", edit=a2) + "bytes" //@codeaction(`.`, "refactor.rewrite.eliminateDotImport", edit=a2) -@@ -15 +15 @@ +@@ -16 +16 @@ - buf := NewBuffer(nil) + buf := bytes.NewBuffer(nil) +-- @a3/a.go -- +@@ -8 +8 @@ +- . "time" //@codeaction(`.`, "refactor.rewrite.eliminateDotImport", edit=a3) ++ "time" //@codeaction(`.`, "refactor.rewrite.eliminateDotImport", edit=a3) +@@ -19 +19 @@ +- _ = Ticker{C: nil} ++ _ = time.Ticker{C: nil} From fed8cc83fd23c3333c5c249460ab8bc37ca9e6f1 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 30 May 2025 08:46:58 -0400 Subject: [PATCH 183/196] internal/refactor: keep comments with same import When adding a new import spec, set its position to just after the comment on the last existing spec. This maintains comments on existing specs. Change-Id: I24e5d3b9277f5ab3cc585f27f2bf269fea5eb18e Reviewed-on: https://go-review.googlesource.com/c/tools/+/677518 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- internal/refactor/inline/inline.go | 16 +++++++++++++++- .../inline/testdata/import-comments.txtar | 4 ++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/internal/refactor/inline/inline.go b/internal/refactor/inline/inline.go index 7e1c9994ddf..0697f3aafb8 100644 --- a/internal/refactor/inline/inline.go +++ b/internal/refactor/inline/inline.go @@ -330,20 +330,35 @@ func (st *state) inline() (*Result, error) { } } // Add new imports. + // Set their position to after the last position of the old imports, to keep + // comments on the old imports from moving. + lastPos := token.NoPos + if lastSpec := last(importDecl.Specs); lastSpec != nil { + lastPos = lastSpec.Pos() + if c := lastSpec.(*ast.ImportSpec).Comment; c != nil { + lastPos = c.Pos() + } + } for _, imp := range newImports { // Check that the new imports are accessible. path, _ := strconv.Unquote(imp.spec.Path.Value) if !analysisinternal.CanImport(caller.Types.Path(), path) { return nil, fmt.Errorf("can't inline function %v as its body refers to inaccessible package %q", callee, path) } + if lastPos.IsValid() { + lastPos++ + imp.spec.Path.ValuePos = lastPos + } importDecl.Specs = append(importDecl.Specs, imp.spec) } + var out bytes.Buffer out.Write(before) commented := &printer.CommentedNode{ Node: importDecl, Comments: comments, } + if err := format.Node(&out, fset, commented); err != nil { logf("failed to format new importDecl: %v", err) // debugging return nil, err @@ -354,7 +369,6 @@ func (st *state) inline() (*Result, error) { return nil, err } } - // Delete imports referenced only by caller.Call.Fun. for _, oldImport := range res.oldImports { specToDelete := oldImport.spec diff --git a/internal/refactor/inline/testdata/import-comments.txtar b/internal/refactor/inline/testdata/import-comments.txtar index d4a4122c4d1..b5319e48846 100644 --- a/internal/refactor/inline/testdata/import-comments.txtar +++ b/internal/refactor/inline/testdata/import-comments.txtar @@ -28,7 +28,7 @@ import ( "io" // This is an import of c. - "testdata/c" + "testdata/c" // yes, of c ) var ( @@ -52,7 +52,7 @@ import ( // This is an import of c. "testdata/b" - "testdata/c" + "testdata/c" // yes, of c ) var ( From 61f37dc0fc255d3b05a1cfe785b9561f72962288 Mon Sep 17 00:00:00 2001 From: Peter Weinberger Date: Tue, 3 Jun 2025 13:27:33 -0400 Subject: [PATCH 184/196] gopls: use new gomodcache index This CL changes the default for 'importsSource' to 'gopls', so that imports and unimpoted completions will use the GOMODCACHE index. One test had to be changed, and there is a benchmark, with some data. (On unimported completions from the module cache, the new code is more than 10 times faster, and always succeeds where the old code sometimes failed.) Change-Id: Ie5c8001ac1292498b72e5b42b51b4fcb06ab6fa9 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678475 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- gopls/doc/release/v0.19.0.md | 12 ++ .../internal/golang/completion/unimported.go | 20 +-- gopls/internal/settings/default.go | 2 +- .../test/integration/bench/unimported_test.go | 161 ++++++++++++++++++ .../marker/testdata/completion/issue62676.txt | 2 +- 5 files changed, 179 insertions(+), 18 deletions(-) create mode 100644 gopls/internal/test/integration/bench/unimported_test.go diff --git a/gopls/doc/release/v0.19.0.md b/gopls/doc/release/v0.19.0.md index 05aeb2ec738..ec9d8362c53 100644 --- a/gopls/doc/release/v0.19.0.md +++ b/gopls/doc/release/v0.19.0.md @@ -173,3 +173,15 @@ func f(x int) { println(fmt.Sprintf("+%d", x)) } ``` + +## Use index for GOMODCACHE in imports and unimported completions + +The default for the option `importsSource` changes from "goimports" to "gopls". +This has the effect of building and maintaining an index to +the packages in GOMODCACHE. +The index is stored in the directory `os.UserCacheDir()/go/imports`. +Users who want the old behavior can change the option back. Users who don't +the module cache used at all for imports or completions +can change the option to +"off". The new code is many times faster than the old when accessing the +module cache. \ No newline at end of file diff --git a/gopls/internal/golang/completion/unimported.go b/gopls/internal/golang/completion/unimported.go index e562f5cd83c..f4dbcde5c98 100644 --- a/gopls/internal/golang/completion/unimported.go +++ b/gopls/internal/golang/completion/unimported.go @@ -52,21 +52,6 @@ func (c *completer) unimported(ctx context.Context, pkgname metadata.PackageName } } // do the stdlib next. - // For now, use the workspace version of stdlib packages - // to get function snippets. CL 665335 will fix this. - var x []metadata.PackageID - for _, mp := range stdpkgs { - if slices.Contains(wsIDs, metadata.PackageID(mp)) { - x = append(x, metadata.PackageID(mp)) - } - } - if len(x) > 0 { - items := c.pkgIDmatches(ctx, x, pkgname, prefix) - if c.scoreList(items) { - return - } - } - // just use the stdlib items := c.stdlibMatches(stdpkgs, pkgname, prefix) if c.scoreList(items) { return @@ -164,7 +149,7 @@ func (c *completer) pkgIDmatches(ctx context.Context, ids []metadata.PackageID, } kind = protocol.FunctionCompletion detail = fmt.Sprintf("func (from %q)", pkg.PkgPath) - case protocol.Variable: + case protocol.Variable, protocol.Struct: kind = protocol.VariableCompletion detail = fmt.Sprintf("var (from %q)", pkg.PkgPath) case protocol.Constant: @@ -264,6 +249,9 @@ func (c *completer) modcacheMatches(pkg metadata.PackageName, prefix string) ([] case modindex.Const: kind = protocol.ConstantCompletion detail = fmt.Sprintf("const (from %s)", cand.ImportPath) + case modindex.Type: // might be a type alias + kind = protocol.VariableCompletion + detail = fmt.Sprintf("type (from %s)", cand.ImportPath) default: continue } diff --git a/gopls/internal/settings/default.go b/gopls/internal/settings/default.go index 70adc1ade02..744e8d5d352 100644 --- a/gopls/internal/settings/default.go +++ b/gopls/internal/settings/default.go @@ -39,7 +39,7 @@ func DefaultOptions(overrides ...func(*Options)) *Options { DynamicWatchedFilesSupported: true, LineFoldingOnly: false, HierarchicalDocumentSymbolSupport: true, - ImportsSource: ImportsSourceGoimports, + ImportsSource: ImportsSourceGopls, }, ServerOptions: ServerOptions{ SupportedCodeActions: map[file.Kind]map[protocol.CodeActionKind]bool{ diff --git a/gopls/internal/test/integration/bench/unimported_test.go b/gopls/internal/test/integration/bench/unimported_test.go new file mode 100644 index 00000000000..9d7139b0bce --- /dev/null +++ b/gopls/internal/test/integration/bench/unimported_test.go @@ -0,0 +1,161 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bench + +import ( + "context" + "fmt" + "go/token" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + . "golang.org/x/tools/gopls/internal/test/integration" + "golang.org/x/tools/gopls/internal/test/integration/fake" + "golang.org/x/tools/internal/modindex" +) + +// experiments show the new code about 15 times faster than the old, +// and the old code sometimes fails to find the completion +func BenchmarkLocalModcache(b *testing.B) { + budgets := []string{"0s", "100ms", "200ms", "500ms", "1s", "5s"} + sources := []string{"gopls", "goimports"} + for _, budget := range budgets { + b.Run(fmt.Sprintf("budget=%s", budget), func(b *testing.B) { + for _, source := range sources { + b.Run(fmt.Sprintf("source=%s", source), func(b *testing.B) { + runModcacheCompletion(b, budget, source) + }) + } + }) + } +} + +func runModcacheCompletion(b *testing.B, budget, source string) { + // First set up the program to be edited + gomod := ` +module mod.com + +go 1.21 +` + pat := ` +package main +var _ = %s.%s +` + pkg, name, modcache := findSym(b) + name, _, _ = strings.Cut(name, " ") + mainfile := fmt.Sprintf(pat, pkg, name) + // Second, create the Env and start gopls + dir := getTempDir() + if err := os.Mkdir(dir, 0750); err != nil { + if !os.IsExist(err) { + b.Fatal(err) + } + } + defer os.RemoveAll(dir) // is this right? needed? + if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte(gomod), 0644); err != nil { + b.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "main.go"), []byte(mainfile), 0644); err != nil { + b.Fatal(err) + } + ts, err := newGoplsConnector(nil) + if err != nil { + b.Fatal(err) + } + // PJW: put better EditorConfig here + envvars := map[string]string{ + "GOMODCACHE": modcache, + //"GOPATH": sandbox.GOPATH(), // do we need a GOPATH? + } + fc := fake.EditorConfig{ + Env: envvars, + Settings: map[string]any{ + "completeUnimported": true, + "completionBudget": budget, // "0s", "100ms" + "importsSource": source, // "gopls" or "goimports" + }, + } + sandbox, editor, awaiter, err := connectEditor(dir, fc, ts) + if err != nil { + b.Fatal(err) + } + defer sandbox.Close() + defer editor.Close(context.Background()) + if err := awaiter.Await(context.Background(), InitialWorkspaceLoad); err != nil { + b.Fatal(err) + } + env := &Env{ + TB: b, + Ctx: context.Background(), + Editor: editor, + Sandbox: sandbox, + Awaiter: awaiter, + } + // Check that completion works as expected + env.CreateBuffer("main.go", mainfile) + env.AfterChange() + if false { // warm up? or not? + loc := env.RegexpSearch("main.go", name) + completions := env.Completion(loc) + if len(completions.Items) == 0 { + b.Fatal("no completions") + } + } + + // run benchmark + for b.Loop() { + loc := env.RegexpSearch("main.go", name) + env.Completion(loc) + } +} + +// find some symbol in the module cache +func findSym(t testing.TB) (pkg, name, gomodcache string) { + initForTest(t) + cmd := exec.Command("go", "env", "GOMODCACHE") + out, err := cmd.Output() + if err != nil { + t.Fatal(err) + } + modcache := strings.TrimSpace(string(out)) + ix, err := modindex.ReadIndex(modcache) + if err != nil { + t.Fatal(err) + } + if ix == nil { + t.Fatal("no index") + } + if len(ix.Entries) == 0 { + t.Fatal("no entries") + } + nth := 100 // or something + for _, e := range ix.Entries { + if token.IsExported(e.PkgName) || strings.HasPrefix(e.PkgName, "_") { + continue // weird stuff in module cache + } + + for _, nm := range e.Names { + nth-- + if nth == 0 { + return e.PkgName, nm, modcache + } + } + } + t.Fatalf("index doesn't have enough usable names, need another %d", nth) + return "", "", modcache +} + +// Set IndexDir, avoiding the special case for tests, +func initForTest(t testing.TB) { + dir, err := os.UserCacheDir() + if err != nil { + t.Fatalf("os.UserCacheDir: %v", err) + } + dir = filepath.Join(dir, "go", "imports") + modindex.IndexDir = dir +} diff --git a/gopls/internal/test/marker/testdata/completion/issue62676.txt b/gopls/internal/test/marker/testdata/completion/issue62676.txt index af4c3b695ec..6251e944aa1 100644 --- a/gopls/internal/test/marker/testdata/completion/issue62676.txt +++ b/gopls/internal/test/marker/testdata/completion/issue62676.txt @@ -53,7 +53,7 @@ import "os" func _() { // This uses goimports-based completion; TODO: this should insert snippets. - os.Open //@acceptcompletion(re"Open()", "Open", open) + os.Open(${1:}) //@acceptcompletion(re"Open()", "Open", open) } func _() { From e43ca0ca5dda3beb023842a433a909d4fcc26807 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 6 May 2025 06:55:07 -0400 Subject: [PATCH 185/196] internal/mcp: validate tool input schemas A tool now validates its input with its InputSchema. Change schema inference to allow explicit nulls for fields of pointer type. We assume the schema has no external references. For example, the following schema cannot be handled: { "$ref": "https://example.com/other.json" } Schemas with internal references, like to a "$defs", are fine. Change-Id: I6ee7c18c2c5cb609df0b22a66da986f7ea64bbe4 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670676 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/jsonschema/infer.go | 12 ++++- internal/mcp/jsonschema/infer_test.go | 2 +- internal/mcp/jsonschema/resolve.go | 4 ++ internal/mcp/jsonschema/util.go | 2 +- internal/mcp/jsonschema/validate.go | 26 ++++++---- internal/mcp/mcp_test.go | 1 + internal/mcp/prompt.go | 6 ++- internal/mcp/tool.go | 49 +++++++++++++++-- internal/mcp/tool_test.go | 75 +++++++++++++++++++++++++++ 9 files changed, 160 insertions(+), 17 deletions(-) diff --git a/internal/mcp/jsonschema/infer.go b/internal/mcp/jsonschema/infer.go index 4ce270e5159..f044996c304 100644 --- a/internal/mcp/jsonschema/infer.go +++ b/internal/mcp/jsonschema/infer.go @@ -42,15 +42,21 @@ func For[T any]() (*Schema, error) { // - complex numbers // - unsafe pointers // +// The cannot be any cycles in the types. // TODO(rfindley): we could perhaps just skip these incompatible fields. func ForType(t reflect.Type) (*Schema, error) { return typeSchema(t) } func typeSchema(t reflect.Type) (*Schema, error) { - if t.Kind() == reflect.Pointer { + // Follow pointers: the schema for *T is almost the same as for T, except that + // an explicit JSON "null" is allowed for the pointer. + allowNull := false + for t.Kind() == reflect.Pointer { + allowNull = true t = t.Elem() } + var ( s = new(Schema) err error @@ -121,6 +127,10 @@ func typeSchema(t reflect.Type) (*Schema, error) { default: return nil, fmt.Errorf("type %v is unsupported by jsonschema", t) } + if allowNull && s.Type != "" { + s.Types = []string{"null", s.Type} + s.Type = "" + } return s, nil } diff --git a/internal/mcp/jsonschema/infer_test.go b/internal/mcp/jsonschema/infer_test.go index 150824cb947..b695d216891 100644 --- a/internal/mcp/jsonschema/infer_test.go +++ b/internal/mcp/jsonschema/infer_test.go @@ -57,7 +57,7 @@ func TestForType(t *testing.T) { Properties: map[string]*schema{ "f": {Type: "integer"}, "G": {Type: "array", Items: &schema{Type: "number"}}, - "P": {Type: "boolean"}, + "P": {Types: []string{"null", "boolean"}}, "NoSkip": {Type: "string"}, }, Required: []string{"f", "G", "P"}, diff --git a/internal/mcp/jsonschema/resolve.go b/internal/mcp/jsonschema/resolve.go index 1754135e6aa..0f913d82f51 100644 --- a/internal/mcp/jsonschema/resolve.go +++ b/internal/mcp/jsonschema/resolve.go @@ -27,6 +27,10 @@ type Resolved struct { resolvedURIs map[string]*Schema } +// Schema returns the schema that was resolved. +// It must not be modified. +func (r *Resolved) Schema() *Schema { return r.root } + // A Loader reads and unmarshals the schema at uri, if any. type Loader func(uri *url.URL) (*Schema, error) diff --git a/internal/mcp/jsonschema/util.go b/internal/mcp/jsonschema/util.go index 58c11ff1df7..550700290f0 100644 --- a/internal/mcp/jsonschema/util.go +++ b/internal/mcp/jsonschema/util.go @@ -270,7 +270,7 @@ func jsonType(v reflect.Value) (string, bool) { return "string", true case reflect.Slice, reflect.Array: return "array", true - case reflect.Map: + case reflect.Map, reflect.Struct: return "object", true default: return "", false diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index 466e34506aa..9068ae6bb6a 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -662,8 +662,8 @@ func property(v reflect.Value, name string) reflect.Value { case reflect.Struct: props := structPropertiesOf(v.Type()) // Ignore nonexistent properties. - if index, ok := props[name]; ok { - return v.FieldByIndex(index) + if sf, ok := props[name]; ok { + return v.FieldByIndex(sf.Index) } return reflect.Value{} default: @@ -673,6 +673,8 @@ func property(v reflect.Value, name string) reflect.Value { // properties returns an iterator over the names and values of all properties // in v, which must be a map or a struct. +// If a struct, zero-valued properties that are marked omitempty or omitzero +// are excluded. func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { return func(yield func(string, reflect.Value) bool) { switch v.Kind() { @@ -683,8 +685,14 @@ func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { } } case reflect.Struct: - for name, index := range structPropertiesOf(v.Type()) { - if !yield(name, v.FieldByIndex(index)) { + for name, sf := range structPropertiesOf(v.Type()) { + val := v.FieldByIndex(sf.Index) + if val.IsZero() { + if tag, ok := sf.Tag.Lookup("json"); ok && (strings.Contains(tag, "omitempty") || strings.Contains(tag, "omitzero")) { + continue + } + } + if !yield(name, val) { return } } @@ -707,8 +715,8 @@ func numPropertiesBounds(v reflect.Value, isRequired map[string]bool) (int, int) case reflect.Struct: sp := structPropertiesOf(v.Type()) min := 0 - for prop, index := range sp { - if !v.FieldByIndex(index).IsZero() || isRequired[prop] { + for prop, sf := range sp { + if !v.FieldByIndex(sf.Index).IsZero() || isRequired[prop] { min++ } } @@ -719,7 +727,7 @@ func numPropertiesBounds(v reflect.Value, isRequired map[string]bool) (int, int) } // A propertyMap is a map from property name to struct field index. -type propertyMap = map[string][]int +type propertyMap = map[string]reflect.StructField var structProperties sync.Map // from reflect.Type to propertyMap @@ -730,10 +738,10 @@ func structPropertiesOf(t reflect.Type) propertyMap { if props, ok := structProperties.Load(t); ok { return props.(propertyMap) } - props := map[string][]int{} + props := map[string]reflect.StructField{} for _, sf := range reflect.VisibleFields(t) { if name, ok := jsonName(sf); ok { - props[name] = sf.Index + props[name] = sf } } structProperties.Store(t, props) diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 96a3dd7269c..2ab5b875507 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -709,4 +709,5 @@ func traceCalls[S Session](w io.Writer, prefix string) Middleware[S] { } } +// A function, because schemas must form a tree (they have hidden state). func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonschema.Schema{}} } diff --git a/internal/mcp/prompt.go b/internal/mcp/prompt.go index f57ccfb1069..97aed980259 100644 --- a/internal/mcp/prompt.go +++ b/internal/mcp/prompt.go @@ -42,6 +42,10 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, if schema.Type != "object" || !reflect.DeepEqual(schema.AdditionalProperties, &jsonschema.Schema{Not: &jsonschema.Schema{}}) { panic(fmt.Sprintf("handler request type must be a struct")) } + resolved, err := schema.Resolve(nil) + if err != nil { + panic(err) + } prompt := &ServerPrompt{ Prompt: &Prompt{ Name: name, @@ -70,7 +74,7 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context, return nil, err } var v TReq - if err := unmarshalSchema(rawArgs, schema, &v); err != nil { + if err := unmarshalSchema(rawArgs, resolved, &v); err != nil { return nil, err } return handler(ctx, ss, v, params) diff --git a/internal/mcp/tool.go b/internal/mcp/tool.go index 099321fde1e..6e8b3cf45e4 100644 --- a/internal/mcp/tool.go +++ b/internal/mcp/tool.go @@ -5,8 +5,10 @@ package mcp import ( + "bytes" "context" "encoding/json" + "fmt" "slices" "golang.org/x/tools/internal/mcp/jsonschema" @@ -39,10 +41,17 @@ func NewTool[TReq any](name, description string, handler ToolHandler[TReq], opts if err != nil { panic(err) } + // We must resolve the schema after the ToolOptions have had a chance to update it. + // But the handler needs access to the resolved schema, and the options may change + // the handler too. + // The best we can do is use the resolved schema in our own wrapped handler, + // and hope that no ToolOption replaces it. + // TODO(jba): at a minimum, document this. + var resolved *jsonschema.Resolved wrapped := func(ctx context.Context, cc *ServerSession, params *CallToolParams[json.RawMessage]) (*CallToolResult, error) { var params2 CallToolParams[TReq] if params.Arguments != nil { - if err := unmarshalSchema(params.Arguments, schema, ¶ms2.Arguments); err != nil { + if err := unmarshalSchema(params.Arguments, resolved, ¶ms2.Arguments); err != nil { return nil, err } } @@ -68,15 +77,38 @@ func NewTool[TReq any](name, description string, handler ToolHandler[TReq], opts for _, opt := range opts { opt.set(t) } + if schema := t.Tool.InputSchema; schema != nil { + // Resolve the schema, with no base URI. We don't expect tool schemas to + // refer outside of themselves. + resolved, err = schema.Resolve(nil) + if err != nil { + panic(fmt.Errorf("resolving input schema %s: %w", schemaJSON(schema), err)) + } + } return t } // unmarshalSchema unmarshals data into v and validates the result according to -// the given schema. -func unmarshalSchema(data json.RawMessage, _ *jsonschema.Schema, v any) error { +// the given resolved schema. +func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any) error { // TODO: use reflection to create the struct type to unmarshal into. // Separate validation from assignment. - return json.Unmarshal(data, v) + + // Disallow unknown fields. + // Otherwise, if the tool was built with a struct, the client could send extra + // fields and json.Unmarshal would ignore them, so the schema would never get + // a chance to declare the extra args invalid. + dec := json.NewDecoder(bytes.NewReader(data)) + dec.DisallowUnknownFields() + if err := dec.Decode(v); err != nil { + return fmt.Errorf("unmarshaling: %w", err) + } + if resolved != nil { + if err := resolved.Validate(v); err != nil { + return fmt.Errorf("validating\n\t%s\nagainst\n\t %s:\n %w", data, schemaJSON(resolved.Schema()), err) + } + } + return nil } // A ToolOption configures the behavior of a Tool. @@ -177,3 +209,12 @@ func Schema(schema *jsonschema.Schema) SchemaOption { *s = *schema }) } + +// schemaJSON returns the JSON value for s as a string, or a string indicating an error. +func schemaJSON(s *jsonschema.Schema) string { + m, err := json.Marshal(s) + if err != nil { + return fmt.Sprintf("", err) + } + return string(m) +} diff --git a/internal/mcp/tool_test.go b/internal/mcp/tool_test.go index 646e9b32992..077ea392b82 100644 --- a/internal/mcp/tool_test.go +++ b/internal/mcp/tool_test.go @@ -6,6 +6,8 @@ package mcp_test import ( "context" + "encoding/json" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -88,3 +90,76 @@ func TestNewTool(t *testing.T) { } } } + +func TestNewToolValidate(t *testing.T) { + // Check that the tool returned from NewTool properly validates its input schema. + + type req struct { + I int + B bool + S string `json:",omitempty"` + P *int `json:",omitempty"` + } + + dummyHandler := func(context.Context, *mcp.ServerSession, *mcp.CallToolParams[req]) (*mcp.CallToolResult, error) { + return nil, nil + } + + tool := mcp.NewTool("test", "test", dummyHandler) + for _, tt := range []struct { + desc string + args map[string]any + want string // error should contain this string; empty for success + }{ + { + "both required", + map[string]any{"I": 1, "B": true}, + "", + }, + { + "optional", + map[string]any{"I": 1, "B": true, "S": "foo"}, + "", + }, + { + "wrong type", + map[string]any{"I": 1.5, "B": true}, + "cannot unmarshal", + }, + { + "extra property", + map[string]any{"I": 1, "B": true, "C": 2}, + "unknown field", + }, + { + "value for pointer", + map[string]any{"I": 1, "B": true, "P": 3}, + "", + }, + { + "null for pointer", + map[string]any{"I": 1, "B": true, "P": nil}, + "", + }, + } { + t.Run(tt.desc, func(t *testing.T) { + raw, err := json.Marshal(tt.args) + if err != nil { + t.Fatal(err) + } + _, err = tool.Handler(context.Background(), nil, + &mcp.CallToolParams[json.RawMessage]{Arguments: json.RawMessage(raw)}) + if err == nil && tt.want != "" { + t.Error("got success, wanted failure") + } + if err != nil { + if tt.want == "" { + t.Fatalf("failed with:\n%s\nwanted success", err) + } + if !strings.Contains(err.Error(), tt.want) { + t.Fatalf("got:\n%s\nwanted to contain %q", err, tt.want) + } + } + }) + } +} From cb39a5f00b69f514f37bc2a838174cc12c38c112 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 4 Jun 2025 10:56:09 -0400 Subject: [PATCH 186/196] gopls/internal/golang: Format generated files This CL reverts the effect of CL 365295 (to address issue golang/go#49555), which caused Format to fail on generated files. We believe this was a mistake: Go developers have plenty of reasons to temporarily edit generated files, for example to quickly experiment with changes without having to rewrite the code generator, or to add a log.Print statement while debugging. If a client asks gopls to format a file or organize its imports, gopls should do that, even if the file contains a "generated..." comment. + test Updates golang/go#49555 Fixes golang/go#73959 Change-Id: I2d02d5b2611b1599068df8fc267453773c66e027 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678815 LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Robert Findley --- gopls/internal/golang/format.go | 5 --- .../test/integration/misc/formatting_test.go | 33 ------------------- .../test/marker/testdata/format/generated.txt | 27 +++++++++++++++ 3 files changed, 27 insertions(+), 38 deletions(-) create mode 100644 gopls/internal/test/marker/testdata/format/generated.txt diff --git a/gopls/internal/golang/format.go b/gopls/internal/golang/format.go index ef98580abff..fc3b2e35a68 100644 --- a/gopls/internal/golang/format.go +++ b/gopls/internal/golang/format.go @@ -40,11 +40,6 @@ func Format(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) ([]pr return nil, err } - // Generated files shouldn't be edited. So, don't format them. - if ast.IsGenerated(pgf.File) { - return nil, fmt.Errorf("can't format %q: file is generated", fh.URI().Path()) - } - // Even if this file has parse errors, it might still be possible to format it. // Using format.Node on an AST with errors may result in code being modified. // Attempt to format the source of this file instead. diff --git a/gopls/internal/test/integration/misc/formatting_test.go b/gopls/internal/test/integration/misc/formatting_test.go index a0f86d3530c..190833c9868 100644 --- a/gopls/internal/test/integration/misc/formatting_test.go +++ b/gopls/internal/test/integration/misc/formatting_test.go @@ -268,39 +268,6 @@ func main() { } } -func TestFormattingOfGeneratedFile_Issue49555(t *testing.T) { - const input = ` --- main.go -- -// Code generated by generator.go. DO NOT EDIT. - -package main - -import "fmt" - -func main() { - - - - - fmt.Print("hello") -} -` - - Run(t, input, func(t *testing.T, env *Env) { - wantErrSuffix := "file is generated" - - env.OpenFile("main.go") - err := env.Editor.FormatBuffer(env.Ctx, "main.go") - if err == nil { - t.Fatal("expected error, got nil") - } - // Check only the suffix because an error contains a dynamic path to main.go - if !strings.HasSuffix(err.Error(), wantErrSuffix) { - t.Fatalf("unexpected error %q, want suffix %q", err.Error(), wantErrSuffix) - } - }) -} - func TestGofumptFormatting(t *testing.T) { // Exercise some gofumpt formatting rules: // - No empty lines following an assignment operator diff --git a/gopls/internal/test/marker/testdata/format/generated.txt b/gopls/internal/test/marker/testdata/format/generated.txt new file mode 100644 index 00000000000..5f7cd7fd0e3 --- /dev/null +++ b/gopls/internal/test/marker/testdata/format/generated.txt @@ -0,0 +1,27 @@ +This test checks that formatting includes generated files too +(reversing https://go.dev/cl/365295 to address issue #49555). + +See https://github.com/golang/go/issues/73959. + +-- flags -- +-ignore_extra_diags + +-- go.mod -- +module example.com +go 1.21 + +-- a/a.go -- +// Code generated by me. DO NOT EDIT. + +package a; func main() { fmt.Println("hello") } + +//@format(out) + +-- @out -- +// Code generated by me. DO NOT EDIT. + +package a + +func main() { fmt.Println("hello") } + +//@format(out) From 33d59880f345d37e4262f5f8e504ddfb6818266b Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 4 Jun 2025 11:13:16 -0400 Subject: [PATCH 187/196] gopls/internal/server: Organize Imports of generated files This CL causes CodeAction to offer the Organize Imports operation even in generated files. Unlike most code actions with fixes, which are a nuisance in generated files since the fix should be applied to the generator logic not its output, Organize Imports is invaluable when making experimental temporary edits in generated files, such as adding logging statements. Also, it is silent on files that are well formed, which all unedited generated files must be. + test Updates golang/go#49555 Fixes golang/go#73959 Change-Id: Ia72f5b40402175ecd80dc62e7c34e8b3cf51d011 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678835 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- gopls/internal/server/code_action.go | 10 ++++++- .../testdata/codeaction/imports-generated.txt | 27 +++++++++++++++++++ .../test/marker/testdata/format/generated.txt | 4 ++- 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 gopls/internal/test/marker/testdata/codeaction/imports-generated.txt diff --git a/gopls/internal/server/code_action.go b/gopls/internal/server/code_action.go index e37cfc9f73e..e0fcd3eacbc 100644 --- a/gopls/internal/server/code_action.go +++ b/gopls/internal/server/code_action.go @@ -180,10 +180,16 @@ func (s *server) CodeAction(ctx context.Context, params *protocol.CodeActionPara } actions = append(actions, moreActions...) - // Don't suggest fixes for generated files, since they are generally + // Don't suggest most fixes for generated files, since they are generally // not useful and some editors may apply them automatically on save. // (Unfortunately there's no reliable way to distinguish fixes from // queries, so we must list all kinds of queries here.) + // + // We make an exception for OrganizeImports, because + // (a) it is needed when making temporary experimental + // changes (e.g. adding logging) in generated files, and + // (b) it doesn't report diagnostics on well-formed code, and + // unedited generated files must be well formed. if golang.IsGenerated(ctx, snapshot, uri) { actions = slices.DeleteFunc(actions, func(a protocol.CodeAction) bool { switch a.Kind { @@ -194,6 +200,8 @@ func (s *server) CodeAction(ctx context.Context, params *protocol.CodeActionPara settings.GoplsDocFeatures, settings.GoToggleCompilerOptDetails: return false // read-only query + case settings.OrganizeImports: + return false // fix allowed in generated files (see #73959) } return true // potential write operation }) diff --git a/gopls/internal/test/marker/testdata/codeaction/imports-generated.txt b/gopls/internal/test/marker/testdata/codeaction/imports-generated.txt new file mode 100644 index 00000000000..879ea6cba33 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/imports-generated.txt @@ -0,0 +1,27 @@ +This test verifies that the 'source.organizeImports' code action +is offered in generated files (see #73959). + +-- go.mod -- +module example.com +go 1.21 + +-- a.go -- +// Code generated by me. DO NOT EDIT. + +package a //@codeaction("a", "source.organizeImports", result=out) + +func _() { + fmt.Println("hello") //@diag("fmt", re"undefined") +} + +-- @out/a.go -- +// Code generated by me. DO NOT EDIT. + +package a //@codeaction("a", "source.organizeImports", result=out) + +import "fmt" + +func _() { + fmt.Println("hello") //@diag("fmt", re"undefined") +} + diff --git a/gopls/internal/test/marker/testdata/format/generated.txt b/gopls/internal/test/marker/testdata/format/generated.txt index 5f7cd7fd0e3..3d571e00bee 100644 --- a/gopls/internal/test/marker/testdata/format/generated.txt +++ b/gopls/internal/test/marker/testdata/format/generated.txt @@ -13,7 +13,7 @@ go 1.21 -- a/a.go -- // Code generated by me. DO NOT EDIT. -package a; func main() { fmt.Println("hello") } +package a; import "fmt"; func main() { fmt.Println("hello") } //@format(out) @@ -22,6 +22,8 @@ package a; func main() { fmt.Println("hello") } package a +import "fmt" + func main() { fmt.Println("hello") } //@format(out) From 1afeefa8150f171e0a8f0948015513b31d59d2f3 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 4 Jun 2025 11:10:58 -0400 Subject: [PATCH 188/196] internal/mcp: unexport FileResourceHandler Unexport our nice root-respecting local-filesystem resource handler. It's unclear how roots should interact with a resource handler that uses the local filesystem. All we know from the spec is that servers should "respect root boundaries during operations," but it's not clear how they can do that reliably if root changes are notifications, meaning there is no causality between a client changing a root and then asking the server for a file. (Maybe the server sees the new roots, maybe it doesn't.) The spec is loose about all this: it feels more like a sketch. No other SDK seems to devote much to providing a local helper, and those that do (Python) don't look at roots. Change-Id: Ia4d9abb9bc1a6450e1fc8e607df5d7bc6aa227e3 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678855 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/server.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 0ecb3cd36bb..fa55c33b539 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -291,7 +291,7 @@ func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *Re return res, nil } -// FileResourceHandler returns a ReadResourceHandler that reads paths using dir as +// fileResourceHandler returns a ReadResourceHandler that reads paths using dir as // a base directory. // It honors client roots and protects against path traversal attacks. // @@ -303,7 +303,7 @@ func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *Re // Lexical path traversal attacks, where the path has ".." elements that escape dir, // are always caught. Go 1.24 and above also protects against symlink-based attacks, // where symlinks under dir lead out of the tree. -func (s *Server) FileResourceHandler(dir string) ResourceHandler { +func (s *Server) fileResourceHandler(dir string) ResourceHandler { return fileResourceHandler(dir) } From d9bacab54dfed6ac3f871f422bb0b2cb5eb5c428 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 4 Jun 2025 11:35:39 -0400 Subject: [PATCH 189/196] gopls/internal/server: improve "editing generated file" warning The phrasing should be more informative and less prescriptive. Also, a test. Updates golang/go#49555 Fixes golang/go#73959 Change-Id: Ie8206b39d1b03c323690095552400c75bfeb9c5a Reviewed-on: https://go-review.googlesource.com/c/tools/+/678836 Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- gopls/internal/server/text_synchronization.go | 12 ++++---- .../test/integration/misc/generate_test.go | 30 +++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/gopls/internal/server/text_synchronization.go b/gopls/internal/server/text_synchronization.go index 982d0e7a292..c11b4a22499 100644 --- a/gopls/internal/server/text_synchronization.go +++ b/gopls/internal/server/text_synchronization.go @@ -156,18 +156,20 @@ func (s *server) warnAboutModifyingGeneratedFiles(ctx context.Context, uri proto return nil } - // Ideally, we should be able to specify that a generated file should - // be opened as read-only. Tell the user that they should not be - // editing a generated file. + // Warn the user that they are editing a generated file, but + // don't try to stop them: there are often good reasons to do + // so, such as adding temporary logging, or evaluating changes + // to the generated code without the trouble of modifying the + // generator logic (see #73959). snapshot, release, err := s.session.SnapshotOf(ctx, uri) if err != nil { return err } isGenerated := golang.IsGenerated(ctx, snapshot, uri) release() - if isGenerated { - msg := fmt.Sprintf("Do not edit this file! %s is a generated file.", uri.Path()) + msg := fmt.Sprintf("Warning: editing %s, a generated file.", + filepath.Base(uri.Path())) showMessage(ctx, s.client, protocol.Warning, msg) } return nil diff --git a/gopls/internal/test/integration/misc/generate_test.go b/gopls/internal/test/integration/misc/generate_test.go index 548f3bd5f5e..f5fe226c436 100644 --- a/gopls/internal/test/integration/misc/generate_test.go +++ b/gopls/internal/test/integration/misc/generate_test.go @@ -103,3 +103,33 @@ package main env.RunGenerate("./") }) } + +func TestEditingGeneratedFileWarning(t *testing.T) { + const src = ` +-- go.mod -- +module example.com +go 1.21 + +-- a/a.go -- +// Code generated by me. DO NOT EDIT. + +package a + +var x = 1 +` + Run(t, src, func(t *testing.T, env *Env) { + env.OpenFile("a/a.go") + env.RegexpReplace("a/a.go", "var", "const") + collectMessages := env.Awaiter.ListenToShownMessages() + env.Await(env.DoneWithChange()) + messages := collectMessages() + + const want = "Warning: editing a.go, a generated file." + if len(messages) != 1 || messages[0].Message != want { + for _, message := range messages { + t.Errorf("got message %q", message.Message) + } + t.Errorf("no %q warning", want) + } + }) +} From f3c581ff0cb8b4b87129f04094005c4b0f962bf9 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 4 Jun 2025 12:23:41 -0400 Subject: [PATCH 190/196] gopls/internal/protocol: add DocumentURI.Base accessor ...and use it everywhere. Change-Id: Ib2d59840a7dbbdfc48a62edbf7264ae3e9502378 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678758 LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Robert Findley --- gopls/internal/cache/snapshot.go | 2 +- gopls/internal/cache/workspace.go | 4 ++-- gopls/internal/cmd/cmd.go | 2 +- gopls/internal/golang/addtest.go | 2 +- gopls/internal/golang/call_hierarchy.go | 7 +++---- gopls/internal/golang/codeaction.go | 3 +-- gopls/internal/mcp/context.go | 9 ++++----- gopls/internal/protocol/uri.go | 5 +++++ gopls/internal/server/command.go | 2 +- gopls/internal/server/text_synchronization.go | 3 +-- gopls/internal/test/integration/misc/definition_test.go | 2 +- gopls/internal/test/integration/modfile/modfile_test.go | 3 +-- 12 files changed, 22 insertions(+), 22 deletions(-) diff --git a/gopls/internal/cache/snapshot.go b/gopls/internal/cache/snapshot.go index e78c1bba010..49707d0d6c7 100644 --- a/gopls/internal/cache/snapshot.go +++ b/gopls/internal/cache/snapshot.go @@ -1421,7 +1421,7 @@ https://github.com/golang/tools/blob/master/gopls/doc/workspace.md.`, modDir, fi fix = `This file may be excluded due to its build tags; try adding "-tags=" to your gopls "buildFlags" configuration See the documentation for more information on working with build tags: https://github.com/golang/tools/blob/master/gopls/doc/settings.md#buildflags.` - } else if strings.Contains(filepath.Base(fh.URI().Path()), "_") { + } else if strings.Contains(fh.URI().Base(), "_") { fix = `This file may be excluded due to its GOOS/GOARCH, or other build constraints.` } else { fix = `This file is ignored by your gopls build.` // we don't know why diff --git a/gopls/internal/cache/workspace.go b/gopls/internal/cache/workspace.go index 0621d17a537..6b2291e5bc9 100644 --- a/gopls/internal/cache/workspace.go +++ b/gopls/internal/cache/workspace.go @@ -18,7 +18,7 @@ import ( // isGoWork reports if uri is a go.work file. func isGoWork(uri protocol.DocumentURI) bool { - return filepath.Base(uri.Path()) == "go.work" + return uri.Base() == "go.work" } // goWorkModules returns the URIs of go.mod files named by the go.work file. @@ -63,7 +63,7 @@ func localModFiles(relativeTo string, goWorkOrModPaths []string) map[protocol.Do // isGoMod reports if uri is a go.mod file. func isGoMod(uri protocol.DocumentURI) bool { - return filepath.Base(uri.Path()) == "go.mod" + return uri.Base() == "go.mod" } // isWorkspaceFile reports if uri matches a set of globs defined in workspaceFiles diff --git a/gopls/internal/cmd/cmd.go b/gopls/internal/cmd/cmd.go index f4cfd99a6ba..d057698d594 100644 --- a/gopls/internal/cmd/cmd.go +++ b/gopls/internal/cmd/cmd.go @@ -903,7 +903,7 @@ func (f *cmdFile) spanRange(s span) (protocol.Range, error) { // case-sensitive directories. The authoritative answer // requires querying the file system, and we don't want // to do that. - if !strings.EqualFold(filepath.Base(string(f.mapper.URI)), filepath.Base(string(s.URI()))) { + if !strings.EqualFold(f.mapper.URI.Base(), s.URI().Base()) { return protocol.Range{}, bugpkg.Errorf("mapper is for file %q instead of %q", f.mapper.URI, s.URI()) } start, err := pointPosition(f.mapper, s.Start()) diff --git a/gopls/internal/golang/addtest.go b/gopls/internal/golang/addtest.go index 73665ce9755..dfd78310f66 100644 --- a/gopls/internal/golang/addtest.go +++ b/gopls/internal/golang/addtest.go @@ -265,7 +265,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. return nil, err } - testBase := strings.TrimSuffix(filepath.Base(loc.URI.Path()), ".go") + "_test.go" + testBase := strings.TrimSuffix(loc.URI.Base(), ".go") + "_test.go" goTestFileURI := protocol.URIFromPath(filepath.Join(loc.URI.DirPath(), testBase)) testFH, err := snapshot.ReadFile(ctx, goTestFileURI) diff --git a/gopls/internal/golang/call_hierarchy.go b/gopls/internal/golang/call_hierarchy.go index 1193d7e8de8..00bc02129f9 100644 --- a/gopls/internal/golang/call_hierarchy.go +++ b/gopls/internal/golang/call_hierarchy.go @@ -11,7 +11,6 @@ import ( "go/ast" "go/token" "go/types" - "path/filepath" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/types/typeutil" @@ -59,7 +58,7 @@ func PrepareCallHierarchy(ctx context.Context, snapshot *cache.Snapshot, fh file Name: obj.Name(), Kind: protocol.Function, Tags: []protocol.SymbolTag{}, - Detail: fmt.Sprintf("%s • %s", obj.Pkg().Path(), filepath.Base(declLoc.URI.Path())), + Detail: fmt.Sprintf("%s • %s", obj.Pkg().Path(), declLoc.URI.Base()), URI: declLoc.URI, Range: rng, SelectionRange: rng, @@ -182,7 +181,7 @@ func enclosingNodeCallItem(ctx context.Context, snapshot *cache.Snapshot, pkgPat Name: name, Kind: kind, Tags: []protocol.SymbolTag{}, - Detail: fmt.Sprintf("%s • %s", pkgPath, filepath.Base(fh.URI().Path())), + Detail: fmt.Sprintf("%s • %s", pkgPath, fh.URI().Base()), URI: loc.URI, Range: rng, SelectionRange: rng, @@ -283,7 +282,7 @@ func OutgoingCalls(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle Name: obj.Name(), Kind: protocol.Function, Tags: []protocol.SymbolTag{}, - Detail: fmt.Sprintf("%s • %s", obj.Pkg().Path(), filepath.Base(loc.URI.Path())), + Detail: fmt.Sprintf("%s • %s", obj.Pkg().Path(), loc.URI.Base()), URI: loc.URI, Range: loc.Range, SelectionRange: loc.Range, diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index 7ed96cf4505..7a9212701ca 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -11,7 +11,6 @@ import ( "go/ast" "go/token" "go/types" - "path/filepath" "reflect" "slices" "strings" @@ -1108,7 +1107,7 @@ func toggleCompilerOptDetails(ctx context.Context, req *codeActionsRequest) erro title := fmt.Sprintf("%s compiler optimization details for %q", cond(req.snapshot.WantCompilerOptDetails(dir), "Hide", "Show"), - filepath.Base(dir.Path())) + dir.Base()) cmd := command.NewGCDetailsCommand(title, req.fh.URI()) req.addCommandAction(cmd, false) } diff --git a/gopls/internal/mcp/context.go b/gopls/internal/mcp/context.go index 06911915c19..83c7633dc41 100644 --- a/gopls/internal/mcp/context.go +++ b/gopls/internal/mcp/context.go @@ -13,7 +13,6 @@ import ( "fmt" "go/ast" "go/token" - "path/filepath" "slices" "strings" @@ -58,7 +57,7 @@ func contextHandler(ctx context.Context, session *cache.Session, params *mcp.Cal fmt.Fprintf(&result, "Current package %q (package %s) declares the following symbols:\n\n", pkg.Metadata().PkgPath, pkg.Metadata().Name) // Write context of the current file. { - fmt.Fprintf(&result, "%s (current file):\n", filepath.Base(pgf.URI.Path())) + fmt.Fprintf(&result, "%s (current file):\n", pgf.URI.Base()) result.WriteString("--->\n") if err := writeFileSummary(ctx, snapshot, pgf.URI, &result, false); err != nil { return nil, err @@ -73,7 +72,7 @@ func contextHandler(ctx context.Context, session *cache.Session, params *mcp.Cal continue } - fmt.Fprintf(&result, "%s:\n", filepath.Base(file.URI.Path())) + fmt.Fprintf(&result, "%s:\n", file.URI.Base()) result.WriteString("--->\n") if err := writeFileSummary(ctx, snapshot, file.URI, &result, false); err != nil { return nil, err @@ -86,7 +85,7 @@ func contextHandler(ctx context.Context, session *cache.Session, params *mcp.Cal if len(pgf.File.Imports) > 0 { // Write import decls of the current file. { - fmt.Fprintf(&result, "Current file %q contains this import declaration:\n", filepath.Base(pgf.URI.Path())) + fmt.Fprintf(&result, "Current file %q contains this import declaration:\n", pgf.URI.Base()) result.WriteString("--->\n") // Add all import decl to output including all floating comment by // using GenDecl's start and end position. @@ -127,7 +126,7 @@ func contextHandler(ctx context.Context, session *cache.Session, params *mcp.Cal fmt.Fprintf(&result, "%q (package %s)\n", importPath, impMetadata.Name) for _, f := range impMetadata.CompiledGoFiles { - fmt.Fprintf(&result, "%s:\n", filepath.Base(f.Path())) + fmt.Fprintf(&result, "%s:\n", f.Base()) result.WriteString("--->\n") if err := writeFileSummary(ctx, snapshot, f, &result, true); err != nil { return nil, err diff --git a/gopls/internal/protocol/uri.go b/gopls/internal/protocol/uri.go index 361bc441cfe..661521060f8 100644 --- a/gopls/internal/protocol/uri.go +++ b/gopls/internal/protocol/uri.go @@ -91,6 +91,11 @@ func (uri DocumentURI) Path() string { return filepath.FromSlash(filename) } +// Base returns the base name of the file path of the given URI. +func (uri DocumentURI) Base() string { + return filepath.Base(uri.Path()) +} + // Dir returns the URI for the directory containing the receiver. func (uri DocumentURI) Dir() DocumentURI { // This function could be more efficiently implemented by avoiding any call diff --git a/gopls/internal/server/command.go b/gopls/internal/server/command.go index 41de2cd7c7e..5c7427fa656 100644 --- a/gopls/internal/server/command.go +++ b/gopls/internal/server/command.go @@ -91,7 +91,7 @@ func (h *commandHandler) Modules(ctx context.Context, args command.ModulesArgs) return false // "can't happen" (see prior Encloses check) } - assert(filepath.Base(goMod.Path()) == "go.mod", fmt.Sprintf("invalid go.mod path: want go.mod, got %q", goMod.Path())) + assert(goMod.Base() == "go.mod", fmt.Sprintf("invalid go.mod path: want go.mod, got %q", goMod.Path())) // Invariant: rel is a relative path without "../" segments and the last // segment is "go.mod" diff --git a/gopls/internal/server/text_synchronization.go b/gopls/internal/server/text_synchronization.go index c11b4a22499..a993598b7ed 100644 --- a/gopls/internal/server/text_synchronization.go +++ b/gopls/internal/server/text_synchronization.go @@ -168,8 +168,7 @@ func (s *server) warnAboutModifyingGeneratedFiles(ctx context.Context, uri proto isGenerated := golang.IsGenerated(ctx, snapshot, uri) release() if isGenerated { - msg := fmt.Sprintf("Warning: editing %s, a generated file.", - filepath.Base(uri.Path())) + msg := fmt.Sprintf("Warning: editing %s, a generated file.", uri.Base()) showMessage(ctx, s.client, protocol.Warning, msg) } return nil diff --git a/gopls/internal/test/integration/misc/definition_test.go b/gopls/internal/test/integration/misc/definition_test.go index 8a9f27d20ac..d53e6585027 100644 --- a/gopls/internal/test/integration/misc/definition_test.go +++ b/gopls/internal/test/integration/misc/definition_test.go @@ -637,7 +637,7 @@ var _ = foo(123) // call env.OpenFile("a.go") locString := func(loc protocol.Location) string { - return fmt.Sprintf("%s:%s", filepath.Base(loc.URI.Path()), loc.Range) + return fmt.Sprintf("%s:%s", loc.URI.Base(), loc.Range) } // Definition at the call"foo(123)" takes us to the Go declaration. diff --git a/gopls/internal/test/integration/modfile/modfile_test.go b/gopls/internal/test/integration/modfile/modfile_test.go index 36ed9cf4138..c6639db6c27 100644 --- a/gopls/internal/test/integration/modfile/modfile_test.go +++ b/gopls/internal/test/integration/modfile/modfile_test.go @@ -6,7 +6,6 @@ package modfile import ( "os" - "path/filepath" "runtime" "strings" "testing" @@ -870,7 +869,7 @@ func hello() {} // Confirm that we still have metadata with only on-disk edits. env.OpenFile("main.go") loc := env.FirstDefinition(env.RegexpSearch("main.go", "hello")) - if filepath.Base(string(loc.URI)) != "hello.go" { + if loc.URI.Base() != "hello.go" { t.Fatalf("expected definition in hello.go, got %s", loc.URI) } // Confirm that we no longer have metadata when the file is saved. From 82473ce934847055bec96f8a96e4d1fc38ecefa9 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 4 Jun 2025 13:24:36 -0400 Subject: [PATCH 191/196] gopls/doc/release: tweak v0.19 Updates golang/go#73965 Change-Id: I532438e160a99bb1754d1de4ed9015eae3880fb9 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678759 Reviewed-by: Robert Findley Auto-Submit: Alan Donovan LUCI-TryBot-Result: Go LUCI --- gopls/doc/release/v0.19.0.md | 179 ++++++++++++++++++++--------------- 1 file changed, 101 insertions(+), 78 deletions(-) diff --git a/gopls/doc/release/v0.19.0.md b/gopls/doc/release/v0.19.0.md index ec9d8362c53..a5e8956fb54 100644 --- a/gopls/doc/release/v0.19.0.md +++ b/gopls/doc/release/v0.19.0.md @@ -1,71 +1,15 @@ # Configuration Changes -- The `gopls check` subcommant now accepts a `-severity` flag to set a minimum +- The `gopls check` subcommand now accepts a `-severity` flag to set a minimum severity for the diagnostics it reports. By default, the minimum severity is "warning", so `gopls check` may report fewer diagnostics than before. Set `-severity=hint` to reproduce the previous behavior. -# New features +# Navigation features -## "Rename" of method receivers +## "Implementations" supports signature types (within same package) -The Rename operation, when applied to the declaration of a method -receiver, now also attempts to rename the receivers of all other -methods associated with the same named type. Each other receiver that -cannot be fully renamed is quietly skipped. - -Renaming a _use_ of a method receiver continues to affect only that -variable. - -```go -type Counter struct { x int } - - Rename here to affect only this method - ↓ -func (c *Counter) Inc() { c.x++ } -func (c *Counter) Dec() { c.x++ } - ↑ - Rename here to affect all methods -``` - -## Many `staticcheck` analyzers are enabled by default - -Slightly more than half of the analyzers in the -[Staticcheck](https://staticcheck.dev/docs/checks) suite are now -enabled by default. This subset has been chosen for precision and -efficiency. - -Previously, Staticcheck analyzers (all of them) would be run only if -the experimental `staticcheck` boolean option was set to `true`. This -value continues to enable the complete set, and a value of `false` -continues to disable the complete set. Leaving the option unspecified -enables the preferred subset of analyzers. - -Staticcheck analyzers, like all other analyzers, can be explicitly -enabled or disabled using the `analyzers` configuration setting; this -setting now takes precedence over the `staticcheck` setting, so, -regardless of what value of `staticcheck` you use (true/false/unset), -you can make adjustments to your preferred set of analyzers. - -## "Inefficient recursive iterator" analyzer - -A common pitfall when writing a function that returns an iterator -(iter.Seq) for a recursive data type is to recursively call the -function from its own implementation, leading to a stack of nested -coroutines, which is inefficient. - -The new `recursiveiter` analyzer detects such mistakes; see -[https://golang.org/x/tools/gopls/internal/analysis/recursiveiter](its -documentation) for details, including tips on how to define simple and -efficient recursive iterators. - -## "Inefficient range over maps.Keys/Values" analyzer - -This analyzer detects redundant calls to `maps.Keys` or `maps.Values` -as the operand of a range loop; maps can of course be ranged over -directly. - -## "Implementations" supports signature types + The Implementations query reports the correspondence between abstract and concrete types and their methods based on their method sets. @@ -89,9 +33,11 @@ Queries using method-sets should be invoked on the type or method name, and queries using signatures should be invoked on a `func` or `(` token. Only the local (same-package) algorithm is currently supported. -TODO: implement global. +(https://go.dev/issue/56572 tracks the global algorithm.) -## Go to Implementation +## "Go to Implementation" reports interface-to-interface relations + + The "Go to Implementation" operation now reports relationships between interfaces. Gopls now uses the concreteness of the query type to @@ -126,19 +72,102 @@ of the selected named type. -## "Eliminate dot import" code action -This code action, available on a dotted import, will offer to replace -the import with a regular one and qualify each use of the package -with its name. +# Editing features -### Auto-complete package clause for new Go files +## Completion: auto-complete package clause for new Go files Gopls now automatically adds the appropriate `package` clause to newly created Go files, so that you can immediately get started writing the interesting part. It requires client support for `workspace/didCreateFiles` +## New GOMODCACHE index for faster Organize Imports and unimported completions + +By default, gopls now builds and maintains a persistent index of +packages in the module cache (GOMODCACHE). The operations of Organize +Imports and completion of symbols from unimported pacakges are an +order of magnitude faster. + +To revert to the old behavior, set the `importsSource` option (whose +new default is `"gopls"`) to `"goimports"`. Users who don't want the +module cache used at all for imports or completions can change the +option to "off". + +# Analysis features + +## Most `staticcheck` analyzers are enabled by default + +Slightly more than half of the analyzers in the +[Staticcheck](https://staticcheck.dev/docs/checks) suite are now +enabled by default. This subset has been chosen for precision and +efficiency. + +Previously, Staticcheck analyzers (all of them) would be run only if +the experimental `staticcheck` boolean option was set to `true`. This +value continues to enable the complete set, and a value of `false` +continues to disable the complete set. Leaving the option unspecified +enables the preferred subset of analyzers. + +Staticcheck analyzers, like all other analyzers, can be explicitly +enabled or disabled using the `analyzers` configuration setting; this +setting now takes precedence over the `staticcheck` setting, so, +regardless of what value of `staticcheck` you use (true/false/unset), +you can make adjustments to your preferred set of analyzers. + +## `recursiveiter`: "inefficient recursive iterator" + +A common pitfall when writing a function that returns an iterator +(`iter.Seq`) for a recursive data type is to recursively call the +function from its own implementation, leading to a stack of nested +coroutines, which is inefficient. + +The new `recursiveiter` analyzer detects such mistakes; see +[its documentation](https://golang.org/x/tools/gopls/internal/analysis/recursiveiter) +for details, including tips on how to define simple and efficient +recursive iterators. + +## `maprange`: "inefficient range over maps.Keys/Values" + +The new `maprange` analyzer detects redundant calls to `maps.Keys` or +`maps.Values` as the operand of a range loop; maps can of course be +ranged over directly. See +[its documentation](https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/maprange) +for details). + +# Code transformation features + +## Rename method receivers + + + +The Rename operation, when applied to the declaration of a method +receiver, now also attempts to rename the receivers of all other +methods associated with the same named type. Each other receiver that +cannot be fully renamed is quietly skipped. + +Renaming a _use_ of a method receiver continues to affect only that +variable. + +```go +type Counter struct { x int } + + Rename here to affect only this method + ↓ +func (c *Counter) Inc() { c.x++ } +func (c *Counter) Dec() { c.x++ } + ↑ + Rename here to affect all methods +``` + +## "Eliminate dot import" code action + + + +This code action, available on a dotted import, will offer to replace +the import with a regular one and qualify each use of the package +with its name. + ## Add/remove tags from struct fields Gopls now provides two new code actions, available on an entire struct @@ -156,6 +185,8 @@ type Info struct { ## Inline local variable + + The new `refactor.inline.variable` code action replaces a reference to a local variable by that variable's initializer expression. For example, when applied to `s` in `println(s)`: @@ -174,14 +205,6 @@ func f(x int) { } ``` -## Use index for GOMODCACHE in imports and unimported completions - -The default for the option `importsSource` changes from "goimports" to "gopls". -This has the effect of building and maintaining an index to -the packages in GOMODCACHE. -The index is stored in the directory `os.UserCacheDir()/go/imports`. -Users who want the old behavior can change the option back. Users who don't -the module cache used at all for imports or completions -can change the option to -"off". The new code is many times faster than the old when accessing the -module cache. \ No newline at end of file +Only a single reference is replaced; issue https://go.dev/issue/70085 +tracks the feature to "inline all" uses of the variable and eliminate +it. From 4546fbd0b20190ede82382b293ae4440923ecaea Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 4 Jun 2025 09:36:17 -0400 Subject: [PATCH 192/196] internal/mcp: unify json tag parsing Do a thorough job of parsing json struct tags. Move the code to a place where both mcp and jsonschema can use it. Change-Id: I6c86308834706b12a1e22cd96ae120590f49d3f4 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678757 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley --- internal/mcp/internal/util/util.go | 39 ++++++++++++++++++++++++ internal/mcp/internal/util/util_test.go | 38 +++++++++++++++++++++++ internal/mcp/jsonschema/infer.go | 34 +++++---------------- internal/mcp/jsonschema/schema.go | 7 +++-- internal/mcp/jsonschema/validate.go | 30 +++++------------- internal/mcp/jsonschema/validate_test.go | 20 ------------ internal/mcp/util.go | 30 +++--------------- 7 files changed, 101 insertions(+), 97 deletions(-) create mode 100644 internal/mcp/internal/util/util_test.go diff --git a/internal/mcp/internal/util/util.go b/internal/mcp/internal/util/util.go index 8ef5d1fd464..58e13efea4f 100644 --- a/internal/mcp/internal/util/util.go +++ b/internal/mcp/internal/util/util.go @@ -7,7 +7,9 @@ package util import ( "cmp" "iter" + "reflect" "slices" + "strings" ) // Helpers below are copied from gopls' moremaps package. @@ -34,3 +36,40 @@ func KeySlice[M ~map[K]V, K comparable, V any](m M) []K { } return r } + +type JSONInfo struct { + Omit bool // unexported or first tag element is "-" + Name string // Go field name or first tag element. Empty if Omit is true. + Settings map[string]bool // "omitempty", "omitzero", etc. +} + +// FieldJSONInfo reports information about how encoding/json +// handles the given struct field. +// If the field is unexported, JSONInfo.Omit is true and no other JSONInfo field +// is populated. +// If the field is exported and has no tag, then Name is the field's name and all +// other fields are false. +// Otherwise, the information is obtained from the tag. +func FieldJSONInfo(f reflect.StructField) JSONInfo { + if !f.IsExported() { + return JSONInfo{Omit: true} + } + info := JSONInfo{Name: f.Name} + if tag, ok := f.Tag.Lookup("json"); ok { + name, rest, found := strings.Cut(tag, ",") + // "-" means omit, but "-," means the name is "-" + if name == "-" && !found { + return JSONInfo{Omit: true} + } + if name != "" { + info.Name = name + } + if len(rest) > 0 { + info.Settings = map[string]bool{} + for _, s := range strings.Split(rest, ",") { + info.Settings[s] = true + } + } + } + return info +} diff --git a/internal/mcp/internal/util/util_test.go b/internal/mcp/internal/util/util_test.go new file mode 100644 index 00000000000..0ec133a459c --- /dev/null +++ b/internal/mcp/internal/util/util_test.go @@ -0,0 +1,38 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package util + +import ( + "reflect" + "testing" +) + +func TestJSONInfo(t *testing.T) { + type S struct { + A int + B int `json:","` + C int `json:"-"` + D int `json:"-,"` + E int `json:"echo"` + F int `json:"foxtrot,omitempty"` + g int `json:"golf"` + } + want := []JSONInfo{ + {Name: "A"}, + {Name: "B"}, + {Omit: true}, + {Name: "-"}, + {Name: "echo"}, + {Name: "foxtrot", Settings: map[string]bool{"omitempty": true}}, + {Omit: true}, + } + tt := reflect.TypeFor[S]() + for i := range tt.NumField() { + got := FieldJSONInfo(tt.Field(i)) + if !reflect.DeepEqual(got, want[i]) { + t.Errorf("got %+v, want %+v", got, want[i]) + } + } +} diff --git a/internal/mcp/jsonschema/infer.go b/internal/mcp/jsonschema/infer.go index f044996c304..feeaf6120bc 100644 --- a/internal/mcp/jsonschema/infer.go +++ b/internal/mcp/jsonschema/infer.go @@ -9,8 +9,8 @@ package jsonschema import ( "fmt" "reflect" - "slices" - "strings" + + "golang.org/x/tools/internal/mcp/internal/util" ) // For constructs a JSON schema object for the given type argument. @@ -108,19 +108,19 @@ func typeSchema(t reflect.Type) (*Schema, error) { for i := range t.NumField() { field := t.Field(i) - name, required, include := parseField(field) - if !include { + info := util.FieldJSONInfo(field) + if info.Omit { continue } if s.Properties == nil { s.Properties = make(map[string]*Schema) } - s.Properties[name], err = typeSchema(field.Type) + s.Properties[info.Name], err = typeSchema(field.Type) if err != nil { return nil, err } - if required { - s.Required = append(s.Required, name) + if !info.Settings["omitempty"] && !info.Settings["omitzero"] { + s.Required = append(s.Required, info.Name) } } @@ -133,23 +133,3 @@ func typeSchema(t reflect.Type) (*Schema, error) { } return s, nil } - -func parseField(f reflect.StructField) (name string, required, include bool) { - if !f.IsExported() { - return "", false, false - } - name = f.Name - required = true - if tag, ok := f.Tag.Lookup("json"); ok { - props := strings.Split(tag, ",") - if props[0] != "" { - if props[0] == "-" { - return "", false, false - } - name = props[0] - } - // TODO: support 'omitzero' as well. - required = !slices.Contains(props[1:], "omitempty") - } - return name, required, true -} diff --git a/internal/mcp/jsonschema/schema.go b/internal/mcp/jsonschema/schema.go index 225beaec958..f0eeb17096c 100644 --- a/internal/mcp/jsonschema/schema.go +++ b/internal/mcp/jsonschema/schema.go @@ -17,6 +17,8 @@ import ( "reflect" "regexp" "slices" + + "golang.org/x/tools/internal/mcp/internal/util" ) // A Schema is a JSON schema object. @@ -426,8 +428,9 @@ var ( func init() { for _, sf := range reflect.VisibleFields(reflect.TypeFor[Schema]()) { - if name, ok := jsonName(sf); ok { - schemaFieldInfos = append(schemaFieldInfos, structFieldInfo{sf, name}) + info := util.FieldJSONInfo(sf) + if !info.Omit { + schemaFieldInfos = append(schemaFieldInfos, structFieldInfo{sf, info.Name}) } } slices.SortFunc(schemaFieldInfos, func(i1, i2 structFieldInfo) int { diff --git a/internal/mcp/jsonschema/validate.go b/internal/mcp/jsonschema/validate.go index 9068ae6bb6a..b5f1757aa9e 100644 --- a/internal/mcp/jsonschema/validate.go +++ b/internal/mcp/jsonschema/validate.go @@ -16,6 +16,8 @@ import ( "strings" "sync" "unicode/utf8" + + "golang.org/x/tools/internal/mcp/internal/util" ) // The value of the "$schema" keyword for the version that we can validate. @@ -688,7 +690,8 @@ func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { for name, sf := range structPropertiesOf(v.Type()) { val := v.FieldByIndex(sf.Index) if val.IsZero() { - if tag, ok := sf.Tag.Lookup("json"); ok && (strings.Contains(tag, "omitempty") || strings.Contains(tag, "omitzero")) { + info := util.FieldJSONInfo(sf) + if info.Settings["omitempty"] || info.Settings["omitzero"] { continue } } @@ -740,30 +743,11 @@ func structPropertiesOf(t reflect.Type) propertyMap { } props := map[string]reflect.StructField{} for _, sf := range reflect.VisibleFields(t) { - if name, ok := jsonName(sf); ok { - props[name] = sf + info := util.FieldJSONInfo(sf) + if !info.Omit { + props[info.Name] = sf } } structProperties.Store(t, props) return props } - -// jsonName returns the name for f as would be used by [json.Marshal]. -// That is the name in the json struct tag, or the field name if there is no tag. -// If f is not exported or the tag is "-", jsonName returns "", false. -func jsonName(f reflect.StructField) (string, bool) { - if !f.IsExported() { - return "", false - } - if tag, ok := f.Tag.Lookup("json"); ok { - name, _, found := strings.Cut(tag, ",") - // "-" means omit, but "-," means the name is "-" - if name == "-" && !found { - return "", false - } - if name != "" { - return name, true - } - } - return f.Name, true -} diff --git a/internal/mcp/jsonschema/validate_test.go b/internal/mcp/jsonschema/validate_test.go index 76eb7c27c10..fe55cedb9ba 100644 --- a/internal/mcp/jsonschema/validate_test.go +++ b/internal/mcp/jsonschema/validate_test.go @@ -259,26 +259,6 @@ func TestStructInstance(t *testing.T) { } } -func TestJSONName(t *testing.T) { - type S struct { - A int - B int `json:","` - C int `json:"-"` - D int `json:"-,"` - E int `json:"echo"` - F int `json:"foxtrot,omitempty"` - g int `json:"golf"` - } - want := []string{"A", "B", "", "-", "echo", "foxtrot", ""} - tt := reflect.TypeFor[S]() - for i := range tt.NumField() { - got, _ := jsonName(tt.Field(i)) - if got != want[i] { - t.Errorf("got %q, want %q", got, want[i]) - } - } -} - func mustMarshal(x any) json.RawMessage { data, err := json.Marshal(x) if err != nil { diff --git a/internal/mcp/util.go b/internal/mcp/util.go index 3b1f53b124e..dae4d920ac8 100644 --- a/internal/mcp/util.go +++ b/internal/mcp/util.go @@ -9,8 +9,9 @@ import ( "encoding/json" "fmt" "reflect" - "strings" "sync" + + "golang.org/x/tools/internal/mcp/internal/util" ) func assert(cond bool, msg string) { @@ -135,32 +136,11 @@ func jsonNames(t reflect.Type) map[string]bool { } m := map[string]bool{} for i := range t.NumField() { - if n, ok := jsonName(t.Field(i)); ok { - m[n] = true + info := util.FieldJSONInfo(t.Field(i)) + if !info.Omit { + m[info.Name] = true } } jsonNamesMap.Store(t, m) return m } - -// jsonName returns the name for f as would be used by [json.Marshal]. -// That is the name in the json struct tag, or the field name if there is no tag. -// If f is not exported or the tag is "-", jsonName returns "", false. -// -// Copied from jsonschema/validate.go. -func jsonName(f reflect.StructField) (string, bool) { - if !f.IsExported() { - return "", false - } - if tag, ok := f.Tag.Lookup("json"); ok { - name, _, found := strings.Cut(tag, ",") - // "-" means omit, but "-," means the name is "-" - if name == "-" && !found { - return "", false - } - if name != "" { - return name, true - } - } - return f.Name, true -} From 64bfecc32e163d2684a85b73472919e02da50180 Mon Sep 17 00:00:00 2001 From: Madeline Kalil Date: Wed, 4 Jun 2025 14:12:56 -0400 Subject: [PATCH 193/196] gopls/internal/golang: fix extract bug with anon functions When the extracted block contains return statements but doesn't have any non-nested returns, we adjust all return statements in the extracted function. However, we need to avoid modifying the return statements of anonymous functions since this would change function behavior or cause a type error. Fixes golang/go#73972 Change-Id: I6e5c145c844d29c9ceae32da341b99a1ff1b2d54 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678915 LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- gopls/internal/golang/extract.go | 4 ++ .../functionextraction_issue73972.txt | 52 +++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 gopls/internal/test/marker/testdata/codeaction/functionextraction_issue73972.txt diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 91bea65a1f2..4dc784045c1 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -1896,6 +1896,10 @@ func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object] if n == nil { return false } + // Don't modify return statements inside anonymous functions. + if _, ok := n.(*ast.FuncLit); ok { + return false + } if n, ok := n.(*ast.ReturnStmt); ok { n.Results = slices.Concat(zeroVals, n.Results) return false diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue73972.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue73972.txt new file mode 100644 index 00000000000..bad151dd419 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue73972.txt @@ -0,0 +1,52 @@ +This test verifies the fix for golang/go#73972: extraction should +not modify the return statements of anonymous functions. + +-- go.mod -- +module mod.test/extract + +go 1.18 + +-- a.go -- +package extract + +import ( + "fmt" + "strings" +) + +func main() { + b := strings.ContainsFunc("a", func(_ rune) bool { //@codeaction("b", "refactor.extract.function", end=end, result=ext) + return false + }) + if b { + return + } //@loc(end, "}") + fmt.Println(b) +} + +-- @ext/a.go -- +package extract + +import ( + "fmt" + "strings" +) + +func main() { + b, shouldReturn := newFunction() + if shouldReturn { + return + } //@loc(end, "}") + fmt.Println(b) +} + +func newFunction() (bool, bool) { + b := strings.ContainsFunc("a", func(_ rune) bool { //@codeaction("b", "refactor.extract.function", end=end, result=ext) + return false + }) + if b { + return false, true + } + return b, false +} + From 82ee0fd1228b85b95daadd1901e83a9200d661e6 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Mon, 2 Jun 2025 16:37:01 +0000 Subject: [PATCH 194/196] internal/mcp: change paginateList to a generic helper This CL simplifies the paginateList function in server.go to use a generic helper for tools, resources, and prompts. Change-Id: Ide0d2a90d715374280067e094d8870882e6bddfa Reviewed-on: https://go-review.googlesource.com/c/tools/+/678055 Auto-Submit: Sam Thanawalla Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI --- internal/mcp/client.go | 16 ++----- internal/mcp/server.go | 94 ++++++++++++++++--------------------- internal/mcp/server_test.go | 57 ++++++++++++++++------ internal/mcp/shared.go | 10 ++++ 4 files changed, 96 insertions(+), 81 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 8f47f6fd420..668ed70b3fc 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -391,23 +391,13 @@ func (cs *ClientSession) Prompts(ctx context.Context, params *ListPromptsParams) }) } -type ListParams interface { - // Returns a pointer to the param's Cursor field. - cursorPtr() *string -} - -type ListResult[T any] interface { - // Returns a pointer to the param's NextCursor field. - nextCursorPtr() *string -} - // paginate is a generic helper function to provide a paginated iterator. -func paginate[P ListParams, R ListResult[E], E any](ctx context.Context, params P, listFunc func(context.Context, P) (R, error), items func(R) []*E) iter.Seq2[E, error] { - return func(yield func(E, error) bool) { +func paginate[P listParams, R listResult[T], T any](ctx context.Context, params P, listFunc func(context.Context, P) (R, error), items func(R) []*T) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { for { res, err := listFunc(ctx, params) if err != nil { - var zero E + var zero T yield(zero, err) return } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index fa55c33b539..02ed31109ec 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -182,21 +182,15 @@ func (s *Server) Sessions() iter.Seq[*ServerSession] { func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPromptsParams) (*ListPromptsResult, error) { s.mu.Lock() defer s.mu.Unlock() - var cursor string - if params != nil { - cursor = params.Cursor + if params == nil { + params = &ListPromptsParams{} } - prompts, nextCursor, err := paginateList(s.prompts, cursor, s.opts.PageSize) - if err != nil { - return nil, err - } - res := new(ListPromptsResult) - res.NextCursor = nextCursor - res.Prompts = []*Prompt{} // avoid JSON null - for _, p := range prompts { - res.Prompts = append(res.Prompts, p.Prompt) - } - return res, nil + return paginateList(s.prompts, s.opts.PageSize, params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*ServerPrompt) { + res.Prompts = []*Prompt{} // avoid JSON null + for _, p := range prompts { + res.Prompts = append(res.Prompts, p.Prompt) + } + }) } func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { @@ -213,21 +207,15 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPr func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() - var cursor string - if params != nil { - cursor = params.Cursor - } - tools, nextCursor, err := paginateList(s.tools, cursor, s.opts.PageSize) - if err != nil { - return nil, err - } - res := new(ListToolsResult) - res.NextCursor = nextCursor - res.Tools = []*Tool{} // avoid JSON null - for _, t := range tools { - res.Tools = append(res.Tools, t.Tool) + if params == nil { + params = &ListToolsParams{} } - return res, nil + return paginateList(s.tools, s.opts.PageSize, params, &ListToolsResult{}, func(res *ListToolsResult, tools []*ServerTool) { + res.Tools = []*Tool{} // avoid JSON null + for _, t := range tools { + res.Tools = append(res.Tools, t.Tool) + } + }) } func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallToolParams[json.RawMessage]) (*CallToolResult, error) { @@ -243,21 +231,15 @@ func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallTo func (s *Server) listResources(_ context.Context, _ *ServerSession, params *ListResourcesParams) (*ListResourcesResult, error) { s.mu.Lock() defer s.mu.Unlock() - var cursor string - if params != nil { - cursor = params.Cursor - } - resources, nextCursor, err := paginateList(s.resources, cursor, s.opts.PageSize) - if err != nil { - return nil, err + if params == nil { + params = &ListResourcesParams{} } - res := new(ListResourcesResult) - res.NextCursor = nextCursor - res.Resources = []*Resource{} // avoid JSON null - for _, r := range resources { - res.Resources = append(res.Resources, r.Resource) - } - return res, nil + return paginateList(s.resources, s.opts.PageSize, params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*ServerResource) { + res.Resources = []*Resource{} // avoid JSON null + for _, r := range resources { + res.Resources = append(res.Resources, r.Resource) + } + }) } func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *ReadResourceParams) (*ReadResourceResult, error) { @@ -618,22 +600,25 @@ func decodeCursor(cursor string) (*pageToken, error) { return &token, nil } -// paginateList returns a slice of features from the given featureSet, based on -// the provided cursor and page size. It also returns a new cursor for the next -// page, or an empty string if there are no more pages. -func paginateList[T any](fs *featureSet[T], cursor string, pageSize int) (features []T, nextCursor string, err error) { +// paginateList is a generic helper that returns a paginated slice of items +// from a featureSet. It populates the provided result res with the items +// and sets its next cursor for subsequent pages. +// If there are no more pages, the next cursor within the result will be an empty string. +func paginateList[P listParams, R listResult[T], T any](fs *featureSet[T], pageSize int, params P, res R, setFunc func(R, []T)) (R, error) { var seq iter.Seq[T] - if cursor == "" { + if params.cursorPtr() == nil || *params.cursorPtr() == "" { seq = fs.all() } else { - pageToken, err := decodeCursor(cursor) + pageToken, err := decodeCursor(*params.cursorPtr()) // According to the spec, invalid cursors should return Invalid params. if err != nil { - return nil, "", jsonrpc2.ErrInvalidParams + var zero R + return zero, jsonrpc2.ErrInvalidParams } seq = fs.above(pageToken.LastUID) } var count int + var features []T for f := range seq { count++ // If we've seen pageSize + 1 elements, we've gathered enough info to determine @@ -643,13 +628,16 @@ func paginateList[T any](fs *featureSet[T], cursor string, pageSize int) (featur } features = append(features, f) } + setFunc(res, features) // No remaining pages. if count < pageSize+1 { - return features, "", nil + return res, nil } - nextCursor, err = encodeCursor(fs.uniqueID(features[len(features)-1])) + nextCursor, err := encodeCursor(fs.uniqueID(features[len(features)-1])) if err != nil { - return nil, "", err + var zero R + return zero, err } - return features, nextCursor, nil + *res.nextCursorPtr() = nextCursor + return res, nil } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 16bf8a5317e..650209f9043 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -12,12 +12,29 @@ import ( "github.com/google/go-cmp/cmp" ) -type TestItem struct { +type testItem struct { Name string Value string } -var allTestItems = []*TestItem{ +type testListParams struct { + Cursor string +} + +func (p *testListParams) cursorPtr() *string { + return &p.Cursor +} + +type testListResult struct { + Items []*testItem + NextCursor string +} + +func (r *testListResult) nextCursorPtr() *string { + return &r.NextCursor +} + +var allTestItems = []*testItem{ {"alpha", "val-A"}, {"bravo", "val-B"}, {"charlie", "val-C"}, @@ -44,10 +61,10 @@ func getCursor(input string) string { func TestServerPaginateBasic(t *testing.T) { testCases := []struct { name string - initialItems []*TestItem + initialItems []*testItem inputCursor string inputPageSize int - wantFeatures []*TestItem + wantFeatures []*testItem wantNextCursor string wantErr bool }{ @@ -154,41 +171,51 @@ func TestServerPaginateBasic(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - fs := newFeatureSet(func(t *TestItem) string { return t.Name }) + fs := newFeatureSet(func(t *testItem) string { return t.Name }) fs.add(tc.initialItems...) - gotFeatures, gotNextCursor, err := paginateList(fs, tc.inputCursor, tc.inputPageSize) + params := &testListParams{Cursor: tc.inputCursor} + gotResult, err := paginateList(fs, tc.inputPageSize, params, &testListResult{}, func(res *testListResult, items []*testItem) { + res.Items = items + }) if (err != nil) != tc.wantErr { t.Errorf("paginateList(%s) error, got %v, wantErr %v", tc.name, err, tc.wantErr) } - if diff := cmp.Diff(tc.wantFeatures, gotFeatures); diff != "" { + if tc.wantErr { + return + } + if diff := cmp.Diff(tc.wantFeatures, gotResult.Items); diff != "" { t.Errorf("paginateList(%s) mismatch (-want +got):\n%s", tc.name, diff) } - if tc.wantNextCursor != gotNextCursor { - t.Errorf("paginateList(%s) nextCursor, got %v, want %v", tc.name, gotNextCursor, tc.wantNextCursor) + if tc.wantNextCursor != gotResult.NextCursor { + t.Errorf("paginateList(%s) nextCursor, got %v, want %v", tc.name, gotResult.NextCursor, tc.wantNextCursor) } }) } } func TestServerPaginateVariousPageSizes(t *testing.T) { - fs := newFeatureSet(func(t *TestItem) string { return t.Name }) + fs := newFeatureSet(func(t *testItem) string { return t.Name }) fs.add(allTestItems...) // Try all possible page sizes, ensuring we get the correct list of items. for pageSize := 1; pageSize < len(allTestItems)+1; pageSize++ { - var gotItems []*TestItem + var gotItems []*testItem var nextCursor string wantChunks := slices.Collect(slices.Chunk(allTestItems, pageSize)) index := 0 // Iterate through all pages, comparing sub-slices to the paginated list. for { - gotFeatures, gotNextCursor, err := paginateList(fs, nextCursor, pageSize) + params := &testListParams{Cursor: nextCursor} + gotResult, err := paginateList(fs, pageSize, params, &testListResult{}, func(res *testListResult, items []*testItem) { + res.Items = items + }) if err != nil { + t.Fatalf("paginateList() unexpected error for pageSize %d, cursor %q: %v", pageSize, nextCursor, err) } - if diff := cmp.Diff(wantChunks[index], gotFeatures); diff != "" { + if diff := cmp.Diff(wantChunks[index], gotResult.Items); diff != "" { t.Errorf("paginateList mismatch (-want +got):\n%s", diff) } - gotItems = append(gotItems, gotFeatures...) - nextCursor = gotNextCursor + gotItems = append(gotItems, gotResult.Items...) + nextCursor = gotResult.NextCursor if nextCursor == "" { break } diff --git a/internal/mcp/shared.go b/internal/mcp/shared.go index 8f0ce48e647..4670a360cbb 100644 --- a/internal/mcp/shared.go +++ b/internal/mcp/shared.go @@ -271,3 +271,13 @@ type Result interface { type emptyResult struct{} func (*emptyResult) GetMeta() *Meta { panic("should never be called") } + +type listParams interface { + // Returns a pointer to the param's Cursor field. + cursorPtr() *string +} + +type listResult[T any] interface { + // Returns a pointer to the param's NextCursor field. + nextCursorPtr() *string +} From f114dcf97d4f35feb86030bb9e1c5c8fc6fd8942 Mon Sep 17 00:00:00 2001 From: xieyuschen Date: Wed, 4 Jun 2025 19:02:34 +0800 Subject: [PATCH 195/196] gopls/internal/protocol: refine DocumentURI Clean method and its usages This CL tracks the comments left by Alan Donovan in CL663295, thanks again. * change protocol.Clean function to DocumentURI method * ensure the same key when access and operate viewMap * factor some logics Change-Id: Ib717705495997379ed0f872d438f238ca1308a23 Reviewed-on: https://go-review.googlesource.com/c/tools/+/678416 Auto-Submit: Alan Donovan Reviewed-by: Robert Findley LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- gopls/internal/cache/session.go | 16 ++++++++------- gopls/internal/protocol/uri.go | 2 +- .../internal/test/integration/fake/editor.go | 20 ++++++++----------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/gopls/internal/cache/session.go b/gopls/internal/cache/session.go index 8a9a589b708..166c30eef16 100644 --- a/gopls/internal/cache/session.go +++ b/gopls/internal/cache/session.go @@ -64,7 +64,7 @@ type Session struct { viewMu sync.Mutex views []*View - viewMap map[protocol.DocumentURI]*View // file->best view or nil; nil after shutdown + viewMap map[protocol.DocumentURI]*View // file->best view or nil; nil after shutdown; the key must be a clean uri. // snapshots is a counting semaphore that records the number // of unreleased snapshots associated with this session. @@ -139,15 +139,16 @@ func (s *Session) NewView(ctx context.Context, folder *Folder) (*View, *Snapshot } view, snapshot, release := s.createView(ctx, def) s.views = append(s.views, view) - s.viewMap[protocol.Clean(folder.Dir)] = view + s.viewMap[folder.Dir.Clean()] = view return view, snapshot, release, nil } // HasView checks whether the uri's view exists. func (s *Session) HasView(uri protocol.DocumentURI) bool { + uri = uri.Clean() s.viewMu.Lock() defer s.viewMu.Unlock() - _, ok := s.viewMap[protocol.Clean(uri)] + _, ok := s.viewMap[uri] return ok } @@ -379,6 +380,7 @@ func (s *Session) View(id string) (*View, error) { // // On success, the caller must call the returned function to release the snapshot. func (s *Session) SnapshotOf(ctx context.Context, uri protocol.DocumentURI) (*Snapshot, func(), error) { + uri = uri.Clean() // Fast path: if the uri has a static association with a view, return it. s.viewMu.Lock() v, err := s.viewOfLocked(ctx, uri) @@ -396,7 +398,7 @@ func (s *Session) SnapshotOf(ctx context.Context, uri protocol.DocumentURI) (*Sn // View is shut down. Forget this association. s.viewMu.Lock() if s.viewMap[uri] == v { - delete(s.viewMap, protocol.Clean(uri)) + delete(s.viewMap, uri) } s.viewMu.Unlock() } @@ -473,7 +475,7 @@ var errNoViews = errors.New("no views") // viewOfLocked evaluates the best view for uri, memoizing its result in // s.viewMap. // -// Precondition: caller holds s.viewMu lock. +// Precondition: caller holds s.viewMu lock; uri must be clean. // // May return (nil, nil) if no best view can be determined. func (s *Session) viewOfLocked(ctx context.Context, uri protocol.DocumentURI) (*View, error) { @@ -500,7 +502,7 @@ func (s *Session) viewOfLocked(ctx context.Context, uri protocol.DocumentURI) (* // (as in golang/go#60776). v = relevantViews[0] } - s.viewMap[protocol.Clean(uri)] = v // may be nil + s.viewMap[uri] = v // may be nil } return v, nil } @@ -748,7 +750,7 @@ func (s *Session) ResetView(ctx context.Context, uri protocol.DocumentURI) (*Vie return nil, fmt.Errorf("session is shut down") } - view, err := s.viewOfLocked(ctx, uri) + view, err := s.viewOfLocked(ctx, uri.Clean()) if err != nil { return nil, err } diff --git a/gopls/internal/protocol/uri.go b/gopls/internal/protocol/uri.go index 661521060f8..ead27a1313a 100644 --- a/gopls/internal/protocol/uri.go +++ b/gopls/internal/protocol/uri.go @@ -68,7 +68,7 @@ func (uri *DocumentURI) UnmarshalText(data []byte) (err error) { } // Clean returns the cleaned uri by triggering filepath.Clean underlying. -func Clean(uri DocumentURI) DocumentURI { +func (uri DocumentURI) Clean() DocumentURI { return URIFromPath(filepath.Clean(uri.Path())) } diff --git a/gopls/internal/test/integration/fake/editor.go b/gopls/internal/test/integration/fake/editor.go index 6ac9dc17e04..aeb41c7a71f 100644 --- a/gopls/internal/test/integration/fake/editor.go +++ b/gopls/internal/test/integration/fake/editor.go @@ -327,7 +327,11 @@ func (e *Editor) initialize(ctx context.Context) error { params.InitializationOptions = makeSettings(e.sandbox, config, nil) params.WorkspaceFolders = makeWorkspaceFolders(e.sandbox, config.WorkspaceFolders, config.NoDefaultWorkspaceFiles) - params.RootURI = protocol.DocumentURI(makeRootURI(e.sandbox, config.RelRootPath)) + params.RootURI = protocol.URIFromPath(config.RelRootPath) + if !uriRE.MatchString(config.RelRootPath) { // relative file path + params.RootURI = e.sandbox.Workdir.URI(config.RelRootPath) + } + capabilities, err := clientCapabilities(config) if err != nil { return fmt.Errorf("unmarshalling EditorConfig.CapabilitiesJSON: %v", err) @@ -447,10 +451,10 @@ var uriRE = regexp.MustCompile(`^[a-z][a-z0-9+\-.]*://\S+`) // makeWorkspaceFolders creates a slice of workspace folders to use for // this editing session, based on the editor configuration. func makeWorkspaceFolders(sandbox *Sandbox, paths []string, useEmpty bool) (folders []protocol.WorkspaceFolder) { - if len(paths) == 0 && useEmpty { - return nil - } if len(paths) == 0 { + if useEmpty { + return nil + } paths = []string{string(sandbox.Workdir.RelativeTo)} } @@ -468,14 +472,6 @@ func makeWorkspaceFolders(sandbox *Sandbox, paths []string, useEmpty bool) (fold return folders } -func makeRootURI(sandbox *Sandbox, path string) string { - uri := path - if !uriRE.MatchString(path) { // relative file path - uri = string(sandbox.Workdir.URI(path)) - } - return uri -} - // onFileChanges is registered to be called by the Workdir on any writes that // go through the Workdir API. It is called synchronously by the Workdir. func (e *Editor) onFileChanges(ctx context.Context, evts []protocol.FileEvent) { From 578c1213983a83e6411536ddf6bbf3a1faf97aea Mon Sep 17 00:00:00 2001 From: Gopher Robot Date: Thu, 5 Jun 2025 13:56:18 -0700 Subject: [PATCH 196/196] go.mod: update golang.org/x dependencies Update golang.org/x dependencies to their latest tagged versions. Change-Id: If816f0214d746b23dc26eb56c0dae2c97b74f1a7 Reviewed-on: https://go-review.googlesource.com/c/tools/+/679316 Auto-Submit: Gopher Robot Reviewed-by: Dmitri Shuralyov LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase --- go.mod | 6 +++--- go.sum | 12 ++++++------ gopls/go.mod | 8 ++++---- gopls/go.sum | 16 ++++++++-------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/go.mod b/go.mod index 91de2267573..634ab1b6165 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,9 @@ go 1.23.0 require ( github.com/google/go-cmp v0.6.0 github.com/yuin/goldmark v1.4.13 - golang.org/x/mod v0.24.0 - golang.org/x/net v0.40.0 - golang.org/x/sync v0.14.0 + golang.org/x/mod v0.25.0 + golang.org/x/net v0.41.0 + golang.org/x/sync v0.15.0 golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457 ) diff --git a/go.sum b/go.sum index 6a01512f3e4..de7de84f9db 100644 --- a/go.sum +++ b/go.sum @@ -2,12 +2,12 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= -golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= -golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= -golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= -golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457 h1:zf5N6UOrA487eEFacMePxjXAJctxKmyjKUsjA11Uzuk= diff --git a/gopls/go.mod b/gopls/go.mod index 80da71f797e..d11c3b13a79 100644 --- a/gopls/go.mod +++ b/gopls/go.mod @@ -6,11 +6,10 @@ require ( github.com/fatih/gomodifytags v1.17.1-0.20250423142747-f3939df9aa3c github.com/google/go-cmp v0.6.0 github.com/jba/templatecheck v0.7.1 - golang.org/x/mod v0.24.0 - golang.org/x/sync v0.14.0 - golang.org/x/sys v0.33.0 + golang.org/x/mod v0.25.0 + golang.org/x/sync v0.15.0 golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3 - golang.org/x/text v0.25.0 + golang.org/x/text v0.26.0 golang.org/x/tools v0.33.1-0.20250521210010-423c5afcceff golang.org/x/vuln v1.1.4 gopkg.in/yaml.v3 v3.0.1 @@ -25,6 +24,7 @@ require ( github.com/fatih/structtag v1.2.0 // indirect github.com/google/safehtml v0.1.0 // indirect golang.org/x/exp/typeparams v0.0.0-20250218142911-aa4b98e5adaa // indirect + golang.org/x/sys v0.33.0 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect ) diff --git a/gopls/go.sum b/gopls/go.sum index d6d9d39c7cd..24ae4b66bab 100644 --- a/gopls/go.sum +++ b/gopls/go.sum @@ -22,19 +22,19 @@ github.com/rogpeppe/go-internal v1.13.2-0.20241226121412-a5dc8ff20d0a h1:w3tdWGK github.com/rogpeppe/go-internal v1.13.2-0.20241226121412-a5dc8ff20d0a/go.mod h1:S8kfXMp+yh77OxPD4fdM6YUknrZpQxLhvxzS4gDHENY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/exp/typeparams v0.0.0-20250218142911-aa4b98e5adaa h1:Br3+0EZZohShrmVVc85znGpxw7Ca8hsUJlrdT/JQGw8= golang.org/x/exp/typeparams v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:LKZHyeOpPuZcMgxeHjJp4p5yvxrCX1xDvH10zYHhjjQ= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= -golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= @@ -50,8 +50,8 @@ golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= -golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/vuln v1.1.4 h1:Ju8QsuyhX3Hk8ma3CesTbO8vfJD9EvUBgHvkxHBzj0I= golang.org/x/vuln v1.1.4/go.mod h1:F+45wmU18ym/ca5PLTPLsSzr2KppzswxPP603ldA67s= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=