diff --git a/Project.toml b/Project.toml index f9fc2330..5a24122e 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -30,6 +31,7 @@ TenetAdaptExt = "Adapt" TenetChainRulesCoreExt = "ChainRulesCore" TenetChainRulesExt = "ChainRules" TenetChainRulesTestUtilsExt = ["ChainRulesCore", "ChainRulesTestUtils"] +TenetDaggerExt = "Dagger" TenetFiniteDifferencesExt = "FiniteDifferences" TenetGraphMakieExt = ["GraphMakie", "Makie"] @@ -39,6 +41,7 @@ Adapt = "4" ChainRules = "1.0" ChainRulesCore = "1.0" Combinatorics = "1.0" +Dagger = "0.18" DeltaArrays = "0.1.1" EinExprs = "0.5, 0.6" GraphMakie = "0.4,0.5" diff --git a/ext/TenetDaggerExt.jl b/ext/TenetDaggerExt.jl new file mode 100644 index 00000000..07d90b8f --- /dev/null +++ b/ext/TenetDaggerExt.jl @@ -0,0 +1,135 @@ +module TenetDaggerExt + +using Tenet +using Dagger: Dagger, ArrayOp, Context, ArrayDomain, EagerThunk, DArray + +struct Contract{T,N} <: ArrayOp{T,N} + ic::Vector{Symbol} + a::ArrayOp + ia::Vector{Symbol} + b::ArrayOp + ib::Vector{Symbol} + + function Contract(ic, a, ia, b, ib) + allunique(ia) || throw(ErrorException("ia must have unique indices")) + allunique(ib) || throw(ErrorException("ib must have unique indices")) + allunique(ic) || throw(ErrorException("ic must have unique indices")) + ic ⊆ ia ∪ ib || throw(ErrorException("ic must be a subset of ia ∪ ib")) + return new{promote_type(eltype(a), eltype(b)),length(ic)}(ic, a, ia, b, ib) + end +end + +function Base.size(x::Contract) + return Tuple( + Iterators.map(x.ic) do i + if i ∈ x.ia + size(x.a, findfirst(==(i), x.ia)) + elseif i ∈ x.ib + size(x.b, findfirst(==(i), x.ib)) + else + throw(ErrorException("index $i not found in a nor b")) + end + end, + ) +end + +function Dagger.Blocks(x::Contract) + return Dagger.Blocks(map(x.ic) do i + j = findfirst(==(i), x.ia) + isnothing(j) || return x.a.partitioning.blocksize[j] + + j = findfirst(==(i), x.ib) + isnothing(j) || return x.b.partitioning.blocksize[j] + + throw(ErrorException("index :$i not found in a nor b")) + end...) +end + +function selectdims(a, proj::Pair...) + return reduce(proj; init=a) do acc, (d, i) + selectdim(acc, d, i) + end +end + +contractfn(ic, chunk_a, ia, chunk_b, ib) = parent(contract(Tensor(chunk_a, ia), Tensor(chunk_b, ib); out=ic)) + +function Dagger.stage(ctx::Context, op::Contract{T,N}) where {T,N} + domain = Dagger.ArrayDomain([1:l for l in size(op)]) + partitioning = Dagger.Blocks(op) + + # NOTE careful with ÷ for dividing into partitions + subdomains = Array{ArrayDomain{N,NTuple{2,UnitRange{Int}}}}(undef, map(÷, size(op), partitioning.blocksize)) + for indices in eachindex(IndexCartesian(), subdomains) + subdomains[indices] = ArrayDomain( + map(Tuple(indices), partitioning.blocksize) do i, step + (i - 1) * step .+ (1:step) + end, + ) + end + + suminds = setdiff(op.ia ∪ op.ib, op.ic) + inner_perm_a = sortperm(map(i -> findfirst(==(i), op.ia), suminds)) + inner_perm_b = sortperm(map(i -> findfirst(==(i), op.ib), suminds)) + + mask_a = op.ic .∈ (op.ia,) + mask_b = op.ic .∈ (op.ib,) + outer_perm_a = map(i -> findfirst(==(i), op.ia), op.ic[mask_a]) + outer_perm_b = map(i -> findfirst(==(i), op.ib), op.ic[mask_b]) + + chunks = similar(subdomains, EagerThunk) + for indices in eachindex(IndexCartesian(), chunks) + outer_indices_a = Tuple(indices)[mask_a] + chunks_a = dropdims( + reduce(zip(outer_perm_a, outer_indices_a); init=Dagger.chunks(op.a)) do acc, (d, i) + selectdim(acc, d, i:i) + end; + dims=Tuple(outer_perm_a), + ) + chunks_a = permutedims(chunks_a, inner_perm_a) + + outer_indices_b = Tuple(indices)[mask_b] + chunks_b = dropdims( + reduce(zip(outer_perm_b, outer_indices_b); init=Dagger.chunks(op.b)) do acc, (d, i) + selectdim(acc, d, i:i) + end; + dims=Tuple(outer_perm_b), + ) + chunks_b = permutedims(chunks_b, inner_perm_b) + + chunks[indices] = Dagger.treereduce( + Dagger.AddComputeOp, + map(chunks_a, chunks_b) do chunk_a, chunk_b + # TODO add ThunkOptions: alloc_util, occupancy, ... + Dagger.@spawn contractfn(op.ic, chunk_a, op.ia, chunk_b, op.ib) + end, + ) + end + + return DArray(T, domain, subdomains, chunks, partitioning) +end + +function Tenet.contract( + a::Tensor{Ta,Na,Aa}, b::Tensor{Tb,Nb,Ab}; dims=(∩(inds(a), inds(b))), out=nothing +) where {Ta,Tb,Na,Nb,Aa<:Dagger.DArray{Ta,Na},Ab<:Dagger.DArray{Tb,Nb}} + ia = collect(inds(a)) + ib = collect(inds(b)) + i = ∩(dims, ia, ib) + + ic::Vector{Symbol} = if isnothing(out) + setdiff(ia ∪ ib, i isa Base.AbstractVecOrTuple ? i : (i,))::Vector{Symbol} + else + out + end + + data = Dagger._to_darray(Contract(ic, parent(a), ia, parent(b), ib)) + return Tensor(data, ic) +end + +Tenet.contract(a::Tensor, b::Tensor{T,N,A}; kwargs...) where {T,N,A<:Dagger.DArray} = contract(b, a; kwargs...) +function Tenet.contract(a::Tensor{T,N,A}, b::Tensor; kwargs...) where {T,N,A<:Dagger.DArray} + throw( + ArgumentError("contract on a Dagger.DArray-backed Tensor with a non-DArray-backed Tensor is not yet supported") + ) +end + +end diff --git a/test/Project.toml b/test/Project.toml index 9c2c5858..94c8cda0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,7 +4,9 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" diff --git a/test/integration/Dagger_test.jl b/test/integration/Dagger_test.jl new file mode 100644 index 00000000..70bf6a68 --- /dev/null +++ b/test/integration/Dagger_test.jl @@ -0,0 +1,45 @@ +using Tenet +using Dagger +using Distributed + +@testset "Dagger" begin + addprocs(1) + @everywhere using Dagger, Tenet + + try + @testset "Tensor" begin + data = rand(4, 4) + block_array = DArray(data, Dagger.Blocks(2, 2)) + indices = (:i, :j) + + tensor = Tensor(data, indices) + block_tensor = Tensor(block_array, indices) + + @test inds(block_tensor) == inds(tensor) + @test Array(parent(block_tensor)) ≈ parent(tensor) + end + + @testset "contract" begin + @testset "block-block" begin + data1, data2 = rand(4, 4), rand(4, 4) + block_array1 = distribute(data1, Dagger.Blocks(2, 2)) + block_array2 = distribute(data2, Dagger.Blocks(2, 2)) + + tensor1 = Tensor(data1, [:i, :j]) + tensor2 = Tensor(data2, [:j, :k]) + block_tensor1 = Tensor(block_array1, [:i, :j]) + block_tensor2 = Tensor(block_array2, [:j, :k]) + + contracted_tensor = contract(tensor1, tensor2) + contracted_block_tensor = contract(block_tensor1, block_tensor2) + + @test parent(contracted_block_tensor) isa DArray + @test inds(contracted_block_tensor) == [:i, :k] + @test all(==((2, 2)) ∘ size, Dagger.domainchunks(parent(contracted_block_tensor))) + @test collect(parent(contracted_block_tensor)) ≈ parent(contracted_tensor) + end + end + finally + rmprocs(workers()) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5fb4da2d..d16a31fe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ if VERSION >= v"1.10" @testset "Integration tests" verbose = true begin include("integration/ChainRules_test.jl") # include("integration/BlockArray_test.jl") + include("integration/Dagger_test.jl") include("integration/Makie_test.jl") end end