Skip to content

Commit 9e1e7cc

Browse files
authored
Merge pull request #2177 from sewbacca/feature/shortcut-autorequire
[Feature] Add action to autorequire undefined globals
2 parents 00fdd47 + 76315ef commit 9e1e7cc

File tree

7 files changed

+238
-48
lines changed

7 files changed

+238
-48
lines changed

locale/en-us/script.lua

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,8 @@ ACTION_ADD_DICT =
448448
'Add \'{}\' to workspace dict'
449449
ACTION_FIX_ADD_PAREN = -- TODO: need translate!
450450
'Add parentheses.'
451+
ACTION_AUTOREQUIRE = -- TODO: need translate!
452+
"Import '{}' as {}"
451453

452454
COMMAND_DISABLE_DIAG =
453455
'Disable diagnostics'

script/core/code-action.lua

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ local util = require 'utility'
44
local sp = require 'bee.subprocess'
55
local guide = require "parser.guide"
66
local converter = require 'proto.converter'
7+
local autoreq = require 'core.completion.auto-require'
8+
local rpath = require 'workspace.require-path'
9+
local furi = require 'file-uri'
10+
local undefined = require 'core.diagnostics.undefined-global'
11+
local vm = require 'vm'
712

813
---@param uri uri
914
---@param row integer
@@ -676,6 +681,54 @@ local function checkJsonToLua(results, uri, start, finish)
676681
}
677682
end
678683

684+
local function findRequireTargets(visiblePaths)
685+
local targets = {}
686+
for _, visible in ipairs(visiblePaths) do
687+
targets[#targets+1] = visible.name
688+
end
689+
return targets
690+
end
691+
692+
local function checkMissingRequire(results, uri, start, finish)
693+
local state = files.getState(uri)
694+
local text = files.getText(uri)
695+
if not state or not text then
696+
return
697+
end
698+
699+
local function addRequires(global, endpos)
700+
autoreq.check(state, global, endpos, function(moduleFile, stemname, targetSource)
701+
local visiblePaths = rpath.getVisiblePath(uri, furi.decode(moduleFile))
702+
if not visiblePaths or #visiblePaths == 0 then return end
703+
704+
for _, target in ipairs(findRequireTargets(visiblePaths)) do
705+
results[#results+1] = {
706+
title = lang.script('ACTION_AUTOREQUIRE', target, global),
707+
kind = 'refactor.rewrite',
708+
command = {
709+
title = 'autoRequire',
710+
command = 'lua.autoRequire',
711+
arguments = {
712+
{
713+
uri = guide.getUri(state.ast),
714+
target = moduleFile,
715+
name = global,
716+
requireName = target
717+
},
718+
},
719+
}
720+
}
721+
end
722+
end)
723+
end
724+
725+
guide.eachSourceBetween(state.ast, start, finish, function (source)
726+
if vm.isUndefinedGlobal(source) then
727+
addRequires(source[1], source.finish)
728+
end
729+
end)
730+
end
731+
679732
return function (uri, start, finish, diagnostics)
680733
local ast = files.getState(uri)
681734
if not ast then
@@ -688,6 +741,7 @@ return function (uri, start, finish, diagnostics)
688741
checkSwapParams(results, uri, start, finish)
689742
--checkExtractAsFunction(results, uri, start, finish)
690743
checkJsonToLua(results, uri, start, finish)
744+
checkMissingRequire(results, uri, start, finish)
691745

692746
return results
693747
end

script/core/command/autoRequire.lua

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ return function (data)
135135
local uri = data.uri
136136
local target = data.target
137137
local name = data.name
138+
local requireName = data.requireName
138139
local state = files.getState(uri)
139140
if not state then
140141
return
@@ -149,11 +150,13 @@ return function (data)
149150
return #a.name < #b.name
150151
end)
151152

152-
local result = askAutoRequire(uri, visiblePaths)
153-
if not result then
154-
return
153+
if not requireName then
154+
requireName = askAutoRequire(uri, visiblePaths)
155+
if not requireName then
156+
return
157+
end
155158
end
156159

157160
local offset, fmt = findInsertRow(uri)
158-
applyAutoRequire(uri, offset, name, result, fmt)
161+
applyAutoRequire(uri, offset, name, requireName, fmt)
159162
end

script/core/diagnostics/undefined-global.lua

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,41 +20,21 @@ return function (uri, callback)
2020
return
2121
end
2222

