Skip to content

Commit

Permalink
Merge pull request #139 from omlins/EnzymeExt
Browse files Browse the repository at this point in the history
Use extension for Enzyme dependency
  • Loading branch information
omlins authored Jan 16, 2024
2 parents 999f0e5 + 5413851 commit d49dd44
Show file tree
Hide file tree
Showing 15 changed files with 96 additions and 67 deletions.
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ version = "0.11.0"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CellArrays = "d35fcfd7-7af4-4c67-b1aa-d78070614af4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
ParallelStencil_EnzymeExt = "Enzyme"

[compat]
AMDGPU = "0.6, 0.7, 0.8"
CUDA = "3.12, 4, 5"
Expand All @@ -26,4 +31,4 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "TOML"]
test = ["Test", "TOML", "Enzyme"]
3 changes: 3 additions & 0 deletions ext/ParallelStencil_EnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module ParallelStencil_EnzymeExt
include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "EnzymeExt", "autodiff_gpu.jl"))
end
3 changes: 1 addition & 2 deletions src/AD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the
To see a description of a function type `?<functionname>`.
"""
module AD
import ..ParallelKernel.AD: init_AD, autodiff_deferred!, autodiff_deferred_thunk!
export autodiff_deferred!, autodiff_deferred_thunk!
import ..ParallelKernel.AD: autodiff_deferred!, autodiff_deferred_thunk!

end # Module AD
32 changes: 0 additions & 32 deletions src/ParallelKernel/AD.jl

This file was deleted.

29 changes: 29 additions & 0 deletions src/ParallelKernel/EnzymeExt/AD.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Module AD
Provides GPU-compatible wrappers for automatic differentiation functions of the Enzyme.jl package. Enzyme needs to be imported before ParallelStencil in order to have it load the corresponding extension. Consult the Enzyme documentation to learn how to use the wrapped functions.
# Usage
import ParallelKernel.AD
# Functions
- `autodiff_deferred!`: wraps function `autodiff_deferred`.
- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`.
!!! note "Enzyme runtime activity default"
If ParallelKernel is initialized with Threads, then `Enzyme.API.runtimeActivity!(true)` is called to ensure correct behavior of Enzyme. If you want to disable this behavior, then call `Enzyme.API.runtimeActivity!(false)` after loading ParallelStencil.
To see a description of a function type `?<functionname>`.
"""
module AD
using ..Exceptions

const ERRMSG_EXTENSION_LOAD_ERROR = "AD: the Enzyme extension was not loaded. Make sure to import Enzyme before ParallelStencil."

init_AD(args...) = return # NOTE: a call will be triggered from @init_parallel_kernel, but it will do nothing if the extension is not loaded. Methods are to be defined in the AD extension modules.
autodiff_deferred!(args...) = @ExtensionLoadError(ERRMSG_EXTENSION_LOAD_ERROR)
autodiff_deferred_thunk!(args...) = @ExtensionLoadError(ERRMSG_EXTENSION_LOAD_ERROR)

export autodiff_deferred!, autodiff_deferred_thunk!

end # Module AD
19 changes: 19 additions & 0 deletions src/ParallelKernel/EnzymeExt/autodiff_gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import ParallelStencil
import ParallelStencil: PKG_THREADS
import Enzyme

function ParallelStencil.ParallelKernel.AD.init_AD(package::Symbol)
if package == PKG_THREADS
Enzyme.API.runtimeActivity!(true) # NOTE: this is currently required for Enzyme to work correctly with threads
end
end

function ParallelStencil.ParallelKernel.AD.autodiff_deferred!(arg, args...) # NOTE: minimal specialization is used to avoid overwriting the default method
Enzyme.autodiff_deferred(arg, args...)
return
end

function ParallelStencil.ParallelKernel.AD.autodiff_deferred_thunk!(arg, args...) # NOTE: minimal specialization is used to avoid overwriting the default method
Enzyme.autodiff_deferred_thunk(arg, args...)
return
end
10 changes: 8 additions & 2 deletions src/ParallelKernel/Exceptions.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
module Exceptions
export @ModuleInternalError, @MethodPluginError, @IncoherentCallError, @NotInitializedError, @IncoherentArgumentError, @KeywordArgumentError, @ArgumentEvaluationError, @ArgumentError
export ModuleInternalError, MethodPluginError, IncoherentCallError, NotInitializedError, IncoherentArgumentError, KeywordArgumentError, ArgumentEvaluationError
export @ModuleInternalError, @MethodPluginError, @IncoherentCallError, @NotInitializedError, @ExtensionLoadError, @IncoherentArgumentError, @KeywordArgumentError, @ArgumentEvaluationError, @ArgumentError
export ModuleInternalError, MethodPluginError, IncoherentCallError, NotInitializedError, ExtensionLoadError, IncoherentArgumentError, KeywordArgumentError, ArgumentEvaluationError

