aboutsummaryrefslogtreecommitdiff
path: root/internal/lsp/code_action.go
blob: c7f99cea4099499b26e706f0ecf751d574d48559 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package lsp

import (
	"context"
	"fmt"
	"sort"
	"strings"

	"golang.org/x/tools/internal/imports"
	"golang.org/x/tools/internal/lsp/mod"
	"golang.org/x/tools/internal/lsp/protocol"
	"golang.org/x/tools/internal/lsp/source"
	"golang.org/x/tools/internal/lsp/telemetry"
	"golang.org/x/tools/internal/span"
	"golang.org/x/tools/internal/telemetry/log"
	errors "golang.org/x/xerrors"
)

func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionParams) ([]protocol.CodeAction, error) {
	uri := span.NewURI(params.TextDocument.URI)
	view, err := s.session.ViewOf(uri)
	if err != nil {
		return nil, err
	}
	snapshot := view.Snapshot()
	fh, err := snapshot.GetFile(uri)
	if err != nil {
		return nil, err
	}

	// Determine the supported actions for this file kind.
	supportedCodeActions, ok := view.Options().SupportedCodeActions[fh.Identity().Kind]
	if !ok {
		return nil, fmt.Errorf("no supported code actions for %v file kind", fh.Identity().Kind)
	}

	// The Only field of the context specifies which code actions the client wants.
	// If Only is empty, assume that the client wants all of the possible code actions.
	var wanted map[protocol.CodeActionKind]bool
	if len(params.Context.Only) == 0 {
		wanted = supportedCodeActions
	} else {
		wanted = make(map[protocol.CodeActionKind]bool)
		for _, only := range params.Context.Only {
			wanted[only] = supportedCodeActions[only]
		}
	}
	if len(wanted) == 0 {
		return nil, errors.Errorf("no supported code action to execute for %s, wanted %v", uri, params.Context.Only)
	}

	var codeActions []protocol.CodeAction
	switch fh.Identity().Kind {
	case source.Mod:
		if !wanted[protocol.SourceOrganizeImports] {
			codeActions = append(codeActions, protocol.CodeAction{
				Title: "Tidy",
				Kind:  protocol.SourceOrganizeImports,
				Command: &protocol.Command{
					Title:     "Tidy",
					Command:   "tidy",
					Arguments: []interface{}{fh.Identity().URI},
				},
			})
		}
		if diagnostics := params.Context.Diagnostics; len(diagnostics) > 0 {
			codeActions = append(codeActions, mod.SuggestedFixes(ctx, snapshot, fh, diagnostics)...)
		}
	case source.Go:
		edits, editsPerFix, err := source.AllImportsFixes(ctx, snapshot, fh)
		if err != nil {
			return nil, err
		}
		if diagnostics := params.Context.Diagnostics; wanted[protocol.QuickFix] && len(diagnostics) > 0 {
			// First, add the quick fixes reported by go/analysis.
			qf, err := quickFixes(ctx, snapshot, fh, diagnostics)
			if err != nil {
				log.Error(ctx, "quick fixes failed", err, telemetry.File.Of(uri))
			}
			codeActions = append(codeActions, qf...)

			// If we also have diagnostics for missing imports, we can associate them with quick fixes.
			if findImportErrors(diagnostics) {
				// Separate this into a set of codeActions per diagnostic, where
				// each action is the addition, removal, or renaming of one import.
				for _, importFix := range editsPerFix {
					// Get the diagnostics this fix would affect.
					if fixDiagnostics := importDiagnostics(importFix.Fix, diagnostics); len(fixDiagnostics) > 0 {
						codeActions = append(codeActions, protocol.CodeAction{
							Title: importFixTitle(importFix.Fix),
							Kind:  protocol.QuickFix,
							Edit: protocol.WorkspaceEdit{
								DocumentChanges: documentChanges(fh, importFix.Edits),
							},
							Diagnostics: fixDiagnostics,
						})
					}
				}
			}
		}
		if wanted[protocol.SourceOrganizeImports] && len(edits) > 0 {
			codeActions = append(codeActions, protocol.CodeAction{
				Title: "Organize Imports",
				Kind:  protocol.SourceOrganizeImports,
				Edit: protocol.WorkspaceEdit{
					DocumentChanges: documentChanges(fh, edits),
				},
			})
		}
	default:
		// Unsupported file kind for a code action.
		return nil, nil
	}
	return codeActions, nil
}

func (s *Server) getSupportedCodeActions() []protocol.CodeActionKind {
	allCodeActionKinds := make(map[protocol.CodeActionKind]struct{})
	for _, kinds := range s.session.Options().SupportedCodeActions {
		for kind := range kinds {
			allCodeActionKinds[kind] = struct{}{}
		}
	}
	var result []protocol.CodeActionKind
	for kind := range allCodeActionKinds {
		result = append(result, kind)
	}
	sort.Slice(result, func(i, j int) bool {
		return result[i] < result[j]
	})
	return result
}