23-
local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals'))
24-
local rspecial = config.get(uri, 'Lua.runtime.special')
25-
local cache = {}
26-
2723
-- 遍历全局变量,检查所有没有 set 模式的全局变量
2824
guide.eachSourceType(state.ast, 'getglobal', function (src) ---@async
29-
local key = src[1]
30-
if not key then
31-
return
32-
end
33-
if dglobals[key] then
34-
return
35-
end
36-
if rspecial[key] then
37-
return
38-
end
39-
local node = src.node
40-
if node.tag ~= '_ENV' then
41-
return
42-
end
43-
if cache[key] == nil then
44-
await.delay()
45-
cache[key] = vm.hasGlobalSets(uri, 'variable', key)
46-
end
47-
if cache[key] then
48-
return
49-
end
50-
local message = lang.script('DIAG_UNDEF_GLOBAL', key)
51-
if requireLike[key:lower()] then
52-
message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key))
25+
if vm.isUndefinedGlobal(src) then
26+
local key = src[1]
27+
local message = lang.script('DIAG_UNDEF_GLOBAL', key)
28+
if requireLike[key:lower()] then
29+
message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key))
30+
end
31+
32+
callback {
33+
start = src.start,
34+
finish = src.finish,
35+
message = message,
36+
undefinedGlobal = src[1]
37+
}
5338
end
54-
callback {
55-
start = src.start,
56-
finish = src.finish,
57-
message = message,
58-
}
5939
end)
6040
end

script/vm/global.lua

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
local util = require 'utility'
22
local scope = require 'workspace.scope'
33
local guide = require 'parser.guide'
4+
local config = require 'config'
45
---@class vm
56
local vm = require 'vm.vm'
67

@@ -518,6 +519,33 @@ function vm.hasGlobalSets(suri, cate, name)
518519
return true
519520
end
520521

522+
---@param src parser.object
523+
local function checkIsUndefinedGlobal(src)
524+
local key = src[1]
525+
526+
local uri = guide.getUri(src)
527+
local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals'))
528+
local rspecial = config.get(uri, 'Lua.runtime.special')
529+
530+
local node = src.node
531+
return src.type == 'getglobal' and key and not (
532+
dglobals[key] or
533+
rspecial[key] or
534+
node.tag ~= '_ENV' or
535+
vm.hasGlobalSets(uri, 'variable', key)
536+
)
537+
end
538+
539+
---@param src parser.object
540+
---@return boolean
541+
function vm.isUndefinedGlobal(src)
542+
local node = vm.compileNode(src)
543+
if node.undefinedGlobal == nil then
544+
node.undefinedGlobal = checkIsUndefinedGlobal(src)
545+
end
546+
return node.undefinedGlobal
547+
end
548+
521549
---@param source parser.object
522550
function compileObject(source)
523551
if source._globalNode ~= nil then

script/vm/node.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ vm.nodeCache = setmetatable({}, util.MODE_K)
1515
---@field [integer] vm.node.object
1616
---@field [vm.node.object] true
1717
---@field fields? table<vm.node|string, vm.node>
18+
---@field undefinedGlobal boolean?
1819
local mt = {}
1920
mt.__index = mt
2021
mt.id = 0

test/code_action/init.lua

Lines changed: 133 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,40 @@ local core = require 'core.code-action'
22
local files = require 'files'
33
local lang = require 'language'
44
local catch = require 'catch'
5+
local furi = require 'file-uri'
56

67
rawset(_G, 'TEST', true)
78

89
local EXISTS = {}
910

10-
local function eq(a, b)
11-
if a == EXISTS and b ~= nil then
11+
local function eq(expected, result)
12+
if expected == EXISTS and result ~= nil then
1213
return true
1314
end
14-
if b == EXISTS and a ~= nil then
15+
if result == EXISTS and expected ~= nil then
1516
return true
1617
end
17-
local tp1, tp2 = type(a), type(b)
18+
local tp1, tp2 = type(expected), type(result)
1819
if tp1 ~= tp2 then
19-
return false
20+
return false, string.format(": expected type %s, got %s for %s", tp1, tp2)
2021
end
2122
if tp1 == 'table' then
2223
local mark = {}
23-
for k in pairs(a) do
24-
if not eq(a[k], b[k]) then
25-
return false
24+
for k in pairs(expected) do
25+
local ok, err = eq(expected[k], result[k])
26+
if not ok then
27+
return false, string.format(".%s%s", k, err)
2628
end
2729
mark[k] = true
2830
end
29-
for k in pairs(b) do
31+
for k in pairs(result) do
3032
if not mark[k] then
31-
return false
33+
return false, string.format(".%s: missing key in result", k)
3234
end
3335
end
3436
return true
3537
end
36-
return a == b
38+
return expected == result, string.format(": expected %s, got %s", expected, result)
3739
end
3840

