From 108e83a581ec7934cacad0697ec22c6e279e7b70 Mon Sep 17 00:00:00 2001
From: Julian P Samaroo <jpsamaroo@jpsamaroo.me>
Date: Thu, 20 Jul 2023 14:10:35 -0500
Subject: [PATCH 1/5] Add DaggerMPI subpackage for MPI integrations

Allows the `DArray` to participate in MPI operations, by building on the
newly-added support for "MPI-style" partitioning (i.e. only the
rank-local partition is stored), via a new `MPIBlocks` partitioner. This
commit also implements MPI-powered `distribute` and `reduce` operations
for `MPIBlocks`-partitioned arrays, which support a variety of
distribution schemes and data transfer modes.
---
 lib/DaggerMPI/Project.toml     |   8 +++
 lib/DaggerMPI/src/DaggerMPI.jl | 100 +++++++++++++++++++++++++++++++++
 2 files changed, 108 insertions(+)
 create mode 100644 lib/DaggerMPI/Project.toml
 create mode 100644 lib/DaggerMPI/src/DaggerMPI.jl

diff --git a/lib/DaggerMPI/Project.toml b/lib/DaggerMPI/Project.toml
new file mode 100644
index 000000000..730ad66de
--- /dev/null
+++ b/lib/DaggerMPI/Project.toml
@@ -0,0 +1,8 @@
+name = "DaggerMPI"
+uuid = "37bfb287-2338-4693-8557-581796463535"
+authors = ["Felipe de Alcântara Tomé <tomefelipe0@usp.br>", "Julian P Samaroo <jpsamaroo@jpsamaroo.me>"]
+version = "0.1.0"
+
+[deps]
+Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
+MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
diff --git a/lib/DaggerMPI/src/DaggerMPI.jl b/lib/DaggerMPI/src/DaggerMPI.jl
new file mode 100644
index 000000000..fc02d7759
--- /dev/null
+++ b/lib/DaggerMPI/src/DaggerMPI.jl
@@ -0,0 +1,100 @@
+module DaggerMPI
+using Dagger
+import Base: reduce, fetch, cat
+using MPI
+
+export MPIBlocks
+
+struct MPIBlocks{N} <: Dagger.AbstractSingleBlocks{N}
+    blocksize::NTuple{N, Int}
+end
+MPIBlocks(xs::Int...) = MPIBlocks(xs)
+
+function Dagger.distribute(::Type{A},
+                           x::Union{AbstractArray, Nothing},
+                           dist::MPIBlocks,
+                           comm::MPI.Comm=MPI.COMM_WORLD,
+                           root::Integer=0) where {A<:AbstractArray{T, N}} where {T, N}
+    isroot = MPI.Comm_rank(comm) == root
+
+    # TODO: Make better load balancing
+
+    data = Array{T, N}(undef, dist.blocksize)
+    if isroot
+        cs = Array{T, N}(undef, size(x))
+        parts = partition(dist, domain(x))
+        idx = 1
+        for part in parts
+            cs[idx:(idx - 1 + prod(dist.blocksize))] = x[part]
+            idx += prod(dist.blocksize)
+        end
+        MPI.Scatter!(MPI.UBuffer(cs, div(length(cs), MPI.Comm_size(comm))), data, comm, root=root)
+    else
+        MPI.Scatter!(nothing, data, comm, root=root)
+    end
+
+    data = Dagger.tochunk(data)
+
+    return Dagger.DArray(T, domain(data), domain(data), data, dist)
+end
+
+function Dagger.distribute(::Type{A},
+                           dist::MPIBlocks,
+                           comm::MPI.Comm=MPI.COMM_WORLD,
+                           root::Integer=0) where {A<:AbstractArray{T, N}} where {T, N}
+    return distribute(A, nothing, dist, comm, root)
+end
+
+function Dagger.distribute(x::AbstractArray,
+                           dist::MPIBlocks,
+                           comm::MPI.Comm=MPI.COMM_WORLD,
+                           root::Integer=0)
+    return distribute(typeof(x), x, dist, comm, root)
+end
+
+function Base.reduce(f::Function, x::Dagger.DArray{T,N,MPIBlocks{N}};
+                     dims=nothing,
+                     comm=MPI.COMM_WORLD, root=nothing, acrossranks::Bool=true) where {T,N}
+    if dims === nothing
+        if !acrossranks
+            return fetch(Dagger.reduce_async(f,x))
+        elseif root === nothing
+            return MPI.Allreduce(fetch(Dagger.reduce_async(f,x)), f, comm)
+        else
+            return MPI.Reduce(fetch(Dagger.reduce_async(f,x)), f, comm; root)
+        end
+    else
+        if dims isa Int
+            dims = (dims,)
+        end
+        d = reduce(x.domain, dims=dims)
+        ds = reduce(x.subdomains[1], dims=dims)
+        if !acrossranks
+            thunks = Dagger.spawn(b->reduce(f, b, dims=dims), x.chunks[1])
+            return Dagger.DArray(T, d, ds, thunks, x.partitioning; concat=x.concat)
+        else
+            tmp = collect(reduce(f, x, comm=comm, root=root, dims=dims, acrossranks=false))
+            if root === nothing
+                h = UInt(0)
+                for dim in 1:N
+                    if dim in dims
+                        continue
+                    end
+                    h = hash(x.subdomains[1].indexes[dim], h)
+                end
+                h = abs(Base.unsafe_trunc(Int32, h))
+                newc = MPI.Comm_split(comm, h, MPI.Comm_rank(comm))
+                chunks = Dagger.tochunk(reshape(MPI.Allreduce(tmp, f, newc), size(tmp)))
+            else
+                rcvbuf = MPI.Reduce(tmp, f, comm; root)
+                if root != MPI.Comm_rank(comm)
+                    return nothing
+                end
+                chunks = Dagger.tochunk(reshape(rcvbuf, size(tmp)))
+            end
+            return Dagger.DArray(T, d, ds, chunks, x.partitioning; concat=x.concat)
+        end
+    end
+end
+
+end # module

