Skip to content

Commit

Permalink
Merge 74edbd5 into aac9688
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy authored Sep 24, 2024
2 parents aac9688 + 74edbd5 commit 6cc31bf
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 26 deletions.
9 changes: 7 additions & 2 deletions docs/examples/03-reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,23 @@ function pool(S1::SummaryStat, S2::SummaryStat)
SummaryStat(m,v,n)
end

# Register the custom reduction operator. This is necessary only on platforms
# where Julia doesn't support closures as cfunctions (e.g. ARM), but can be used
# on all platforms for consistency.
MPI.@RegisterOp(pool, SummaryStat)

X = randn(10,3) .* [1,3,7]'

# Perform a scalar reduction
summ = MPI.Reduce(SummaryStat(X), pool, root, comm)
summ = MPI.Reduce(SummaryStat(X), pool, comm; root)

if MPI.Comm_rank(comm) == root
@show summ.var
end

# Perform a vector reduction:
# the reduction operator is applied elementwise
col_summ = MPI.Reduce(mapslices(SummaryStat,X,dims=1), pool, root, comm)
col_summ = MPI.Reduce(mapslices(SummaryStat,X,dims=1), pool, comm; root)

if MPI.Comm_rank(comm) == root
col_var = map(summ -> summ.var, col_summ)
Expand Down
1 change: 1 addition & 0 deletions docs/src/reference/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ MPI.Types.duplicate

```@docs
MPI.Op
MPI.@RegisterOp
```

