Skip to content

Commit

Permalink
Merge pull request #1 from tkf/stack
Browse files Browse the repository at this point in the history
Use stack
  • Loading branch information
vchuravy authored Jan 9, 2022
2 parents ac51bc5 + e1d1650 commit f4fde9d
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 88 deletions.
140 changes: 77 additions & 63 deletions src/ForeignCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,97 +8,111 @@ struct Node{T}
end
Node(data::T) where T = Node{T}(C_NULL, data)

setnext!(node::Ptr{Node{T}}, next::Ptr{Node{T}}) where {T} =
unsafe_store!(convert(Ptr{Ptr{Node{T}}}, node), next)

function calloc(::Type{T}) where T
ptr = Libc.malloc(sizeof(T))
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), ptr, 0, sizeof(T))
return convert(Ptr{T}, ptr)
end

mutable struct LockfreeQueue{T}
@atomic head::Ptr{Node{T}}
@atomic tail::Ptr{Node{T}}
function LockfreeQueue{T}() where T
@assert !Base.ismutabletype(T) && Base.isconcretetype(T) && Base.datatype_pointerfree(T)
tmp = calloc(Node{T})
new{T}(tmp, tmp)
"""
SingleConsumerStack{T}
A variant of Treiber stack that assumes single-consumer for simple (hopefully)
correct implementation.
Safety notes: `push!` and `unsafe_push!` on `SingleConsumerStack{T}` can be
called from multiple tasks/threads. `unsafe_push! can be called from foreign
threads without the Julia runtime. Only one Julia task is allowed to call
`popall!` at any point. This simplifies the implementation and avoids the use
of "counted pointer" (and hence 128 bit CAS) that would typically be required
for general Treiber stack with manual memory management.
"""
mutable struct SingleConsumerStack{T}
@atomic top::Ptr{Node{T}}
function SingleConsumerStack{T}() where T
@assert Base.datatype_pointerfree(Node{T})
new{T}(C_NULL)
end
end

# Notes:
# We require at-least one node in the queue.
# When we dequeue, we discard the head node and return the data
# from the new head.

function enqueue!(q::LockfreeQueue{T}, data::T) where T
function Base.push!(stack::SingleConsumerStack{T}, data::T) where T
node = Node(data)
p_node = convert(Ptr{Node{T}}, Libc.malloc(sizeof(Node{T})))
Base.unsafe_store!(p_node, node)

# Update the tail node in queue
p_tail = @atomicswap :acquire_release q.tail = p_node

# Link former tail to new tail
Core.Intrinsics.atomic_pointerset(convert(Ptr{Ptr{Node{T}}}, p_tail), p_node, :release)
return nothing
top = @atomic :monotonic stack.top
while true
setnext!(p_node, top)
top, ok = @atomicreplace :acquire_release :monotonic stack.top top => p_node
ok && return nothing
end
end

# Manual implementation of `enqueue!` since `unsafe_pointer_to_objref(ptr)::T` creates a gcframe
function unsafe_enqueue!(ptr::Ptr{Cvoid}, data::T) where T
# Manual implementation of `push!` since `unsafe_pointer_to_objref(ptr)::T` creates a gcframe
function unsafe_push!(ptr::Ptr{Cvoid}, data::T) where T
node = Node(data)
p_node = convert(Ptr{Node{T}}, Libc.malloc(sizeof(Node{T})))
Base.unsafe_store!(p_node, node)

# Update the tail node in queue
ptr += fieldoffset(ForeignCallbacks.LockfreeQueue{T}, 2)
p_tail = Core.Intrinsics.atomic_pointerswap(convert(Ptr{Ptr{Node{T}}}, ptr), p_node, :acquire_release)

# Link former tail to new tail
Core.Intrinsics.atomic_pointerset(convert(Ptr{Ptr{Node{T}}}, p_tail), p_node, :release)
return nothing
p_top = Ptr{Ptr{Node{T}}}(ptr)
top = Core.Intrinsics.atomic_pointerref(p_top, :monotonic)
while true
setnext!(p_node, top)
top, ok = Core.Intrinsics.atomic_pointerreplace(
p_top,
top,
p_node,
:acquire_release,
:monotonic,
)
ok && return nothing
end
end

function dequeue!(q::LockfreeQueue{T}) where T
p_head = @atomic :acquire q.head
popall!(stack::SingleConsumerStack{T}) where T = moveto!(T[], stack)