type protocolImportFix struct {
	fix   *imports.ImportFix
	edits []protocol.TextEdit
}

// findImports determines if a given diagnostic represents an error that could
// be fixed by organizing imports.
// TODO(rstambler): We need a better way to check this than string matching.
func findImportErrors(diagnostics []protocol.Diagnostic) bool {
	for _, diagnostic := range diagnostics {
		// "undeclared name: X" may be an unresolved import.
		if strings.HasPrefix(diagnostic.Message, "undeclared name: ") {
			return true
		}
		// "could not import: X" may be an invalid import.
		if strings.HasPrefix(diagnostic.Message, "could not import: ") {
			return true
		}
		// "X imported but not used" is an unused import.
		// "X imported but not used as Y" is an unused import.
		if strings.Contains(diagnostic.Message, " imported but not used") {
			return true
		}
	}
	return false
}

func importFixTitle(fix *imports.ImportFix) string {
	var str string
	switch fix.FixType {
	case imports.AddImport:
		str = fmt.Sprintf("Add import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
	case imports.DeleteImport:
		str = fmt.Sprintf("Delete import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
	case imports.SetImportName:
		str = fmt.Sprintf("Rename import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
	}
	return str
}

func importDiagnostics(fix *imports.ImportFix, diagnostics []protocol.Diagnostic) (results []protocol.Diagnostic) {
	for _, diagnostic := range diagnostics {
		switch {
		// "undeclared name: X" may be an unresolved import.
		case strings.HasPrefix(diagnostic.Message, "undeclared name: "):
			ident := strings.TrimPrefix(diagnostic.Message, "undeclared name: ")
			if ident == fix.IdentName {
				results = append(results, diagnostic)
			}
		// "could not import: X" may be an invalid import.
		case strings.HasPrefix(diagnostic.Message, "could not import: "):
			ident := strings.TrimPrefix(diagnostic.Message, "could not import: ")
			if ident == fix.IdentName {
				results = append(results, diagnostic)
			}
		// "X imported but not used" is an unused import.
		// "X imported but not used as Y" is an unused import.
		case strings.Contains(diagnostic.Message, " imported but not used"):
			idx := strings.Index(diagnostic.Message, " imported but not used")
			importPath := diagnostic.Message[:idx]
			if importPath == fmt.Sprintf("%q", fix.StmtInfo.ImportPath) {
				results = append(results, diagnostic)
			}
		}
	}
	return results
}

func quickFixes(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle, diagnostics []protocol.Diagnostic) ([]protocol.CodeAction, error) {
	var codeActions []protocol.CodeAction

	phs, err := snapshot.PackageHandles(ctx, fh)
	if err != nil {
		return nil, err
	}
	// We get the package that source.Diagnostics would've used. This is hack.
	// TODO(golang/go#32443): The correct solution will be to cache diagnostics per-file per-snapshot.
	ph, err := source.WidestPackageHandle(phs)
	if err != nil {
		return nil, err
	}
	for _, diag := range diagnostics {
		// This code assumes that the analyzer name is the Source of the diagnostic.
		// If this ever changes, this will need to be addressed.
		srcErr, err := snapshot.FindAnalysisError(ctx, ph.ID(), diag.Source, diag.Message, diag.Range)
		if err != nil {
			continue
		}
		for _, fix := range srcErr.SuggestedFixes {
			action := protocol.CodeAction{
				Title:       fix.Title,
				Kind:        protocol.QuickFix,
				Diagnostics: []protocol.Diagnostic{diag},
				Edit:        protocol.WorkspaceEdit{},
			}
			for uri, edits := range fix.Edits {
				fh, err := snapshot.GetFile(uri)
				if err != nil {
					log.Error(ctx, "no file", err, telemetry.URI.Of(uri))
					continue
				}
				action.Edit.DocumentChanges = append(action.Edit.DocumentChanges, documentChanges(fh, edits)...)
			}
			codeActions = append(codeActions, action)
		}
	}
	return codeActions, nil
}

func documentChanges(fh source.FileHandle, edits []protocol.TextEdit) []protocol.TextDocumentEdit {
	return []protocol.TextDocumentEdit{
		{
			TextDocument: protocol.VersionedTextDocumentIdentifier{
				Version: fh.Identity().Version,
				TextDocumentIdentifier: protocol.TextDocumentIdentifier{
					URI: protocol.NewURI(fh.Identity().URI),
				},
			},
			Edits: edits,
		},
	}
}