Skip to content

Commit

Permalink
chore(async): add Channel (#1456)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamboman authored Aug 17, 2023
1 parent 68e6a15 commit b5bb138
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 2 deletions.
62 changes: 60 additions & 2 deletions lua/mason-core/async/control.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@ function Condvar:wait()
end)
end

function Condvar:notify()
local handle = table.remove(self.handles)
pcall(handle)
end

function Condvar:notify_all()
for _, handle in ipairs(self.handles) do
pcall(handle)
while #self.handles > 0 do
self:notify()
end
self.handles = {}
end
Expand Down Expand Up @@ -97,8 +102,61 @@ function OneShotChannel:receive()
return unpack(self.value)
end

---@class Channel
---@field private condvar Condvar
---@field private buffer any?
---@field is_closed boolean
local Channel = {}
Channel.__index = Channel
function Channel.new()
return setmetatable({
condvar = Condvar.new(),
buffer = nil,
is_closed = false,
}, Channel)
end

function Channel:close()
self.is_closed = true
end

---@async
function Channel:send(value)
assert(not self.is_closed, "Channel is closed.")
while self.buffer ~= nil do
self.condvar:wait()
end
self.buffer = value
self.condvar:notify()
while self.buffer ~= nil do
self.condvar:wait()
end
end

---@async
function Channel:receive()
assert(not self.is_closed, "Channel is closed.")
while self.buffer == nil do
self.condvar:wait()
end
local value = self.buffer
self.buffer = nil
self.condvar:notify()
return value
end

---@async
function Channel:iter()
return function()
while not self.is_closed do
return self:receive()
end
end
end

return {
Condvar = Condvar,
Semaphore = Semaphore,
OneShotChannel = OneShotChannel,
Channel = Channel,
}
55 changes: 55 additions & 0 deletions tests/mason-core/async/async_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,58 @@ describe("async :: OneShotChannel", function()
assert.equals(42, channel:receive())
end)
end)

describe("async :: Channel", function()
local Channel = control.Channel

it("should suspend send until buffer is received", function()
local channel = Channel.new()
spy.on(channel, "send")
local guard = spy.new()

a.run(function()
channel:send "message"
guard()
channel:send "another message"
end, function() end)

assert.spy(channel.send).was_called(1)
assert.spy(channel.send).was_called_with(match.is_ref(channel), "message")
assert.spy(guard).was_not_called()
end)

it("should send subsequent messages after they're received", function()
local channel = Channel.new()
spy.on(channel, "send")

a.run(function()
channel:send "message"
channel:send "another message"
end, function() end)

local value = channel:receive()
assert.equals(value, "message")

assert.spy(channel.send).was_called(2)
assert.spy(channel.send).was_called_with(match.is_ref(channel), "message")
assert.spy(channel.send).was_called_with(match.is_ref(channel), "another message")
end)

it("should suspend receive until message is sent", function()
local channel = Channel.new()

a.run(function()
a.sleep(100)
channel:send "hello world"
end, function() end)

local start = timestamp()
local value = a.run_blocking(function()
return channel:receive()
end)
local stop = timestamp()

assert.is_true((stop - start) > 80)
assert.equals(value, "hello world")
end)
end)

0 comments on commit b5bb138

Please sign in to comment.