Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>"]
version = "0.4.7"
version = "0.4.8"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -19,7 +19,6 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down Expand Up @@ -79,7 +78,6 @@ MacroTools = "0.5"
MappedArrays = "0.4"
Metal = "1"
Octavian = "0.3"
PackageExtensionCompat = "1"
Random = "<0.0.1, 1.10"
SimpleTraits = "0.9.4"
SparseArrays = "<0.0.1, 1.10"
Expand Down
8 changes: 1 addition & 7 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ include("blocksparse/contract.jl")
include("blocksparse/contract_utilities.jl")
include("blocksparse/contract_generic.jl")
include("blocksparse/contract_sequential.jl")
include("blocksparse/contract_folds.jl")
include("blocksparse/contract_threads.jl")
include("blocksparse/contract_threaded.jl")
include("blocksparse/diagblocksparse.jl")
include("blocksparse/similar.jl")
include("blocksparse/combiner.jl")
Expand Down Expand Up @@ -221,9 +220,4 @@ end

function backend_octavian end

using PackageExtensionCompat
function __init__()
@require_extensions
end

end # module NDTensors
2 changes: 0 additions & 2 deletions NDTensors/src/blocksparse/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ function contract_blockoffsets(
alg = Algorithm"sequential"()
if using_threaded_blocksparse() && nthreads() > 1
alg = Algorithm"threaded_threads"()
# This code is a bit cleaner but slower:
# alg = Algorithm"threaded_folds"()
end
return contract_blockoffsets(
alg, boffs1, inds1, labels1, boffs2, inds2, labels2, indsR, labelsR
Expand Down
60 changes: 0 additions & 60 deletions NDTensors/src/blocksparse/contract_folds.jl

This file was deleted.

38 changes: 2 additions & 36 deletions NDTensors/src/blocksparse/contract_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,14 @@ function contract_blockoffsets(
indsR,
labelsR,
)
N1 = length(blocktype(boffs1))
N2 = length(blocktype(boffs2))
NR = length(labelsR)
ValNR = ValLength(labelsR)
labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR = contract_labels(
labels1, labels2, labelsR
)

# Contraction plan element type
T = Tuple{Block{N1},Block{N2},Block{NR}}

# Thread-local collections of block contractions.
# Could use:
# ```julia
# FLoops.@reduce(contraction_plans = append!(T[], [(block1, block2, blockR)]))
# ```
# as a simpler alternative but it is slower.

contraction_plans = Vector{T}[T[] for _ in 1:nthreads()]

#
# Reserve some capacity
# In theory the maximum is length(boffs1) * length(boffs2)
# but in practice that is too much
#for contraction_plan in contraction_plans
# sizehint!(contraction_plan, max(length(boffs1), length(boffs2)))
#end
#

contract_blocks!(
alg,
contraction_plans,
boffs1,
boffs2,
labels1_to_labels2,
labels1_to_labelsR,
labels2_to_labelsR,
ValNR,
contraction_plan = contract_blocks(
alg, boffs1, boffs2, labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR, ValNR
)

contraction_plan = reduce(vcat, contraction_plans)
blockoffsetsR = BlockOffsets{NR}()
nnzR = 0
for (_, _, blockR) in contraction_plan
Expand All @@ -60,7 +27,6 @@ function contract_blockoffsets(
nnzR += blockdim(indsR, blockR)
end
end

return blockoffsetsR, contraction_plan
end

Expand Down
91 changes: 91 additions & 0 deletions NDTensors/src/blocksparse/contract_threaded.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using .Expose: expose
function contract_blocks(
alg::Algorithm"threaded_threads",
boffs1,
boffs2,
labels1_to_labels2,
labels1_to_labelsR,
labels2_to_labelsR,
ValNR::Val{NR},
) where {NR}
N1 = length(blocktype(boffs1))
N2 = length(blocktype(boffs2))
blocks1 = keys(boffs1)
blocks2 = keys(boffs2)
T = Tuple{Block{N1},Block{N2},Block{NR}}
return if length(blocks1) > length(blocks2)
tasks = map(
Iterators.partition(blocks1, max(1, length(blocks1) ÷ nthreads()))
) do blocks1_partition
@spawn begin
block_contractions = T[]
for block1 in blocks1_partition
for block2 in blocks2
block_contraction = maybe_contract_blocks(
block1,
block2,
labels1_to_labels2,
labels1_to_labelsR,
labels2_to_labelsR,
ValNR,
)
if !isnothing(block_contraction)
push!(block_contractions, block_contraction)
end
end
end
return block_contractions
end
end
all_block_contractions = T[]
for task in tasks
append!(all_block_contractions, fetch(task))
end
return all_block_contractions
else
tasks = map(
Iterators.partition(blocks2, max(1, length(blocks2) ÷ nthreads()))
) do blocks2_partition
@spawn begin
block_contractions = T[]
for block2 in blocks2_partition
for block1 in blocks1
block_contraction = maybe_contract_blocks(
block1,
block2,
labels1_to_labels2,
labels1_to_labelsR,
labels2_to_labelsR,
ValNR,
)
if !isnothing(block_contraction)
push!(block_contractions, block_contraction)
end
end
end
return block_contractions
end
end
all_block_contractions = T[]
for task in tasks
append!(all_block_contractions, fetch(task))
end
return all_block_contractions
end
end

function contract!(
::Algorithm"threaded_folds",
R::BlockSparseTensor,
labelsR,
tensor1::BlockSparseTensor,
labelstensor1,
tensor2::BlockSparseTensor,
labelstensor2,
contraction_plan,
)
executor = ThreadedEx()
return contract!(
R, labelsR, tensor1, labelstensor1, tensor2, labelstensor2, contraction_plan, executor
)
end
Loading
Loading