success = false
p_new_head = convert(Ptr{Node{T}}, C_NULL)
while !success
# Load new head
p_new_head = Core.Intrinsics.atomic_pointerref(convert(Ptr{Ptr{Node{T}}}, p_head), :acquire)
if p_new_head == convert(Ptr{Node{T}}, C_NULL)
return nothing # never remove the last node, queue is empty
end
# Attempt replacement of current head with new head
p_head, success = @atomicreplace :acquire_release :monotonic q.head p_head => p_new_head
function moveto!(results::AbstractVector{T}, stack::SingleConsumerStack{T}) where T
p_node = @atomic :monotonic stack.top
while true
p_node, ok = @atomicreplace :acquire_release :monotonic stack.top p_node => C_NULL
ok && break
end

# Copy the node data into `results` vector
while p_node != C_NULL
node = unsafe_load(p_node)
Libc.free(p_node)
push!(results, node.data)
p_node = node.next
end

# We have atomically advanced head and claimed a node.
# We return the data from the new head
# The lists starts of with a temporary node, which we will now free.
head = unsafe_load(p_new_head) # p_head is now valid to free
# TODO: Is there a potential race between `free(p_head)` and `unsafe_load(p_head)`
# on the previous `dequeue!`?
# As long as we only have one consumer this is fine.
Libc.free(p_head)
return Some(head.data)

return results
end

mutable struct ForeignCallback{T}
queue::LockfreeQueue{T}
queue::SingleConsumerStack{T}
cond::Base.AsyncCondition

function ForeignCallback{T}(callback) where T
queue = LockfreeQueue{T}()

cond = Base.AsyncCondition() do _
data = dequeue!(queue)
while data !== nothing
Base.errormonitor(Threads.@spawn callback(something($data)))
data = dequeue!(queue)
task::Task

function ForeignCallback{T}(callback; fifo::Bool = true) where T
stack = SingleConsumerStack{T}()
cond = Base.AsyncCondition()
mayreverse = fifo ? Iterators.reverse : identity
task = Threads.@spawn begin
local results = T[]
while isopen(cond)
wait(cond)
moveto!(results, stack)
for data in mayreverse(results)
Base.errormonitor(Threads.@spawn callback(data))
end
empty!(results)
end
return
end
this = new{T}(queue, cond)
this = new{T}(stack, cond, task)
finalizer(this) do this
close(this.cond)
# TODO: free queue we are leaking at least one node here
Expand All @@ -121,7 +135,7 @@ by calling `notify!`.
ForeignToken(fc::ForeignCallback) = ForeignToken(fc.cond.handle, Base.pointer_from_objref(fc.queue))

function notify!(token::ForeignToken, data::T) where T
unsafe_enqueue!(token.queue, data)
unsafe_push!(token.queue, data)
ccall(:uv_async_send, Cvoid, (Ptr{Cvoid},), token.handle)
return
end
Expand Down
49 changes: 24 additions & 25 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,28 @@ struct Message2
end

@testset "Constructor" begin
@test_throws AssertionError ForeignCallbacks.LockfreeQueue{Ref{Int}}()
@test_throws AssertionError ForeignCallbacks.LockfreeQueue{Base.RefValue{Int}}()
@test_throws AssertionError ForeignCallbacks.LockfreeQueue{Message2}()
@test_throws AssertionError ForeignCallbacks.LockfreeQueue{Array}()
@test_throws AssertionError ForeignCallbacks.LockfreeQueue{Array{Int, 1}}()
@test_throws AssertionError ForeignCallbacks.SingleConsumerStack{Ref{Int}}()
@test_throws AssertionError ForeignCallbacks.SingleConsumerStack{Base.RefValue{Int}}()
@test_throws AssertionError ForeignCallbacks.SingleConsumerStack{Message2}()
@test_throws AssertionError ForeignCallbacks.SingleConsumerStack{Array}()
@test_throws AssertionError ForeignCallbacks.SingleConsumerStack{Array{Int, 1}}()
end

@testset "Queue" begin
lfq = ForeignCallbacks.LockfreeQueue{Int}()
@testset "SingleConsumerStack" begin
lfq = ForeignCallbacks.SingleConsumerStack{Int}()