## Info objects
Expand Down
99 changes: 91 additions & 8 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ associative, and if `iscommutative` is true, assumed to be commutative as well.
- [`Allreduce!`](@ref)/[`Allreduce`](@ref)
- [`Scan!`](@ref)/[`Scan`](@ref)
- [`Exscan!`](@ref)/[`Exscan`](@ref)
- [`@RegisterOp`](@ref)
"""
mutable struct Op
val::MPI_Op
Expand Down Expand Up @@ -81,21 +82,36 @@ end

function (w::OpWrapper{F,T})(_a::Ptr{Cvoid}, _b::Ptr{Cvoid}, _len::Ptr{Cint}, t::Ptr{MPI_Datatype}) where {F,T}
len = unsafe_load(_len)
@assert isconcretetype(T)
a = Ptr{T}(_a)
b = Ptr{T}(_b)
for i = 1:len
unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i)
if !isconcretetype(T)
concrete_T = to_type(Datatype(unsafe_load(t))) # Ptr might actually point to a Julia object so we could unsafe_pointer_to_objref?
else
concrete_T = T
end
function copy(::Type{T}) where T
@assert isconcretetype(T)
a = Ptr{T}(_a)
b = Ptr{T}(_b)
for i = 1:len
unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i)
end
end
copy(concrete_T)
return nothing
end


function Op(f, T=Any; iscommutative=false)
@static if MPI_LIBRARY == "MicrosoftMPI" && Sys.WORD_SIZE == 32
error("User-defined reduction operators are not supported on 32-bit Windows.\nSee https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.")
error("""
User-defined reduction operators are not supported on 32-bit Windows.
See https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.
""")
elseif Sys.ARCH (:aarch64, :ppc64le, :powerpc64le) || startswith(lowercase(String(Sys.ARCH)), "arm")
error("User-defined reduction operators are currently not supported on non-Intel architectures.\nSee https://github.com/JuliaParallel/MPI.jl/issues/404 for more details.")
error("""
User-defined reduction operators are currently not supported on non-Intel architectures.
See https://github.com/JuliaParallel/MPI.jl/issues/404 for more details.
You may want to use `@RegisterOp` to statically register `f`.
""")
end
w = OpWrapper{typeof(f),T}(f)
fptr = @cfunction($w, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype}))
Expand All @@ -107,3 +123,70 @@ function Op(f, T=Any; iscommutative=false)
finalizer(free, op)
return op
end

"""
@RegisterOp(f, T)
Register a custom operator [`Op`](@ref) using the function `f` statically.
On platfroms like AArch64, Julia does not support runtime closures,
being passed to C. The generic version of [`Op`](@ref) uses runtime closures
to support arbitrary functions being passed as MPI reduction operators.
`@RegisterOp` statically adds a function to the set of functions allowed as
as an MPI operator.
```julia
function my_reduce(x, y)
2x+y-x
end
MPI.@RegisterOp(my_reduce, Int)
# ...
MPI.Reduce!(send_arr, recv_arr, my_reduce, MPI.COMM_WORLD; root=root)
#...
```
!!! warning
Note that `@RegisterOp` works be introducing a new method of the generic function `Op`.
It can only be used as a top-level statement and may trigger method invalidations.
!!! note
`T` can be `Any`, but this will lead to a runtime dispatch.
"""
macro RegisterOp(f, T)
name_wrapper = gensym(Symbol(f, :_, T, :_wrapper))
name_fptr = gensym(Symbol(f, :_, T, :_ptr))
name_module = gensym(Symbol(f, :_, T, :_module))
# The gist is that we can use a method very similar to how we handle `min`/`max`
# but since this might be used from user code we can't use add_load_time_hook!
# this is why we introduce a new module that has a `__init__` function.
# If this module approach is too costly for loading MPI.jl for internal use we could use
# `add_load_time_hook`
expr = quote
module $(name_module)
# import ..$f, ..$T
$(Expr(:import, Expr(:., :., :., f), Expr(:., :., :., T))) # Julia 1.6 strugles with import ..$f, ..$T
const $(name_wrapper) = $OpWrapper{typeof($f),$T}($f)
const $(name_fptr) = Ref(@cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype})))
function __init__()
$(name_fptr)[] = @cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype}))
end
import MPI: Op
# we can't create a const Op since MPI needs to be initialized?
function Op(::typeof($f), ::Type{<:$T}; iscommutative=false)
op = Op($OP_NULL.val, $(name_fptr)[])
# int MPI_Op_create(MPI_User_function* user_fn, int commute, MPI_Op* op)
$API.MPI_Op_create($(name_fptr)[], iscommutative, op)

finalizer($free, op)
end
end
end
expr.head = :toplevel
esc(expr)
end

@RegisterOp(min, Any)
@RegisterOp(max, Any)
@RegisterOp(+, Any)
@RegisterOp(*, Any)
@RegisterOp(&, Any)
@RegisterOp(|, Any)
@RegisterOp(, Any)
9 changes: 5 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[compat]
AMDGPU = "0.6, 0.7, 0.8, 0.9, 1"
CUDA = "3, 4, 5"
DoubleFloats = "1.4"
MPIPreferences = "0.1"
StaticArrays = "1"
TOML = "< 0.0.1, 1.0"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
27 changes: 15 additions & 12 deletions test/test_reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ if isroot
@test sum_mesg == sz .* mesg
end

function my_reduce(x, y)
2x+y-x
end
MPI.@RegisterOp(my_reduce, Any)

if can_do_closures
operators = [MPI.SUM, +, (x,y) -> 2x+y-x]
operators = [MPI.SUM, +, my_reduce, (x,y) -> 2x+y-x]
else
operators = [MPI.SUM, +]
operators = [MPI.SUM, +, my_reduce]
end

for T = [Int]
Expand Down Expand Up @@ -117,19 +122,17 @@ end

MPI.Barrier( MPI.COMM_WORLD )

if can_do_closures
send_arr = [Double64(i)/10 for i = 1:10]

result = MPI.Reduce(send_arr, +, MPI.COMM_WORLD; root=root)
if rank == root
@test result [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64)
else
@test result === nothing
end
send_arr = [Double64(i)/10 for i = 1:10]

MPI.Barrier( MPI.COMM_WORLD )
result = MPI.Reduce(send_arr, +, MPI.COMM_WORLD; root=root)
if rank == root
@test result [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64)
else
@test result === nothing
end

MPI.Barrier( MPI.COMM_WORLD )

GC.gc()
MPI.Finalize()
@test MPI.Finalized()

0 comments on commit 6cc31bf

Please sign in to comment.