diff --git a/lua/mason-core/async/control.lua b/lua/mason-core/async/control.lua index 3252c0709..df2627f6a 100644 --- a/lua/mason-core/async/control.lua +++ b/lua/mason-core/async/control.lua @@ -5,28 +5,21 @@ local Condvar = {} Condvar.__index = Condvar function Condvar.new() - return setmetatable({ handles = {}, queue = {}, is_notifying = false }, Condvar) + return setmetatable({ handles = {} }, Condvar) end ---@async function Condvar:wait() a.wait(function(resolve) - if self.is_notifying then - self.queue[resolve] = true - else - self.handles[resolve] = true - end + self.handles[#self.handles + 1] = resolve end) end function Condvar:notify_all() - self.is_notifying = true - for handle in pairs(self.handles) do - handle() + for _, handle in ipairs(self.handles) do + pcall(handle) end - self.handles = self.queue - self.queue = {} - self.is_notifying = false + self.handles = {} end local Permit = {} @@ -69,7 +62,42 @@ function Semaphore:acquire() return Permit.new(self) end +---@class OneShotChannel +---@field has_sent boolean +---@field value any +---@field condvar Condvar +local OneShotChannel = {} +OneShotChannel.__index = OneShotChannel + +function OneShotChannel.new() + return setmetatable({ + has_sent = false, + value = nil, + condvar = Condvar.new(), + }, OneShotChannel) +end + +function OneShotChannel:is_closed() + return self.has_sent +end + +function OneShotChannel:send(...) + assert(not self.has_sent, "Oneshot channel can only send once.") + self.has_sent = true + self.value = { ... } + self.condvar:notify_all() + self.condvar = nil +end + +function OneShotChannel:receive() + if not self.has_sent then + self.condvar:wait() + end + return unpack(self.value) +end + return { Condvar = Condvar, Semaphore = Semaphore, + OneShotChannel = OneShotChannel, } diff --git a/lua/mason-core/async/init.lua b/lua/mason-core/async/init.lua index a6c7c8efd..df7c996c9 100644 --- a/lua/mason-core/async/init.lua +++ b/lua/mason-core/async/init.lua @@ -153,41 +153,11 @@ exports.scheduler = function() await(vim.schedule) end ----Creates a oneshot channel that can only send once. -local function oneshot_channel() - local has_sent = false - local sent_value - local saved_callback - - return { - is_closed = function() - return has_sent - end, - send = function(...) - assert(not has_sent, "Oneshot channel can only send once.") - has_sent = true - sent_value = { ... } - if saved_callback then - saved_callback(unpack(sent_value)) - end - end, - receive = function() - return await(function(resolve) - if has_sent then - resolve(unpack(sent_value)) - else - saved_callback = resolve - end - end) - end, - } -end - ---@async ---@param suspend_fns async fun()[] ---@param mode '"first"' | '"all"' local function wait(suspend_fns, mode) - local channel = oneshot_channel() + local channel = require("mason-core.async.control").OneShotChannel.new() if #suspend_fns == 0 then return end @@ -208,9 +178,9 @@ local function wait(suspend_fns, mode) thread_cancellations[i] = exports.run(suspend_fn, function(success, result) completed = completed + 1 if not success then - if not channel.is_closed() then + if not channel:is_closed() then cancel() - channel.send(false, result) + channel:send(false, result) results = nil thread_cancellations = {} end @@ -218,7 +188,7 @@ local function wait(suspend_fns, mode) results[i] = result if mode == "first" or completed >= count then cancel() - channel.send(true, mode == "first" and { result } or results) + channel:send(true, mode == "first" and { result } or results) results = nil thread_cancellations = {} end @@ -227,7 +197,7 @@ local function wait(suspend_fns, mode) end end - local ok, results = channel.receive() + local ok, results = channel:receive() if not ok then error(results, 2) end diff --git a/lua/mason-registry/init.lua b/lua/mason-registry/init.lua index 67a63976f..93472ef3c 100644 --- a/lua/mason-registry/init.lua +++ b/lua/mason-registry/init.lua @@ -150,25 +150,8 @@ end ---@param callback? fun(success: boolean, updated_registries: RegistrySource[]) function M.update(callback) local a = require "mason-core.async" - local Result = require "mason-core.result" - return a.run(function() - return Result.try(function(try) - local updated_sources = {} - for source in sources.iter { include_uninstalled = true } do - source:get_installer():if_present(function(installer) - try(installer():map_err(function(err) - return ("%s failed to install: %s"):format(source, err) - end)) - table.insert(updated_sources, source) - end) - end - return updated_sources - end):on_success(function(updated_sources) - if #updated_sources > 0 then - M:emit("update", updated_sources) - end - end) - end, function(success, result) + + return a.run(require("mason-registry.installer").run, function(success, result) if not callback then return end diff --git a/lua/mason-registry/installer.lua b/lua/mason-registry/installer.lua new file mode 100644 index 000000000..31fe0d85d --- /dev/null +++ b/lua/mason-registry/installer.lua @@ -0,0 +1,38 @@ +local a = require "mason-core.async" +local OneShotChannel = require("mason-core.async.control").OneShotChannel +local Result = require "mason-core.result" +local sources = require "mason-registry.sources" + +local M = {} + +---@type OneShotChannel? +local update_channel + +---@async +function M.run() + if not update_channel or update_channel:is_closed() then + update_channel = OneShotChannel.new() + a.run(function() + update_channel:send(Result.try(function(try) + local updated_sources = {} + for source in sources.iter { include_uninstalled = true } do + source:get_installer():if_present(function(installer) + try(installer():map_err(function(err) + return ("%s failed to install: %s"):format(source, err) + end)) + table.insert(updated_sources, source) + end) + end + return updated_sources + end):on_success(function(updated_sources) + if #updated_sources > 0 then + require("mason-registry"):emit("update", updated_sources) + end + end)) + end, function() end) + end + + return update_channel:receive() +end + +return M