Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(registry): use oneshot channel for updating registry #1168

Merged
merged 1 commit into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 40 additions & 12 deletions lua/mason-core/async/control.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,21 @@ local Condvar = {}
Condvar.__index = Condvar

function Condvar.new()
return setmetatable({ handles = {}, queue = {}, is_notifying = false }, Condvar)
return setmetatable({ handles = {} }, Condvar)
end

---@async
function Condvar:wait()
a.wait(function(resolve)
if self.is_notifying then
self.queue[resolve] = true
else
self.handles[resolve] = true
end
self.handles[#self.handles + 1] = resolve
end)
end

function Condvar:notify_all()
self.is_notifying = true
for handle in pairs(self.handles) do
handle()
for _, handle in ipairs(self.handles) do
pcall(handle)
end
self.handles = self.queue
self.queue = {}
self.is_notifying = false
self.handles = {}
end

local Permit = {}
Expand Down Expand Up @@ -69,7 +62,42 @@ function Semaphore:acquire()
return Permit.new(self)
end

---@class OneShotChannel
---@field has_sent boolean
---@field value any
---@field condvar Condvar
local OneShotChannel = {}
OneShotChannel.__index = OneShotChannel

function OneShotChannel.new()
return setmetatable({
has_sent = false,
value = nil,
condvar = Condvar.new(),
}, OneShotChannel)
end

function OneShotChannel:is_closed()
return self.has_sent
end

function OneShotChannel:send(...)
assert(not self.has_sent, "Oneshot channel can only send once.")
self.has_sent = true
self.value = { ... }
self.condvar:notify_all()
self.condvar = nil
end

function OneShotChannel:receive()
if not self.has_sent then
self.condvar:wait()
end
return unpack(self.value)
end

return {
Condvar = Condvar,
Semaphore = Semaphore,
OneShotChannel = OneShotChannel,
}
40 changes: 5 additions & 35 deletions lua/mason-core/async/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -153,41 +153,11 @@ exports.scheduler = function()
await(vim.schedule)
end

---Creates a oneshot channel that can only send once.
local function oneshot_channel()
local has_sent = false
local sent_value
local saved_callback

return {
is_closed = function()
return has_sent
end,
send = function(...)
assert(not has_sent, "Oneshot channel can only send once.")
has_sent = true
sent_value = { ... }
if saved_callback then
saved_callback(unpack(sent_value))
end
end,
receive = function()
return await(function(resolve)
if has_sent then
resolve(unpack(sent_value))
else
saved_callback = resolve
end
end)
end,
}
end

---@async
---@param suspend_fns async fun()[]
---@param mode '"first"' | '"all"'
local function wait(suspend_fns, mode)
local channel = oneshot_channel()
local channel = require("mason-core.async.control").OneShotChannel.new()
if #suspend_fns == 0 then
return
end
Expand All @@ -208,17 +178,17 @@ local function wait(suspend_fns, mode)
thread_cancellations[i] = exports.run(suspend_fn, function(success, result)
completed = completed + 1
if not success then
if not channel.is_closed() then
if not channel:is_closed() then
cancel()
channel.send(false, result)
channel:send(false, result)
results = nil
thread_cancellations = {}
end
else
results[i] = result
if mode == "first" or completed >= count then
cancel()
channel.send(true, mode == "first" and { result } or results)
channel:send(true, mode == "first" and { result } or results)
results = nil
thread_cancellations = {}
end
Expand All @@ -227,7 +197,7 @@ local function wait(suspend_fns, mode)
end
end

local ok, results = channel.receive()
local ok, results = channel:receive()
if not ok then
error(results, 2)
end
Expand Down
21 changes: 2 additions & 19 deletions lua/mason-registry/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -150,25 +150,8 @@ end
---@param callback? fun(success: boolean, updated_registries: RegistrySource[])
function M.update(callback)
local a = require "mason-core.async"
local Result = require "mason-core.result"
return a.run(function()
return Result.try(function(try)
local updated_sources = {}
for source in sources.iter { include_uninstalled = true } do
source:get_installer():if_present(function(installer)
try(installer():map_err(function(err)
return ("%s failed to install: %s"):format(source, err)
end))
table.insert(updated_sources, source)
end)
end
return updated_sources
end):on_success(function(updated_sources)
if #updated_sources > 0 then
M:emit("update", updated_sources)
end
end)
end, function(success, result)

return a.run(require("mason-registry.installer").run, function(success, result)
if not callback then
return
end
Expand Down
38 changes: 38 additions & 0 deletions lua/mason-registry/installer.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
local a = require "mason-core.async"
local OneShotChannel = require("mason-core.async.control").OneShotChannel
local Result = require "mason-core.result"
local sources = require "mason-registry.sources"

local M = {}

---@type OneShotChannel?
local update_channel

---@async
function M.run()
if not update_channel or update_channel:is_closed() then
update_channel = OneShotChannel.new()
a.run(function()
update_channel:send(Result.try(function(try)
local updated_sources = {}
for source in sources.iter { include_uninstalled = true } do
source:get_installer():if_present(function(installer)
try(installer():map_err(function(err)
return ("%s failed to install: %s"):format(source, err)
end))
table.insert(updated_sources, source)
end)
end
return updated_sources
end):on_success(function(updated_sources)
if #updated_sources > 0 then
require("mason-registry"):emit("update", updated_sources)
end
end))
end, function() end)
end

return update_channel:receive()
end

return M