@test ForeignCallbacks.dequeue!(lfq) === nothing
ForeignCallbacks.enqueue!(lfq, 1)
@test ForeignCallbacks.dequeue!(lfq) === Some(1)
@test ForeignCallbacks.dequeue!(lfq) === nothing
@test ForeignCallbacks.popall!(lfq) == []
push!(lfq, 1)
@test ForeignCallbacks.popall!(lfq) == [1]
@test ForeignCallbacks.popall!(lfq) == []

GC.@preserve lfq begin
ptr = Base.pointer_from_objref(lfq)
ForeignCallbacks.unsafe_enqueue!(ptr, 2)
ForeignCallbacks.unsafe_push!(ptr, 2)
end
# TODO: Test no load from TLS in `unsafe_enqueue!`
@test ForeignCallbacks.dequeue!(lfq) === Some(2)
@test ForeignCallbacks.dequeue!(lfq) === nothing
# TODO: Test no load from TLS in `unsafe_push!`
@test ForeignCallbacks.popall!(lfq) == [2]
@test ForeignCallbacks.popall!(lfq) == []
end

@testset "callback" begin
Expand All @@ -52,12 +52,12 @@ end
end

@testset "IR" begin
let llvm = sprint(io->code_llvm(io, ForeignCallbacks.enqueue!, Tuple{ForeignCallbacks.LockfreeQueue{Int}, Int}))
let llvm = sprint(io->code_llvm(io, push!, Tuple{ForeignCallbacks.SingleConsumerStack{Int}, Int}))
@test !contains(llvm, "%thread_ptr")
@test !contains(llvm, "%pgcstack")
@test !contains(llvm, "%gcframe")
end
let llvm = sprint(io->code_llvm(io, ForeignCallbacks.unsafe_enqueue!, Tuple{Ptr{Cvoid}, Int}))
let llvm = sprint(io->code_llvm(io, ForeignCallbacks.unsafe_push!, Tuple{Ptr{Cvoid}, Int}))
@test !contains(llvm, "%thread_ptr")
@test !contains(llvm, "%pgcstack")
@test !contains(llvm, "%gcframe")
Expand All @@ -67,7 +67,7 @@ end
@test !contains(llvm, "%pgcstack")
@test !contains(llvm, "%gcframe")
end
let llvm = sprint(io->code_llvm(io, ForeignCallbacks.unsafe_enqueue!, Tuple{Ptr{Cvoid}, Message}))
let llvm = sprint(io->code_llvm(io, ForeignCallbacks.unsafe_push!, Tuple{Ptr{Cvoid}, Message}))
@test !contains(llvm, "%thread_ptr")
@test !contains(llvm, "%pgcstack")
@test !contains(llvm, "%gcframe")
Expand All @@ -89,7 +89,7 @@ end

function producer!(lfq)
for i in 1:100
ForeignCallbacks.enqueue!(lfq, i)
push!(lfq, i)
yield()
end
end
Expand All @@ -98,7 +98,7 @@ function unsafe_producer!(lfq)
for i in 1:100
GC.@preserve lfq begin
ptr = Base.pointer_from_objref(lfq)
ForeignCallbacks.unsafe_enqueue!(ptr, i)
ForeignCallbacks.unsafe_push!(ptr, i)
end
yield()
end
Expand All @@ -109,9 +109,8 @@ function consumer!(lfq)

done = false
while !done
data = ForeignCallbacks.dequeue!(lfq)
if data !== nothing
acc += something(data)
for x in ForeignCallbacks.popall!(lfq)
acc += x
end
done = acc == sum(1:100)*2*Threads.nthreads()
yield()
Expand All @@ -120,7 +119,7 @@ end

@testset "Queue threads" begin
@test Threads.nthreads() == Sys.CPU_THREADS
let lfq = ForeignCallbacks.LockfreeQueue{Int}()
let lfq = ForeignCallbacks.SingleConsumerStack{Int}()
@sync begin
for n in 1:2*Threads.nthreads()
Threads.@spawn producer!(lfq)
Expand All @@ -130,7 +129,7 @@ end
@test true
end

let lfq = ForeignCallbacks.LockfreeQueue{Int}()
let lfq = ForeignCallbacks.SingleConsumerStack{Int}()
@sync begin
for n in 1:2*Threads.nthreads()
Threads.@spawn unsafe_producer!(lfq)
Expand Down

2 comments on commit f4fde9d

@vchuravy
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/51925

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" f4fde9d2fcd78458063d78dca64c4234b0e8f9a5
git push origin v0.1.0

Please sign in to comment.