From 65fb8cfa78f113fce1b8704d03e43510071134c1 Mon Sep 17 00:00:00 2001
From: fda-tome <tomefelipe0@usp.br>
Date: Tue, 1 Aug 2023 15:10:56 -0300
Subject: [PATCH 2/5] Collection through ranks and auto-distribution

---
 lib/DaggerMPI/src/DaggerMPI.jl | 63 +++++++++++++++++++++++++++++-----
 1 file changed, 55 insertions(+), 8 deletions(-)

diff --git a/lib/DaggerMPI/src/DaggerMPI.jl b/lib/DaggerMPI/src/DaggerMPI.jl
index fc02d7759..1429d6740 100644
--- a/lib/DaggerMPI/src/DaggerMPI.jl
+++ b/lib/DaggerMPI/src/DaggerMPI.jl
@@ -6,9 +6,9 @@ using MPI
 export MPIBlocks
 
 struct MPIBlocks{N} <: Dagger.AbstractSingleBlocks{N}
-    blocksize::NTuple{N, Int}
+    blocksize::NTuple{N, Union{Int, Nothing}}
 end
-MPIBlocks(xs::Int...) = MPIBlocks(xs)
+MPIBlocks(xs::Union{Int,Nothing}...) = MPIBlocks(xs)
 
 function Dagger.distribute(::Type{A},
                            x::Union{AbstractArray, Nothing},
@@ -16,17 +16,38 @@ function Dagger.distribute(::Type{A},
                            comm::MPI.Comm=MPI.COMM_WORLD,
                            root::Integer=0) where {A<:AbstractArray{T, N}} where {T, N}
     isroot = MPI.Comm_rank(comm) == root
-
+    csz = MPI.Comm_size(comm)
+    if any(isnothing, dist.blocksize)
+        newdims = map(collect(dist.blocksize)) do d
+            something(d, 1)
+        end
+        if isroot
+            for i in 1:N
+                if dist.blocksize[i] !== nothing
+                    continue
+                end
+                if csz * prod(newdims) >= length(x)
+                    break
+                end
+                newdims[i] = min(size(x, i),  cld(length(x), csz * prod(newdims)))
+             end
+        end
+        newdims = MPI.bcast(newdims, comm, root=root)
+        dist = MPIBlocks(newdims...)
+    end
+    d = MPI.bcast(domain(x), comm, root=root)
     # TODO: Make better load balancing
-
-    data = Array{T, N}(undef, dist.blocksize)
+    data = Array{T,N}(undef, dist.blocksize)
     if isroot
         cs = Array{T, N}(undef, size(x))
+        #TODO: deal with uneven partitions(scatterv possibly)
+        @assert prod(dist.blocksize) * csz == length(x) "Cannot match length of array and number of ranks"
         parts = partition(dist, domain(x))
         idx = 1
         for part in parts
-            cs[idx:(idx - 1 + prod(dist.blocksize))] = x[part]
-            idx += prod(dist.blocksize)
+            step = prod(map(length, part.indexes))
+            cs[idx:(idx - 1 + step)] = x[part]
+            idx += step
         end
         MPI.Scatter!(MPI.UBuffer(cs, div(length(cs), MPI.Comm_size(comm))), data, comm, root=root)
     else
@@ -35,7 +56,7 @@ function Dagger.distribute(::Type{A},
 
     data = Dagger.tochunk(data)
 
-    return Dagger.DArray(T, domain(data), domain(data), data, dist)
+    return Dagger.DArray(T, d, domain(data), data, dist)
 end
 
 function Dagger.distribute(::Type{A},
@@ -97,4 +118,30 @@ function Base.reduce(f::Function, x::Dagger.DArray{T,N,MPIBlocks{N}};
     end
 end
 
+function Base.collect(x::Dagger.DArray{T,N,MPIBlocks{N}};
+                     comm=MPI.COMM_WORLD, root=nothing, acrossranks::Bool=true) where {T,N}
+    if !acrossranks
+        a = fetch(x)
+        if isempty(x.chunks)
+            return Array{eltype(d)}(undef, size(x)...)
+        end
+
+        dimcatfuncs = [(d...) -> d.concat(d..., dims=i) for i in 1:ndims(x)]
+        Dagger.treereduce_nd(dimcatfuncs, asyncmap(fetch, a.chunks))
+    else
+        datasnd = collect(x, acrossranks=false)
+        if root === nothing
+            tmp = MPI.Allgather(datasnd, comm)
+        else
+            tmp = MPI.Gather(datasnd, comm, root=root)
+            if tmp === nothing
+                return
+            end
+        end
+        return(reshape(tmp, size(x.domain)))
+    end
+end
+
+
+
 end # module

From 117e8ca4e1538edc0c898d6416e4f42cf7b097eb Mon Sep 17 00:00:00 2001
From: fda-tome <tomefelipe0@usp.br>
Date: Mon, 14 Aug 2023 13:04:50 -0300
Subject: [PATCH 3/5] Fix ind2sub deprecation

---
 src/lib/domain-blocks.jl | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/lib/domain-blocks.jl b/src/lib/domain-blocks.jl
index 30dc4b085..f095c0c04 100644
--- a/src/lib/domain-blocks.jl
+++ b/src/lib/domain-blocks.jl
@@ -15,10 +15,11 @@ function _getindex(x::DomainBlocks{N}, idx::Tuple) where N
 end
 
 function getindex(x::DomainBlocks{N}, idx::Int) where N
-    if N == 1
+    if N ==1
         _getindex(x, (idx,))
     else
-        _getindex(x, ind2sub(x, idx))
+        idces = CartesianIndices(x)
+        _getindex(x, Tuple(idces[idx]))
     end
 end
 

From c133eaecdf5cf0f3d0e5fba6e3fedbf7bb5f9d13 Mon Sep 17 00:00:00 2001
From: fda-tome <tomefelipe0@usp.br>
Date: Mon, 14 Aug 2023 13:11:21 -0300
Subject: [PATCH 4/5] [DaggerMPI] Tests and minor corrections for DArray

---
 Project.toml                   |   5 +-
 lib/DaggerMPI/Project.toml     |   5 +
 lib/DaggerMPI/src/DaggerMPI.jl | 186 +++++++++++++++++++++++++--------
 lib/DaggerMPI/test/runtests.jl | 154 +++++++++++++++++++++++++++
 4 files changed, 307 insertions(+), 43 deletions(-)
 create mode 100644 lib/DaggerMPI/test/runtests.jl

diff --git a/Project.toml b/Project.toml
index 93a570f95..12ef13e87 100644
--- a/Project.toml
+++ b/Project.toml
@@ -8,11 +8,14 @@ ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
 DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
 Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267"
 MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
 MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"
+Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
 Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
+Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
 Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
 SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
 SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -33,8 +36,8 @@ TimespanLogging = "0.1"
 julia = "1.7"
 
 [extras]
-Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [targets]
 test = ["Test", "Pkg"]
diff --git a/lib/DaggerMPI/Project.toml b/lib/DaggerMPI/Project.toml
index 730ad66de..8d00014d4 100644
--- a/lib/DaggerMPI/Project.toml
+++ b/lib/DaggerMPI/Project.toml
@@ -6,3 +6,8 @@ version = "0.1.0"
 [deps]
 Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
 MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
+MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267"
+Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
+SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
+Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
diff --git a/lib/DaggerMPI/src/DaggerMPI.jl b/lib/DaggerMPI/src/DaggerMPI.jl
index 1429d6740..feaab8379 100644
--- a/lib/DaggerMPI/src/DaggerMPI.jl
+++ b/lib/DaggerMPI/src/DaggerMPI.jl
@@ -1,6 +1,7 @@
 module DaggerMPI
-using Dagger
-import Base: reduce, fetch, cat
+using Dagger, SparseArrays, Random
+import Statistics: sum, prod, mean
+import Base: reduce, fetch, cat, prod, sum
 using MPI
 
 export MPIBlocks
@@ -10,30 +11,68 @@ struct MPIBlocks{N} <: Dagger.AbstractSingleBlocks{N}
 end
 MPIBlocks(xs::Union{Int,Nothing}...) = MPIBlocks(xs)
 
-function Dagger.distribute(::Type{A},
-                           x::Union{AbstractArray, Nothing},
-                           dist::MPIBlocks,
-                           comm::MPI.Comm=MPI.COMM_WORLD,
-                           root::Integer=0) where {A<:AbstractArray{T, N}} where {T, N}
+function _defineBlocks(dist::MPIBlocks, comm::MPI.Comm, root::Integer, dims::Tuple)
     isroot = MPI.Comm_rank(comm) == root
+    newdims = map(collect(dist.blocksize)) do d
+        something(d, 1)
+    end
     csz = MPI.Comm_size(comm)
-    if any(isnothing, dist.blocksize)
-        newdims = map(collect(dist.blocksize)) do d
-            something(d, 1)
+    if isroot
+        homogenous = map(dist.blocksize) do i 
+            if i !== nothing 
+                if i < (prod(dims) / csz) ^ (1/length(dims))
+                    return true
+                else
+                    return false
+                end
+            else
+                return false
+            end
         end
-        if isroot
-            for i in 1:N
+        if isinteger((prod(dims) / csz) ^ (1 / length(dims))) && !any(homogenous)   
+            for i in 1:length(dims)
+                if dist.blocksize[i] === nothing
+                    newdims[i] = (prod(dims) / csz) ^ (1 / length(dims))
+                end
+            end
+        else
+            for i in 1:length(dims)
                 if dist.blocksize[i] !== nothing
                     continue
                 end
-                if csz * prod(newdims) >= length(x)
+                if csz * prod(newdims) >= prod(dims)
                     break
                 end
-                newdims[i] = min(size(x, i),  cld(length(x), csz * prod(newdims)))
+                newdims[i] = min(dims[i],  cld(prod(dims), csz * prod(newdims)))
              end
-        end
-        newdims = MPI.bcast(newdims, comm, root=root)
-        dist = MPIBlocks(newdims...)
+         end
+    end
+    newdims = MPI.bcast(newdims, comm, root=root)
+    MPIBlocks(newdims...)
+end
+
+function _finish_allocation(f::Function, dist::MPIBlocks, dims, comm::MPI.Comm, root::Integer)
+    if any(isnothing, dist.blocksize)
+       dist =  _defineBlocks(dist, comm, root, dims)
+    end
+    rnk = MPI.Comm_rank(comm)
+    d = ArrayDomain(map(x->1:x, dims))
+    data = f(dist.blocksize)
+    data = Dagger.tochunk(data)
+    ds = partition(dist, d)[rnk + 1]
+    return Dagger.DArray(eltype, d, ds, data, dist)
+end
+
+function Dagger.distribute(::Type{A},
+                           x::Union{AbstractArray, Nothing},
+                           dist::MPIBlocks,
+                           comm::MPI.Comm=MPI.COMM_WORLD,
+                           root::Integer=0) where {A<:AbstractArray{T, N}} where {T, N}
+    rnk = MPI.Comm_rank(comm)
+    isroot = rnk == root
+    csz = MPI.Comm_size(comm)
+    if any(isnothing, dist.blocksize)
+        dist = _defineBlocks(dist, comm, root, size(x))
     end
     d = MPI.bcast(domain(x), comm, root=root)
     # TODO: Make better load balancing
@@ -49,14 +88,14 @@ function Dagger.distribute(::Type{A},
             cs[idx:(idx - 1 + step)] = x[part]
             idx += step
         end
-        MPI.Scatter!(MPI.UBuffer(cs, div(length(cs), MPI.Comm_size(comm))), data, comm, root=root)
+        MPI.Scatter!(MPI.UBuffer(cs, div(length(cs), csz)), data, comm, root=root)
     else
         MPI.Scatter!(nothing, data, comm, root=root)
     end
 
     data = Dagger.tochunk(data)
-
-    return Dagger.DArray(T, d, domain(data), data, dist)
+    ds = partition(dist, d)[rnk + 1] 
+    return Dagger.DArray(T, d, ds, data, dist)
 end
 
 function Dagger.distribute(::Type{A},
@@ -92,56 +131,119 @@ function Base.reduce(f::Function, x::Dagger.DArray{T,N,MPIBlocks{N}};
         ds = reduce(x.subdomains[1], dims=dims)
         if !acrossranks
             thunks = Dagger.spawn(b->reduce(f, b, dims=dims), x.chunks[1])
-            return Dagger.DArray(T, d, ds, thunks, x.partitioning; concat=x.concat)
+            return Dagger.DArray(T, d, ds, thunks, x.partitioning, x.concat)
         else
-            tmp = collect(reduce(f, x, comm=comm, root=root, dims=dims, acrossranks=false))
-            if root === nothing
-                h = UInt(0)
-                for dim in 1:N
-                    if dim in dims
-                        continue
-                    end
-                    h = hash(x.subdomains[1].indexes[dim], h)
+            tmp = collect(reduce(f, x, comm=comm, root=root, dims=dims, acrossranks=false), acrossranks = false)
+            h = UInt(0)
+            for dim in 1:N
+                if dim in dims
+                    continue
                 end
-                h = abs(Base.unsafe_trunc(Int32, h))
-                newc = MPI.Comm_split(comm, h, MPI.Comm_rank(comm))
+                h = hash(x.subdomains[1].indexes[dim], h)
+            end
+            h = abs(Base.unsafe_trunc(Int32, h))
+            newc = MPI.Comm_split(comm, h, MPI.Comm_rank(comm))
+            if root === nothing 
                 chunks = Dagger.tochunk(reshape(MPI.Allreduce(tmp, f, newc), size(tmp)))
             else
-                rcvbuf = MPI.Reduce(tmp, f, comm; root)
+                rcvbuf = MPI.Reduce(tmp, f, newc, root=root)
                 if root != MPI.Comm_rank(comm)
                     return nothing
                 end
-                chunks = Dagger.tochunk(reshape(rcvbuf, size(tmp)))
+                chunks = reshape(rcvbuf, size(tmp))
+                chunks = Dagger.tochunk(chunks)
             end
-            return Dagger.DArray(T, d, ds, chunks, x.partitioning; concat=x.concat)
+            return Dagger.DArray(T, d, ds, chunks, x.partitioning, x.concat)
         end
     end
 end
 
+function Base.rand(dist::MPIBlocks, eltype::Type, dims, comm::MPI.Comm=MPI.COMM_WORLD, root::Integer=0)  
+    rank = MPI.Comm_rank(comm) 
+    s = rand(UInt)
+    f(block) = rand(MersenneTwister(s+rank), eltype, block)
+    _finish_allocation(f, dist, dims, comm, root)
+end
+Base.rand(dist::MPIBlocks, t::Type, dims::Integer...; comm::MPI.Comm=MPI.COMM_WORLD, root::Integer=0) = rand(dist, t, dims, comm, root)
+Base.rand(dist::MPIBlocks, dims::Integer...; comm::MPI.Comm=MPI.COMM_WORLD, root::Integer=0) = rand(dist, Float64, dims, comm, root)
+Base.rand(dist::MPIBlocks, dims::Tuple, comm::MPI.Comm=MPI.COMM_WORLD, root::Integer=0) = rand(dist, Float64, dims, comm, root)
+
+function Base.randn(dist::MPIBlocks, dims, comm::MPI.Comm=MPI.COMM_WORLD, root::Integer=0)
+    rank = MPI.Comm_rank(comm) 
+    s = rand(UInt)
+    f(block) = randn(MersenneTwister(s+rank), Float64, block)
+    _finish_allocation(f, dist, dims, comm, root)
+end
+Base.randn(p::MPIBlocks, dims::Integer...) = randn(p, dims)
+
+function Base.ones(dist::MPIBlocks, eltype::Type, dims, comm::MPI.Comm=MPI.COMM_WORLD, root::Integer=0)
+    f(blocks) = ones(eltype, blocks)
+    _finish_allocation(f, dist, dims, comm, root)
+end
+Base.ones(dist::MPIBlocks, t::Type, dims::Integer...) = ones(dist, t, dims)
+Base.ones(dist::MPIBlocks, dims::Integer...) = ones(dist, Float64, dims)
+Base.ones(dist::MPIBlocks, dims::Tuple) = ones(dist, Float64, dims)
+
+function Base.zeros(dist::MPIBlocks, eltype::Type, dims, comm::MPI.Comm=MPI.COMM_WORLD, root::Integer=0)
+    f(blocks) = zeros(eltype, blocks) 
+    _finish_allocation(f, dist, dims, comm, root)
+end
+Base.zeros(dist::MPIBlocks, t::Type, dims::Integer...) = zeros(dist, t, dims)
+Base.zeros(dist::MPIBlocks, dims::Integer...) = zeros(dist, Float64, dims)
+Base.zeros(dist::MPIBlocks, dims::Tuple) = zeros(dist, Float64, dims)
+
+sum(x::Dagger.DArray{T,N,MPIBlocks{N}}; dims::Union{Int,Nothing}=nothing, comm=MPI.COMM_WORLD, root=nothing, acrossranks::Bool=true) where {T,N} = reduce(+, x, dims=dims, comm=comm, root=root, acrossranks=acrossranks) 
+
+prod(x::Dagger.DArray{T,N,MPIBlocks{N}}; dims::Union{Int,Nothing}=nothing, comm=MPI.COMM_WORLD, root=nothing, acrossranks::Bool=true) where {T,N} = reduce(*, x, dims=dims, comm=comm, root=root, acrossranks=acrossranks) 
+
+#mean(x::Dagger.DArray{T,N,MPIBlocks{N}}; dims::Union{Int,Nothing}=nothing, comm=MPI.COMM_WORLD, root=nothing, acrossranks::Bool=true) where {T,N} = reduce(mean, x, dims=dims, comm=comm, root=root, acrossranks=acrossranks)
+
 function Base.collect(x::Dagger.DArray{T,N,MPIBlocks{N}};
                      comm=MPI.COMM_WORLD, root=nothing, acrossranks::Bool=true) where {T,N}
     if !acrossranks
-        a = fetch(x)
         if isempty(x.chunks)
             return Array{eltype(d)}(undef, size(x)...)
         end
-
-        dimcatfuncs = [(d...) -> d.concat(d..., dims=i) for i in 1:ndims(x)]
-        Dagger.treereduce_nd(dimcatfuncs, asyncmap(fetch, a.chunks))
+        dimcatfuncs = [(d...) -> x.concat(d..., dims=i) for i in 1:ndims(x)]
+        Dagger.treereduce_nd(dimcatfuncs, asyncmap(fetch, x.chunks))
     else
+        csz = MPI.Comm_size(comm)
         datasnd = collect(x, acrossranks=false)
+        domsnd = collect(x.subdomains[1].indexes)
         if root === nothing
-            tmp = MPI.Allgather(datasnd, comm)
+            datatmp = MPI.Allgather(datasnd, comm)
+            domtmp = MPI.Allgather(domsnd, comm)
         else
-            tmp = MPI.Gather(datasnd, comm, root=root)
-            if tmp === nothing
+            datatmp = MPI.Gather(datasnd, comm, root=root)
+            if datatmp === nothing
                 return
             end
         end
-        return(reshape(tmp, size(x.domain)))
+        idx = 1
+        domtmp = reshape(domtmp, (N, csz))
+        doms = Array{Vector{UnitRange{Int64}}}(undef, csz)
+        domused = Array{Vector{UnitRange{Int64}}}(undef, csz)
+        for i in 1:csz
+            doms[i] = domtmp[:, i]
+            domused[i] = [0:0 for j in 1:N]
+        end
+        step = Int(length(datatmp) / csz)
+        data = Array{eltype(datatmp), N}(undef, size(x.domain))
+        for i in 1:csz
+            if doms[i] in domused
+                idx += step
+                continue
+            end
+            data[doms[i]...] = datatmp[idx:(idx - 1 + step)]
+            domused[i] = doms[i]
+            idx += step
+        end
+        return(reshape(data, size(x.domain)))
     end
 end
 
 
 
 end # module
+
+#TODO: sort, broadcast ops, mat-mul, setindex, getindex, mean
diff --git a/lib/DaggerMPI/test/runtests.jl b/lib/DaggerMPI/test/runtests.jl
new file mode 100644
index 000000000..aa617ed31
--- /dev/null
+++ b/lib/DaggerMPI/test/runtests.jl
@@ -0,0 +1,154 @@
+using LinearAlgebra, SparseArrays, Random, SharedArrays, Test, MPI, DaggerMPI, Dagger
+import Dagger: chunks, DArray, domainchunks, treereduce_nd
+import Distributed: myid, procs
+
+MPI.Init()
+
+@testset "DArray MPI constructor" begin
+    x = rand(MPIBlocks(nothing,nothing), 10, 10)
+    @test x isa Dagger.DArray{eltype(x), 2, MPIBlocks{2}, typeof(cat)}
+    @test collect(x) == DArray(eltype(x), x.domain, x.subdomains[1], x.chunks[1], x.partitioning, x.concat) |> collect
+end
+
+@testset "rand" begin
+    function test_rand(X1)
+        X2 = collect(X1)
+        @test isa(X1, Dagger.DArray)
+        @test X2 |> size == (100, 100)
+        @test all(X2 .>= 0.0)
+        @test size(chunks(X1)) == (1, 1)
+        @test domain(X1) == ArrayDomain(1:100, 1:100)
+        @test domainchunks(X1) |> size == (1, 1)
+        @test collect(X1) == collect(X1)
+    end
+    X = rand(MPIBlocks(nothing, nothing), 100, 100)
+    test_rand(X)
+    R = rand(MPIBlocks(nothing), 20)
+    r = collect(R)
+    @test r[1:10] != r[11:20]
+end
+
+@testset "local sum" begin
+    X = ones(MPIBlocks(nothing, nothing), 100, 100)
+    @test sum(X, acrossranks=false) == 10000 / MPI.Comm_size(MPI.COMM_WORLD)
+    Y = zeros(MPIBlocks(nothing, nothing), 100, 100)
+    @test sum(Y, acrossranks=false) == 0
+end
+
+@testset "sum across ranks" begin
+    X = ones(MPIBlocks(nothing, nothing), 100, 100)
+    @test sum(X) == 10000
+    Y = zeros(MPIBlocks(nothing, nothing), 100, 100)
+    @test sum(Y) == 0
+end
+
+@testset "rooted sum across ranks" begin
+    X = ones(MPIBlocks(nothing, nothing), 100, 100)
+    isroot = MPI.Comm_rank(MPI.COMM_WORLD) == 0
+    if isroot
+        @test sum(X, root=0) == 10000
+    else
+        @test sum(X, root=0) == nothing
+    end
+
+    Y = zeros(MPIBlocks(nothing, nothing), 100, 100)
+    if isroot
+        @test sum(Y, root=0) == 0
+    else
+        @test sum(Y, root=0) == nothing
+    end
+end
+
+@testset "local prod" begin
+    x = ones(100, 100)
+    x = 2 .* x
+    x = x ./ 10
+    X = distribute(x, MPIBlocks(nothing, nothing))
+    @test prod(X, acrossranks=false) == 0.2 ^ (10000 / MPI.Comm_size(MPI.COMM_WORLD))
+    Y = zeros(MPIBlocks(nothing, nothing), 100, 100)
+    @test prod(Y, acrossranks=false) == 0
+end
+
+@testset "prod across ranks" begin
+    x = randn(100, 100)
+    X = distribute(x, MPIBlocks(nothing, nothing))
+    @test prod(X) == prod(x)
+    Y = zeros(MPIBlocks(nothing, nothing), 100, 100)
+    @test prod(Y) == 0
+end
+
+@testset "rooted prod across ranks" begin
+    x = randn(100, 100)
+    X = distribute(x, MPIBlocks(nothing, nothing))
+    isroot = MPI.Comm_rank(MPI.COMM_WORLD) == 0
+
+    if isroot
+        @test prod(X, root=0) == prod(x)
+    else
+        @test prod(X, root=0) == nothing
+    end
+
+    Y = zeros(MPIBlocks(nothing, nothing), 100, 100)
+    if isroot
+        @test prod(Y, root=0) == 0
+    else
+        @test prod(Y, root=0) == nothing
+    end
+end
+
+@testset "distributing an array" begin
+    function test_dist(X)
+        X1 = distribute(X, MPIBlocks(nothing, nothing))
+        Xc = fetch(X1)
+        @test Xc isa DArray{eltype(X), ndims(X), MPIBlocks{2}, typeof(cat)}
+        @test chunks(Xc) |> size == (1, 1)
+        @test domainchunks(Xc) |> size == (1, 1)
+    end
+    x = [1 2; 3 4]
+    test_dist(rand(100, 100))
+end
+
+@testset "reducedim across ranks" begin
+    X = randn(MPIBlocks(nothing, nothing), 100, 100)
+    @test reduce(+, collect(X), dims= 1) ≈ collect(reduce(+, X, dims=1))
+    @test reduce(+, collect(X), dims= 2) ≈ collect(reduce(+, X, dims=2))
+    
+    @test sum(collect(X), dims= 1) ≈ collect(sum(X, dims=1))
+    @test sum(collect(X), dims= 2) ≈ collect(sum(X, dims=2))
+end
+
+@testset "rooted reducedim across ranks" begin
+    X = randn(MPIBlocks(nothing, nothing), 100, 100)
+    isroot = MPI.Comm_rank(MPI.COMM_WORLD) == 0
+    if isroot
+        @test collect(reduce(+, X, dims=1, root=0), acrossranks=false) ≈ collect(reduce(+, X, dims=1), acrossranks=false)
+        @test collect(reduce(+, X, dims=2, root=0), acrossranks=false) ≈ collect(reduce(+, X, dims=2), acrossranks=false)
+    else
+        @test reduce(+, X, dims=1, root=0) == nothing
+        reduce(+, X, dims=1)
+        @test reduce(+, X, dims= 2, root=0) == nothing
+        reduce(+, X, dims=2)
+    end
+    
+    if isroot
+        @test collect(sum(X, dims=1, root=0), acrossranks=false) ≈ collect(sum(X, dims=1), acrossranks=false)
+        @test collect(sum(X, dims=2, root=0), acrossranks=false) ≈ collect(sum(X, dims=2), acrossranks=false)
+    else
+        @test sum(X, dims=1, root=0) == nothing
+        sum(X, dims=1)
+        @test sum(X, dims=2, root=0) == nothing
+        sum(X, dims=2)
+    end
+end
+
+@testset "local reducedim" begin
+    X = rand(MPIBlocks(nothing, nothing), 100, 100)
+    @test reduce(+, collect(X, acrossranks=false), dims=1) == collect(reduce(+, X, dims=1, acrossranks=false), acrossranks=false)
+    @test reduce(+, collect(X, acrossranks=false), dims=2) == collect(reduce(+, X, dims=2, acrossranks=false), acrossranks=false)
+
+    X = rand(MPIBlocks(nothing, nothing), 100, 100)
+    @test sum(collect(X, acrossranks=false), dims=1) == collect(sum(X, dims=1, acrossranks=false), acrossranks=false)
+    @test sum(collect(X, acrossranks=false), dims=2) == collect(sum(X, dims=2, acrossranks=false), acrossranks=false)
+end
+
+MPI.Finalize()

From af51f416ffda33051d7c79219397168f9f561741 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Felipe=20de=20Alc=C3=A2ntara=20Tom=C3=A9?=
 <62853093+fda-tome@users.noreply.github.com>
Date: Wed, 16 Aug 2023 08:44:51 -0300
Subject: [PATCH 5/5] Minor correction on rooted collect Gather operation

---
 lib/DaggerMPI/src/DaggerMPI.jl | 1 +
 1 file changed, 1 insertion(+)

diff --git a/lib/DaggerMPI/src/DaggerMPI.jl b/lib/DaggerMPI/src/DaggerMPI.jl
index feaab8379..cb6f4c297 100644
--- a/lib/DaggerMPI/src/DaggerMPI.jl
+++ b/lib/DaggerMPI/src/DaggerMPI.jl
@@ -215,6 +215,7 @@ function Base.collect(x::Dagger.DArray{T,N,MPIBlocks{N}};
             domtmp = MPI.Allgather(domsnd, comm)
         else
             datatmp = MPI.Gather(datasnd, comm, root=root)
+            domtmp = MPI.Gather(domsnd, comm, root=root)
             if datatmp === nothing
                 return
             end