Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions script/files.lua
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,17 @@ function m.compileStateAsync(uri, callback)
end)
end

local function pluginOnTransformAst(uri, state)
local plugin = require 'plugin'
---TODO: maybe deepcopy astNode
local suc, result = plugin.dispatch('OnTransformAst', uri, state.ast)
if not suc then
return state
end
state.ast = result
return state
end

---@param uri uri
---@return parser.state?
function m.compileState(uri)
Expand Down Expand Up @@ -700,6 +711,12 @@ function m.compileState(uri)
return nil
end

state = pluginOnTransformAst(uri, state)
if not state then
log.error('pluginOnTransformAst failed! discard the file state')
return nil
end

m.compileStateThen(state, file)

return state
Expand Down
34 changes: 34 additions & 0 deletions script/parser/guide.lua
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,22 @@ function m.getParentType(obj, want)
error('guide.getParentType overstack')
end

--- 寻找所在父类型
---@param obj parser.object
---@return parser.object?
function m.getParentTypes(obj, wants)
for _ = 1, 10000 do
obj = obj.parent
if not obj then
return nil
end
if wants[obj.type] then
return obj
end
end
error('guide.getParentTypes overstack')
end

--- 寻找根区块
---@param obj parser.object
---@return parser.object
Expand Down Expand Up @@ -1289,4 +1305,22 @@ function m.isParam(source)
return true
end

---@param source parser.object
---@param index integer
---@return parser.object?
function m.getParam(source, index)
if source.type == 'call' then
local args = source.args
assert(args.type == 'callargs', 'call.args type is\'t callargs')
return args[index]
elseif source.type == 'callargs' then
return source[index]
elseif source.type == 'function' then
local args = source.args
assert(args.type == 'funcargs', 'function.args type is\'t callargs')
return args[index]
end
return nil
end

return m
2 changes: 1 addition & 1 deletion script/parser/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ local api = {
compile = require 'parser.compile',
lines = require 'parser.lines',
guide = require 'parser.guide',
luadoc = require 'parser.luadoc',
luadoc = require 'parser.luadoc'.luadoc,
}

return api
38 changes: 36 additions & 2 deletions script/parser/luadoc.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1921,6 +1921,14 @@ local function bindDocWithSources(sources, binded)
bindGeneric(binded)
bindCommentsAndFields(binded)
bindReturnIndex(binded)

-- doc is special node
if lastDoc.special then
if bindDoc(lastDoc.special, binded) then
return
end
end

local row = guide.rowColOf(lastDoc.finish)
local suc = bindDocsBetween(sources, binded, guide.positionOf(row, 0), lastDoc.start)
if not suc then
Expand Down Expand Up @@ -1956,7 +1964,8 @@ local function bindDocs(state)
binded = nil
else
local nextDoc = state.ast.docs[i+1]
if not isNextLine(doc, nextDoc) then
if nextDoc and nextDoc.special
or not isNextLine(doc, nextDoc) then
bindDocWithSources(sources, binded)
binded = nil
end
Expand Down Expand Up @@ -1985,7 +1994,7 @@ local function findTouch(state, doc)
end
end

