Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add macro to create custom Ops also on aarch64 #871

Merged
merged 23 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/reference/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ MPI.Types.duplicate

```@docs
MPI.Op
MPI.@RegisterOp
```

## Info objects
Expand Down
84 changes: 78 additions & 6 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ associative, and if `iscommutative` is true, assumed to be commutative as well.
- [`Allreduce!`](@ref)/[`Allreduce`](@ref)
- [`Scan!`](@ref)/[`Scan`](@ref)
- [`Exscan!`](@ref)/[`Exscan`](@ref)
- [`@RegisterOp`](@ref)
"""
mutable struct Op
val::MPI_Op
Expand Down Expand Up @@ -81,16 +82,21 @@ end

function (w::OpWrapper{F,T})(_a::Ptr{Cvoid}, _b::Ptr{Cvoid}, _len::Ptr{Cint}, t::Ptr{MPI_Datatype}) where {F,T}
len = unsafe_load(_len)
@assert isconcretetype(T)
a = Ptr{T}(_a)
b = Ptr{T}(_b)
for i = 1:len
unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i)
if !isconcretetype(T)
T = to_type(Datatype(unsafe_load(t))) # Ptr might actually point to a Julia object so we could unsafe_pointer_to_objref?
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
end
function copy(::Type{T}) where T
@assert isconcretetype(T)
a = Ptr{T}(_a)
b = Ptr{T}(_b)
for i = 1:len
unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i)
end
end
copy(T)
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
return nothing
end


function Op(f, T=Any; iscommutative=false)
@static if MPI_LIBRARY == "MicrosoftMPI" && Sys.WORD_SIZE == 32
error("User-defined reduction operators are not supported on 32-bit Windows.\nSee https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.")
Expand All @@ -107,3 +113,69 @@ function Op(f, T=Any; iscommutative=false)
finalizer(free, op)
return op
end

"""
@RegisterOp(f, T, internal=false)
vchuravy marked this conversation as resolved.
Show resolved Hide resolved

Register a custom operator [`Op`](@ref) using the function `f` statically.
On platfroms like AArch64, Julia does not support runtime closures,
being passed to C. The generic version of [`Op`](@ref) uses that
to support arbitrary function being passed as MPI reduction operators.
In contrast `@RegisterOp` can be used to statically declare a function to
be passed as an MPI operator.

```julia
function my_reduce(x, y)
2x+y-x
end
MPI.@RegisterOp(my_reduce, Int)
# ...
MPI.Reduce!(send_arr, recv_arr, my_reduce, MPI.COMM_WORLD; root=root)
#...
```
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
!!! warning
Note that `@RegisterOp` works be introducing a new method to `Op`, potentially invalidating other users of `Op`.
vchuravy marked this conversation as resolved.
Show resolved Hide resolved

!!! note
`T` can be `Any`, but that will lead to a runtime dispatch.
"""
macro RegisterOp(f, T)
name_wrapper = gensym(Symbol(f, :_, T, :_wrapper))
name_fptr = gensym(Symbol(f, :_, T, :_ptr))
name_module = gensym(Symbol(f, :_, T, :_module))
# The gist is that we can use a method very similar to how we handle `min`/`max`
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
# but since this might be used from user code we can't use add_load_time_hook!
# this is why we introduce a new module that has a `__init__` function.
# If this module approach is too costly for loading MPI.jl for internal use we could use
# `add_load_time_hook`
expr = quote
module $(name_module)
# import ..$f, ..$T
$(Expr(:import, Expr(:., :., :., f), Expr(:., :., :., T))) # Julia 1.6 strugles with import ..$f, ..$T
const $(name_wrapper) = $OpWrapper{typeof($f),$T}($f)
const $(name_fptr) = Ref(@cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype})))
function __init__()
$(name_fptr)[] = @cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype}))
end
import MPI: Op
# we can't create a const Op since MPI needs to be initialized?
function Op(::typeof($f), ::Type{<:$T}; iscommutative=false)
op = Op($OP_NULL.val, $(name_fptr)[])
# int MPI_Op_create(MPI_User_function* user_fn, int commute, MPI_Op* op)
$API.MPI_Op_create($(name_fptr)[], iscommutative, op)

finalizer($free, op)
end
end
end
expr.head = :toplevel
esc(expr)
end

@RegisterOp(min, Any)
@RegisterOp(max, Any)
@RegisterOp(+, Any)
@RegisterOp(*, Any)
@RegisterOp(&, Any)
@RegisterOp(|, Any)
@RegisterOp(⊻, Any)
10 changes: 6 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@ DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[compat]
AMDGPU = "0.6, 0.7, 0.8, 0.9, 1"
CUDA = "3, 4, 5"
DoubleFloats = "1.4"
MPIPreferences = "0.1"
StaticArrays = "1"
TOML = "< 0.0.1, 1.0"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
9 changes: 7 additions & 2 deletions test/test_reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ if isroot
@test sum_mesg == sz .* mesg
end

function my_reduce(x, y)
2x+y-x
end
MPI.@RegisterOp(my_reduce, Any)

if can_do_closures
operators = [MPI.SUM, +, (x,y) -> 2x+y-x]
operators = [MPI.SUM, +, my_reduce, (x,y) -> 2x+y-x]
else
operators = [MPI.SUM, +]
operators = [MPI.SUM, +, my_reduce]
end

for T = [Int]
Expand Down
Loading