Skip to content

Commit

Permalink
Merge pull request #597 from willow-ahrens/kbd-sparselist-follow
Browse files Browse the repository at this point in the history
add follow protocol to sparse_list_level
  • Loading branch information
willow-ahrens authored Jun 13, 2024
2 parents 59866ce + ef2ae32 commit e62c15b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
27 changes: 27 additions & 0 deletions src/tensors/levels/sparse_list_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,33 @@ function instantiate(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, mode::Re
)
end


function instantiate(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, mode::Reader, subprotos, ::typeof(follow))
(lvl, pos) = (fbr.lvl, fbr.pos)
tag = lvl.ex
Tp = postype(lvl)
my_q = freshen(ctx, tag, :_q)
my_q_stop = freshen(ctx, tag, :_q_stop)
my_qos = freshen(ctx, tag, :_qos)

Furlable(
body = (ctx, ext) ->
Lookup(
body = (ctx, i) -> Thunk(
preamble = quote
$my_q = $(lvl.ptr)[$(ctx(pos))]
$my_q_stop = $(lvl.ptr)[$(ctx(pos)) + $(Tp(1))]
$my_qos = bin_scansearch($(lvl.idx), $(ctx(i)), $my_q, $my_q_stop-1)
end,
body = (ctx) -> Switch([
value(:($my_qos < $my_q_stop && $(lvl.idx)[$my_qos] == $(ctx(i)))) => instantiate(ctx, VirtualSubFiber(lvl.lvl, value(my_qos, Tp)), mode, subprotos),
literal(true) => FillLeaf(virtual_level_fill_value(lvl))
])
)
)
)
end

function instantiate(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, mode::Reader, subprotos, ::typeof(gallop))
(lvl, pos) = (fbr.lvl, fbr.pos)
tag = lvl.ex
Expand Down
4 changes: 0 additions & 4 deletions src/util/shims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ end

Base.@propagate_inbounds function bin_scansearch(v, x, lo::T1, hi::T2) where {T1<:Integer, T2<:Integer} # TODO types for `lo` and `hi` #406
u = T1(1)
stop = min(hi, lo + T1(32))
while lo + u < stop && v[lo] < x
lo += u
end
lo = lo - u
hi = hi + u
while lo < hi - u
Expand Down
25 changes: 25 additions & 0 deletions test/test_issues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -863,4 +863,29 @@ using SparseArrays
end
@test C1 == C2
end


# Basic Sparse Follow Test
let

n = 1000
A = Tensor(SparseList(Element(0.0)), fsprand(n, 100))
B = Tensor(SparseList(Element(0.0)), fsprand(n, 100))
C_follow = Tensor(SparseList(Element(0.0)))
@finch begin
C_follow .= 0
for i=_
C_follow[i] = A[follow(i)] * B[i]
end
end
C_walk = Tensor(SparseList(Element(0.0)))
@finch begin
C_walk .= 0
for i=_
C_walk[i] = A[i] * B[i]
end
end
@test C_follow == C_walk
end

end
1 change: 1 addition & 0 deletions test/test_merges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
fmts = [
(;fmt = (z) -> Tensor(Dense(SparseList(Element(z)))), proto = [literal_instance(walk), literal_instance(follow)]),
(;fmt = (z) -> Tensor(Dense(SparseList(Element(z)))), proto = [literal_instance(gallop), literal_instance(follow)]),
(;fmt = (z) -> Tensor(Dense(SparseList(Element(z)))), proto = [literal_instance(follow), literal_instance(follow)]),
(;fmt = (z) -> Tensor(Dense(SparseVBL(Element(z)))), proto = [literal_instance(walk), literal_instance(follow)]),
(;fmt = (z) -> Tensor(Dense(SparseVBL(Element(z)))), proto = [literal_instance(gallop), literal_instance(follow)]),
(;fmt = (z) -> Tensor(Dense(SparseByteMap(Element(z)))), proto = [literal_instance(walk), literal_instance(follow)]),
Expand Down

0 comments on commit e62c15b

Please sign in to comment.