3941
function TEST(script)
@@ -47,6 +49,32 @@ function TEST(script)
4749
end
4850
end
4951

52+
local function TEST_CROSSFILE(testfiles)
53+
local mainscript = table.remove(testfiles, 1)
54+
return function(expected)
55+
for _, data in ipairs(testfiles) do
56+
local uri = furi.encode(data.path)
57+
files.setText(uri, data.content)
58+
files.compileState(uri)
59+
end
60+
61+
local newScript, catched = catch(mainscript, '?')
62+
files.setText(TESTURI, newScript)
63+
files.compileState(TESTURI)
64+
65+
local _ <close> = function ()
66+
for _, info in ipairs(testfiles) do
67+
files.remove(furi.encode(info.path))
68+
end
69+
files.remove(TESTURI)
70+
end
71+
72+
local results = core(TESTURI, catched['?'][1][1], catched['?'][1][2])
73+
assert(results)
74+
assert(eq(expected, results))
75+
end
76+
end
77+
5078
TEST [[
5179
print(<?a?>, b, c)
5280
]]
@@ -154,3 +182,97 @@ local t = {
154182
-- edit = EXISTS,
155183
-- },
156184
--}
185+
186+
TEST_CROSSFILE {
187+
[[
188+
<?unrequiredModule?>.myFunction()
189+
]],
190+
{
191+
path = 'unrequiredModule.lua',
192+
content = [[
193+
local m = {}
194+
m.myFunction = print
195+
return m
196+
]]
197+
}
198+
} {
199+
{
200+
title = lang.script('ACTION_AUTOREQUIRE', 'unrequiredModule', 'unrequiredModule'),
201+
kind = 'refactor.rewrite',
202+
command = {
203+
title = 'autoRequire',
204+
command = 'lua.autoRequire',
205+
arguments = {
206+
{
207+
uri = TESTURI,
208+
target = furi.encode 'unrequiredModule.lua',
209+
name = 'unrequiredModule',
210+
requireName = 'unrequiredModule'
211+
},
212+
},
213+
}
214+
}
215+
}
216+
217+
TEST_CROSSFILE {
218+
[[
219+
<?myModule?>.myFunction()
220+
]],
221+
{
222+
path = 'myModule/init.lua',
223+
content = [[
224+
local m = {}
225+
m.myFunction = print
226+
return m
227+
]]
228+
}
229+
} {
230+
{
231+
title = lang.script('ACTION_AUTOREQUIRE', 'myModule.init', 'myModule'),
232+
kind = 'refactor.rewrite',
233+
command = {
234+
title = 'autoRequire',
235+
command = 'lua.autoRequire',
236+
arguments = {
237+
{
238+
uri = TESTURI,
239+
target = furi.encode 'myModule/init.lua',
240+
name = 'myModule',
241+
requireName = 'myModule.init'
242+
},
243+
},
244+
}
245+
},
246+
{
247+
title = lang.script('ACTION_AUTOREQUIRE', 'init', 'myModule'),
248+
kind = 'refactor.rewrite',
249+
command = {
250+
title = 'autoRequire',
251+
command = 'lua.autoRequire',
252+
arguments = {
253+
{
254+
uri = TESTURI,
255+
target = furi.encode 'myModule/init.lua',
256+
name = 'myModule',
257+
requireName = 'init'
258+
},
259+
},
260+
}
261+
},
262+
{
263+
title = lang.script('ACTION_AUTOREQUIRE', 'myModule', 'myModule'),
264+
kind = 'refactor.rewrite',
265+
command = {
266+
title = 'autoRequire',
267+
command = 'lua.autoRequire',
268+
arguments = {
269+
{
270+
uri = TESTURI,
271+
target = furi.encode 'myModule/init.lua',
272+
name = 'myModule',
273+
requireName = 'myModule'
274+
},
275+
},
276+
}
277+
},
278+
}

0 commit comments

Comments
 (0)