return function (state)
local function luadoc(state)
local ast = state.ast
local comments = state.comms
table.sort(comments, function (a, b)
Expand Down Expand Up @@ -2054,6 +2063,15 @@ return function (state)
end
end
end

if ast.state.pluginDocs then
for i, doc in ipairs(ast.state.pluginDocs) do
insertDoc(doc, doc.originalComment)
end
table.sort(ast.docs, function (a, b)
return a.start < b.start
end)
end

ast.docs.start = ast.start
ast.docs.finish = ast.finish
Expand All @@ -2064,3 +2082,19 @@ return function (state)

bindDocs(state)
end

return {
buildAndBindDoc = function (ast, src, comment)
local doc = buildLuaDoc(comment)
if doc then
local pluginDocs = ast.state.pluginDocs or {}
pluginDocs[#pluginDocs+1] = doc
doc.special = src
doc.originalComment = comment
ast.state.pluginDocs = pluginDocs
return true
end
return false
end,
luadoc = luadoc
}
66 changes: 66 additions & 0 deletions script/plugins/astHelper.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
local luadoc = require 'parser.luadoc'
local ssub = require 'core.substring'
local guide = require 'parser.guide'
local _M = {}

function _M.buildComment(t, value)
return {
type = 'comment.short',
start = 1,
finish = 1,
text = "-@" .. t .. " " .. value,
}
end

function _M.InsertDoc(ast, comm)
local comms = ast.state.comms or {}
comms[#comms+1] = comm
ast.state.comms = comms
end

--- give the local/global variable add doc.class
---@param ast parser.object
---@param source parser.object local/global variable
---@param classname string
function _M.addClassDoc(ast, source, classname)
if source.type ~= 'local' and not guide.isGlobal(source) then
return false
end
--TODO fileds
--TODO callers
local comment = _M.buildComment("class", classname)
comment.start = source.start - 1
comment.finish = comment.start

return luadoc.buildAndBindDoc(ast, source, comment)
end

---remove `ast` function node `index` arg, the variable will be the function local variable
---@param source parser.object function node
---@param index integer
---@return parser.object?
function _M.removeArg(source, index)
if source.type == 'function' then
local arg = table.remove(source.args, index)
if not arg then
return nil
end
arg.parent = arg.parent.parent
return arg
end
return nil
end

--- 把特定函数当成构造函数,`index` 参数是self
---@param classname string
---@param source parser.object function node
---@param index integer
function _M.addClassDocAtParam(ast, classname, source, index)
local arg = _M.removeArg(source, index)
if arg then
return _M.addClassDoc(ast, arg, classname)
end
return false
end

return _M
5 changes: 3 additions & 2 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,12 @@ local function main()

testAll()
end)

test 'tclient'
test 'full'
test 'plugins.ffi.test'
end
test 'plugins.ast'
end

loadAllLibs()
main()
Expand Down
2 changes: 1 addition & 1 deletion test/full/example.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ local files = require 'files'
local diag = require 'core.diagnostics'
local config = require 'config'
local fs = require 'bee.filesystem'
local luadoc = require "parser.luadoc"
local luadoc = require "parser".luadoc

-- 临时
---@diagnostic disable: await-in-sync
Expand Down
92 changes: 92 additions & 0 deletions test/plugins/ast/init.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
local config = require 'config'
local utility = require 'utility'
local parser = require 'parser'
local luadoc = require 'parser.luadoc'
local guide = require 'parser.guide'
local vm = require 'vm'
local helper = require 'plugins.astHelper'

---@diagnostic disable: await-in-sync
local function TestPlugin(script, plugin, checker)
config.set(TESTURI, 'Lua.workspace.preloadFileSize', 1000000000)
local state = parser.compile(script, "Lua", "Lua 5.4")
state.ast = plugin(TESTURI, state.ast) or state.ast
parser.luadoc(state)
checker(state)
end

local function isDocClass(ast)
return ast.bindDocs[1].type == 'doc.class'
end

local function TestAIsClass(state, next)
assert(isDocClass(state.ast[1]))
end

--- when call Class
local function plugin_AddClass(uri, ast)
guide.eachSourceType(ast, "call", function (source)
local node = source.node
if not guide.isGet(node) then
return
end
if not guide.isGlobal(node) then
return
end
if guide.getKeyName(node) ~= 'Class' then
return
end
local wants = {
['local'] = true,
['setglobal'] = true
}
local classnameNode = guide.getParentTypes(source, wants)
if not classnameNode then
return
end
local classname = guide.getKeyName(classnameNode)
if classname then
helper.addClassDoc(ast, classnameNode, classname)
end
end)
end

local function plugin_AddClassAtParam(uri, ast)
guide.eachSourceType(ast, "function", function (src)
helper.addClassDocAtParam(ast, "A", src, 1)
end)
end

local function TestSelfIsClass(state, next)
guide.eachSourceType(state.ast, "local", function (source)
if source[1] == 'self' then
assert(source.bindDocs)
assert(source.parent.type == 'function')
assert(#source.parent.args == 0)
end
end)
end

local function TestPlugin1(script)
TestPlugin(script, plugin_AddClass, TestAIsClass)
end

local function TestPlugin2(script)
TestPlugin(script, plugin_AddClassAtParam, TestSelfIsClass)
end

TestPlugin1 [[
local A = Class(function() end)
]]

TestPlugin1 [[
A = Class(function() end)
]]

TestPlugin2 [[
local function ctor(self) end
]]

TestPlugin2 [[
function ctor(self) end
]]