macro ModuleInternalError(msg) esc(:(throw(ModuleInternalError($msg)))) end
macro MethodPluginError(msg) esc(:(throw(MethodPluginError($msg)))) end
macro IncoherentCallError(msg) esc(:(throw(IncoherentCallError($msg)))) end
macro NotInitializedError(msg) esc(:(throw(NotInitializedError($msg)))) end
macro ExtensionLoadError(msg) esc(:(throw(ExtensionLoadError($msg)))) end
macro IncoherentArgumentError(msg) esc(:(throw(IncoherentArgumentError($msg)))) end
macro KeywordArgumentError(msg) esc(:(throw(KeywordArgumentError($msg)))) end
macro ArgumentEvaluationError(msg) esc(:(throw(ArgumentEvaluationError($msg)))) end
Expand All @@ -31,6 +32,11 @@ struct NotInitializedError <: Exception
end
Base.showerror(io::IO, e::NotInitializedError) = print(io, "NotInitializedError: ", e.msg)

struct ExtensionLoadError <: Exception
msg::String
end
Base.showerror(io::IO, e::ExtensionLoadError) = print(io, "ExtensionLoadError: ", e.msg)

struct IncoherentArgumentError <: Exception
msg::String
end
Expand Down
8 changes: 5 additions & 3 deletions src/ParallelKernel/ParallelKernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ To see a description of a macro or module type `?<macroname>` (including the `@`
"""
module ParallelKernel

## Alphabetical include of submodules.
include("AD.jl");
include("Data.jl");
## Include off exception module
include("Exceptions.jl");
using .Exceptions

## Alphabetical include of submodules.
include(joinpath("EnzymeExt", "AD.jl"));
include("Data.jl");

## Include of constant parameters, types and syntax sugar shared in ParallelKernel module only
include("shared.jl")

Expand Down
7 changes: 3 additions & 4 deletions src/ParallelKernel/init_parallel_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataT
data_module_shared = Data_shared(numbertype, indextype)
pkg_import_cmd = :()
end
ad_import_cmd = :(import ParallelStencil.ParallelKernel.Enzyme)
if (package == PKG_THREADS) ad_import_cmd = :(import ParallelStencil.ParallelKernel.Enzyme; Enzyme.API.runtimeActivity!(true)) end # NOTE: Enzyme requires this currently to work correctly with threads.
ad_init_cmd = :(ParallelStencil.ParallelKernel.AD.init_AD(ParallelStencil.ParallelKernel.PKG_THREADS))
if !isdefined(caller, :Data) || (@eval(caller, isa(Data, Module)) && length(symbols(caller, :Data)) == 1) # Only if the module Data does not exist in the caller or is empty, create it.
if (datadoc_call==:())
if (numbertype == NUMBERTYPE_NONE) datadoc_call = :(@doc ParallelStencil.ParallelKernel.DATA_DOC_NUMBERTYPE_NONE Data)
Expand All @@ -59,7 +58,7 @@ function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataT
@warn "Module Data cannot be created in caller module ($caller) as there is already a user defined symbol (module/variable...) with this name. ParallelStencil is still usable but without the features of the Data module."
end
@eval(caller, $pkg_import_cmd)
@eval(caller, $ad_import_cmd)
@eval(caller, $ad_init_cmd)
set_package(caller, package)
set_numbertype(caller, numbertype)
set_inbounds(caller, inbounds)
Expand Down Expand Up @@ -112,4 +111,4 @@ function extract_kwargs_nopos(caller::Module, kwargs::Dict)
else inbounds_val = false
end
return inbounds_val
end
end
2 changes: 1 addition & 1 deletion src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const PARALLEL_DOC = """
@parallel (...) configcall=... backendkwargs... kernelcall
@parallel ∇=... ad_mode=... ad_annotations=... (...) backendkwargs... kernelcall
Declare the `kernelcall` parallel. The kernel will automatically be called as required by the package for parallelization selected with [`@init_parallel_kernel`](@ref). Synchronizes at the end of the call (if a stream is given via keyword arguments, then it synchronizes only this stream). The keyword argument `∇` triggers a parallel call to the gradient kernel instead of the kernel itself. The automatic differentiation is performed with the package Enzyme.jl (refer to the corresponding documentation for Enzyme-specific terms used below).
Declare the `kernelcall` parallel. The kernel will automatically be called as required by the package for parallelization selected with [`@init_parallel_kernel`](@ref). Synchronizes at the end of the call (if a stream is given via keyword arguments, then it synchronizes only this stream). The keyword argument `∇` triggers a parallel call to the gradient kernel instead of the kernel itself. The automatic differentiation is performed with the package Enzyme.jl (refer to the corresponding documentation for Enzyme-specific terms used below); Enzyme needs to be imported before ParallelKernel in order to have it load the corresponding extension.
# Arguments
- `kernelcall`: a call to a kernel that is declared parallel.
Expand Down
7 changes: 3 additions & 4 deletions src/ParallelKernel/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ elseif ENABLE_AMDGPU
else
const SUPPORTED_PACKAGES = [PKG_THREADS]
end
import Enzyme
using CellArrays, StaticArrays, MacroTools
import MacroTools: postwalk, splitdef, combinedef, isexpr, unblock # NOTE: inexpr_walk used instead of MacroTools.inexpr

Expand Down Expand Up @@ -60,9 +59,9 @@ const SUPPORTED_NUMBERTYPES = [Float16, Float32, Float64, Complex{Fl
const PKNumber = Union{Float16, Float32, Float64, Complex{Float16}, Complex{Float32}, Complex{Float64}} # NOTE: this always needs to correspond to SUPPORTED_NUMBERTYPES!
const NUMBERTYPE_NONE = DataType
const AD_MODE_DEFAULT = :(Enzyme.Reverse)
const AD_DUPLICATE_DEFAULT = Enzyme.DuplicatedNoNeed
const AD_ANNOTATION_DEFAULT = Enzyme.Const
const AD_SUPPORTED_ANNOTATIONS = (Const=Enzyme.Const, Active=Enzyme.Active, Duplicated=Enzyme.Duplicated, DuplicatedNoNeed=Enzyme.DuplicatedNoNeed)
const AD_DUPLICATE_DEFAULT = :(Enzyme.DuplicatedNoNeed)
const AD_ANNOTATION_DEFAULT = :(Enzyme.Const)
const AD_SUPPORTED_ANNOTATIONS = (Const=:(Enzyme.Const), Active=:(Enzyme.Active), Duplicated=:(Enzyme.Duplicated), DuplicatedNoNeed=:(Enzyme.DuplicatedNoNeed))
const ERRMSG_UNSUPPORTED_PACKAGE = "unsupported package for parallelization"
const ERRMSG_CHECK_PACKAGE = "package has to be functional and one of the following: $(join(SUPPORTED_PACKAGES,", "))"
const ERRMSG_CHECK_NUMBERTYPE = "numbertype has to be one of the following (and evaluatable at parse time): $(join(SUPPORTED_NUMBERTYPES,", "))"
Expand Down
4 changes: 2 additions & 2 deletions src/init_parallel_stencil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ function init_parallel_stencil(caller::Module, package::Symbol, numbertype::Data
if (numbertype == NUMBERTYPE_NONE) datadoc_call = :(@doc replace(ParallelStencil.ParallelKernel.DATA_DOC_NUMBERTYPE_NONE, "@init_parallel_kernel" => "@init_parallel_stencil") Data)
else datadoc_call = :(@doc replace(ParallelStencil.ParallelKernel.DATA_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") Data)
end
ParallelKernel.init_parallel_kernel(caller, package, numbertype, inbounds; datadoc_call=datadoc_call)
return_expr = ParallelKernel.init_parallel_kernel(caller, package, numbertype, inbounds; datadoc_call=datadoc_call)
set_package(caller, package)
set_numbertype(caller, numbertype)
set_ndims(caller, ndims)
set_inbounds(caller, inbounds)
set_memopt(caller, memopt)
set_initialized(caller, true)
return nothing
return return_expr
end


Expand Down
2 changes: 1 addition & 1 deletion src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ See also: [`@init_parallel_stencil`](@ref)
@parallel (...) memopt=... configcall=... backendkwargs... kernelcall
@parallel ∇=... ad_mode=... ad_annotations=... (...) memopt=... backendkwargs... kernelcall
Declare the `kernelcall` parallel. The kernel will automatically be called as required by the package for parallelization selected with [`@init_parallel_kernel`](@ref). Synchronizes at the end of the call (if a stream is given via keyword arguments, then it synchronizes only this stream). The keyword argument `∇` triggers a parallel call to the gradient kernel instead of the kernel itself. The automatic differentiation is performed with the package Enzyme.jl (refer to the corresponding documentation for Enzyme-specific terms used below).
Declare the `kernelcall` parallel. The kernel will automatically be called as required by the package for parallelization selected with [`@init_parallel_kernel`](@ref). Synchronizes at the end of the call (if a stream is given via keyword arguments, then it synchronizes only this stream). The keyword argument `∇` triggers a parallel call to the gradient kernel instead of the kernel itself. The automatic differentiation is performed with the package Enzyme.jl (refer to the corresponding documentation for Enzyme-specific terms used below); Enzyme needs to be imported before ParallelStencil in order to have it load the corresponding extension.
# Arguments
- `kernelcall`: a call to a kernel that is declared parallel.
Expand Down
16 changes: 8 additions & 8 deletions test/ParallelKernel/test_parallel.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Test
import ParallelStencil
using Enzyme
using ParallelStencil.ParallelKernel
using ParallelStencil.ParallelKernel.Enzyme
import ParallelStencil.ParallelKernel.AD
import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_THREADS, INDICES
import ParallelStencil.ParallelKernel: @require, @prettystring, @gorgeousstring, @isgpu
Expand All @@ -18,7 +18,7 @@ end
end
macro compute(A) esc(:($(INDICES[1]) + ($(INDICES[2])-1)*size($A,1))) end
macro compute_with_aliases(A) esc(:(ix + (iz -1)*size($A,1))) end

import Enzyme
@static for package in TEST_PACKAGES eval(:(
@testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
@testset "1. parallel macros" begin
Expand Down Expand Up @@ -79,12 +79,12 @@ macro compute_with_aliases(A) esc(:(ix + (iz -1)*size($A,1)
end;
end;
@testset "@parallel ∇" begin
@test @prettystring(1, @parallel=B->f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Reverse, f!, (Const)(A), (DuplicatedNoNeed)(B, B̄), (Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Reverse, f!, (DuplicatedNoNeed)(A, Ā), (DuplicatedNoNeed)(B, B̄), (Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, (DuplicatedNoNeed)(A, Ā), (DuplicatedNoNeed)(B, B̄), (Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=B) f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, (DuplicatedNoNeed)(A, Ā), (Duplicated)(B, B̄), (Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=(B,A), Active=b) f!(A, B, a, b)) == "@parallel configcall = f!(A, B, a, b) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, (Duplicated)(A, Ā), (Duplicated)(B, B̄), (Const)(a), (Active)(b))"
@test @prettystring(1, @parallel=(V.x->.x, V.y->.y) f!(V.x, V.y, a)) == "@parallel configcall = f!(V.x, V.y, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Reverse, f!, (DuplicatedNoNeed)(V.x, V̄.x), (DuplicatedNoNeed)(V.y, V̄.y), (Const)(a))"
@test @prettystring(1, @parallel=B->f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Reverse, f!, Enzyme.Const(A), Enzyme.DuplicatedNoNeed(B, B̄), Enzyme.Const(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Reverse, f!, Enzyme.DuplicatedNoNeed(A, Ā), Enzyme.DuplicatedNoNeed(B, B̄), Enzyme.Const(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, Enzyme.DuplicatedNoNeed(A, Ā), Enzyme.DuplicatedNoNeed(B, B̄), Enzyme.Const(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=B) f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, Enzyme.DuplicatedNoNeed(A, Ā), Enzyme.Duplicated(B, B̄), Enzyme.Const(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=(B,A), Active=b) f!(A, B, a, b)) == "@parallel configcall = f!(A, B, a, b) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, Enzyme.Duplicated(A, Ā), Enzyme.Duplicated(B, B̄), Enzyme.Const(a), Enzyme.Active(b))"
@test @prettystring(1, @parallel=(V.x->.x, V.y->.y) f!(V.x, V.y, a)) == "@parallel configcall = f!(V.x, V.y, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Reverse, f!, Enzyme.DuplicatedNoNeed(V.x, V̄.x), Enzyme.DuplicatedNoNeed(V.y, V̄.y), Enzyme.Const(a))"
end;
@testset "AD.autodiff_deferred!" begin
@static if $package == $PKG_THREADS
Expand Down
Loading

0 comments on commit d49dd44

Please sign in to comment.