Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,22 @@ end

Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear()

Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} =
arrayref(A, i1)
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} =
arrayset(A, convert(T, x)::T, i1)
Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = arrayref(
A, i1
)
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = arrayset(
A, convert(T, x)::T, i1
)

# preserve the specific integer type when indexing device arrays,
# to avoid extending 32-bit hardware indices to 64-bit.
Base.to_index(::CuTracedArray, i::Integer) = i

# Base doesn't like Integer indices, so we need our own ND get and setindex! routines.
# See also: https://github.com/JuliaLang/julia/pull/42289
Base.@propagate_inbounds Base.getindex(
A::CuTracedArray, I::Union{Integer,CartesianIndex}...
) = A[Base._to_linear_index(A, to_indices(A, I)...)]
Base.@propagate_inbounds Base.getindex(A::CuTracedArray, I::Union{Integer,CartesianIndex}...) = A[Base._to_linear_index(
A, to_indices(A, I)...
)]
Base.@propagate_inbounds Base.setindex!(
A::CuTracedArray, x, I::Union{Integer,CartesianIndex}...
) = A[Base._to_linear_index(A, to_indices(A, I)...)] = x
Expand Down
10 changes: 4 additions & 6 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,10 @@ function trace_for(mod, expr)
step = length(range.args) == 3 ? 1 : range.args[3]
limit = range.args[end]

body_symbols = ExpressionExplorer.compute_symbols_state(
quote
$(Expr(:local, assign))
$body
end,
)
body_symbols = ExpressionExplorer.compute_symbols_state(quote
$(Expr(:local, assign))
$body
end)

external_syms = body_symbols.assignments ∪ body_symbols.references
filter!(∉(SPECIAL_SYMBOLS), external_syms)
Expand Down
2 changes: 1 addition & 1 deletion src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N
end

# TODO is there any way to allocate an uninitialized buffer in XLA?
function Base.similar(a::ConcreteRArray{T}, ::Type{S}=T, dims::Dims=size(a)) where {T,S}
function Base.similar(a::ConcreteRArray{T}, (::Type{S})=T, dims::Dims=size(a)) where {T,S}
return ConcreteRArray(
Array{S}(undef, dims); client=XLA.client(a.data), device=XLA.device(a.data)
)
Expand Down
53 changes: 30 additions & 23 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ end
ReactantCacheToken(),
REACTANT_METHOD_TABLE,
world,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
false,
false,
false,
false,
set_reactant_abi,
)
end
Expand All @@ -81,10 +81,10 @@ else
REACTANT_CACHE,
REACTANT_METHOD_TABLE,
world,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
false,
false,
false,
false,
set_reactant_abi,
)
end
Expand All @@ -97,20 +97,25 @@ const enzyme_dupnoneed = 3
const enzyme_outnoneed = 4
const enzyme_constnoneed = 5

@inline act_from_type(x, reverse, needs_primal=true) =
throw(AssertionError("Unhandled activity $(typeof(x))"))
@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) =
act_from_type(Enzyme.Const, reverse, needs_primal)
@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) =
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
@inline act_from_type(x, reverse, needs_primal=true) = throw(
AssertionError("Unhandled activity $(typeof(x))")
)
@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) = act_from_type(
Enzyme.Const, reverse, needs_primal
)
@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) = act_from_type(
Enzyme.Duplicated, reverse, needs_primal
)
@inline act_from_type(::Enzyme.DuplicatedNoNeed, reverse, needs_primal=true) =
reverse ? enzyme_out : enzyme_dupnoneed
@inline act_from_type(::Enzyme.BatchDuplicated, reverse, needs_primal=true) =
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
@inline act_from_type(::Enzyme.BatchDuplicated, reverse, needs_primal=true) = act_from_type(
Enzyme.Duplicated, reverse, needs_primal
)
@inline act_from_type(::Enzyme.BatchDuplicatedNoNeed, reverse, needs_primal=true) =
reverse ? enzyme_out : enzyme_dupnoneed
@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) =
act_from_type(Enzyme.Active, reverse, needs_primal)
@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) = act_from_type(
Enzyme.Active, reverse, needs_primal
)
@inline act_from_type(::Type{<:Enzyme.Const}, reverse, needs_primal) =
if needs_primal
enzyme_const
Expand All @@ -132,10 +137,12 @@ const enzyme_constnoneed = 5
end
end

@inline act_from_type(::Type{<:Enzyme.BatchDuplicated}, reverse, needs_primal) =
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
@inline act_from_type(::Type{<:Enzyme.BatchDuplicatedNoNeed}, reverse, needs_primal) =
act_from_type(Enzyme.DuplicatedNoNeed, Reverse, needs_primal)
@inline act_from_type(::Type{<:Enzyme.BatchDuplicated}, reverse, needs_primal) = act_from_type(
Enzyme.Duplicated, reverse, needs_primal
)
@inline act_from_type(::Type{<:Enzyme.BatchDuplicatedNoNeed}, reverse, needs_primal) = act_from_type(
Enzyme.DuplicatedNoNeed, Reverse, needs_primal
)

@inline act_from_type(::Type{<:Enzyme.Active}, reverse, needs_primal) =
if needs_primal
Expand Down Expand Up @@ -465,7 +472,7 @@ function overload_autodiff(
false,
TracedUtils.transpose_val(MLIR.IR.result(res, residx));
emptypaths=true,
) #=reverse=#
)
residx += 1
continue
end
Expand Down
2 changes: 1 addition & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ for randfun in (:rand, :randn, :randexp)

# scalars
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}=Float64
rng::AbstractRNG, (::Type{T})=Float64
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
Expand Down
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ include("Compiler.jl")
include("Overlay.jl")

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
::Type{RT}, seen::IdDict, prev::RT, (::Val{copy_if_inactive})=Val(false)
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
if haskey(seen, prev)
return seen[prev]
Expand Down
5 changes: 2 additions & 3 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,8 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
end

indices = [
(
TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1
).mlir_data for i in indices
(TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data
for i in indices
]
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.dynamic_update_slice(
Expand Down
2 changes: 1 addition & 1 deletion src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ for (jlop, hloop, hlocomp) in (
function $(jlop)(
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
) where {T}
return Ops.compare(lhs, rhs; comparison_direction=$(hlocomp))
return Ops.compare(lhs, rhs; comparison_direction=($(hlocomp)))
end

function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T}
Expand Down
8 changes: 4 additions & 4 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ function materialize_traced_array(
end

get_mlir_data(x::TracedRNumber) = x.mlir_data
set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x)
set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data=data; return x)
get_paths(x::TracedRNumber) = x.paths
set_paths!(x::TracedRNumber, paths) = (x.paths = paths; return x)
set_paths!(x::TracedRNumber, paths) = (x.paths=paths; return x)

get_mlir_data(x::TracedRArray) = x.mlir_data
get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x))
get_paths(x::TracedRArray) = x.paths
set_paths!(x::TracedRArray, paths) = (x.paths = paths; return x)
set_paths!(x::TracedRArray, paths) = (x.paths=paths; return x)

get_paths(x::MissingTracedValue) = x.paths
set_paths!(x::MissingTracedValue, paths) = (x.paths = paths; return x)
set_paths!(x::MissingTracedValue, paths) = (x.paths=paths; return x)

function set_mlir_data!(x::TracedRArray, data)
x.mlir_data = data
Expand Down
24 changes: 15 additions & 9 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,21 @@ Base.@nospecializeinfer function traced_type_inner(
end
end

Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) =
UnionAll(TV.var, base_typet(TV.body))
Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) =
TracedRArray{TV.parameters...}

Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) =
UnionAll(TV.var, base_typec(TV.body))
Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) =
(TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}
Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) = UnionAll(
TV.var, base_typet(TV.body)
)
Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) = TracedRArray{
TV.parameters...
}

Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) = UnionAll(
TV.var, base_typec(TV.body)
)
Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = (
TV<:TracedRArray ? ConcreteRArray : ConcreteRNumber
){
TV.parameters...
}

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{<:ConcreteRArray}),
Expand Down
18 changes: 12 additions & 6 deletions src/mlir/IR/AffineExpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ ismultipleof(expr::AffineExpr, factor) = API.mlirAffineExprIsMultipleOf(expr, fa

Checks whether the given affine expression involves AffineDimExpr 'position'.
"""
isfunctionofdimexpr(expr::AffineExpr, position) =
function isfunctionofdimexpr(expr::AffineExpr, position)
API.mlirAffineExprIsFunctionOfDim(expr, position)
end

"""
isdimexpr(affineExpr)
Expand All @@ -71,8 +72,9 @@ isdimexpr(expr::AffineExpr) = API.mlirAffineExprIsADim(expr)

Creates an affine dimension expression with 'position' in the context.
"""
AffineDimensionExpr(position; context::Context=context()) =
function AffineDimensionExpr(position; context::Context=context())
AffineExpr(API.mlirAffineDimExprGet(context, position))
end

"""
issymbolexpr(affineExpr)
Expand All @@ -86,8 +88,9 @@ issymbolexpr(expr::AffineExpr) = API.mlirAffineExprIsASymbol(expr)

Creates an affine symbol expression with 'position' in the context.
"""
SymbolExpr(position; context::Context=context()) =
function SymbolExpr(position; context::Context=context())
AffineExpr(API.mlirAffineSymbolExprGet(context, position))
end

"""
position(affineExpr)
Expand Down Expand Up @@ -120,8 +123,9 @@ isconstantexpr(expr::AffineExpr) = API.mlirAffineExprIsAConstant(expr)

Creates an affine constant expression with 'constant' in the context.
"""
ConstantExpr(constant; context::Context=context()) =
function ConstantExpr(constant; context::Context=context())
AffineExpr(API.mlirAffineConstantExprGet(context, constant))
end

"""
value(affineExpr)
Expand Down Expand Up @@ -189,8 +193,9 @@ isfloordiv(expr::AffineExpr) = API.mlirAffineExprIsAFloorDiv(expr)

Creates an affine floordiv expression with 'lhs' and 'rhs'.
"""
Base.div(lhs::AffineExpr, rhs::AffineExpr) =
function Base.div(lhs::AffineExpr, rhs::AffineExpr)
AffineExpr(API.mlirAffineFloorDivExprGet(lhs, rhs))
end
Base.fld(lhs::AffineExpr, rhs::AffineExpr) = div(lhs, rhs)

"""
Expand All @@ -205,8 +210,9 @@ isceildiv(expr::AffineExpr) = API.mlirAffineExprIsACeilDiv(expr)

Creates an affine ceildiv expression with 'lhs' and 'rhs'.
"""
Base.cld(lhs::AffineExpr, rhs::AffineExpr) =
function Base.cld(lhs::AffineExpr, rhs::AffineExpr)
AffineExpr(API.mlirAffineCeilDivExprGet(lhs, rhs))
end

"""
isbinary(affineExpr)
Expand Down
31 changes: 21 additions & 10 deletions src/mlir/IR/AffineMap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,37 @@ context(map::AffineMap) = API.mlirAffineMapGetContext(map)
Creates a zero result affine map of the given dimensions and symbols in the context.
The affine map is owned by the context.
"""
AffineMap(ndims, nsymbols; context::Context=context()) =
function AffineMap(ndims, nsymbols; context::Context=context())
AffineMap(API.mlirAffineMapZeroResultGet(context, ndims, nsymbols))
end

"""
AffineMap(ndims, nsymbols, affineExprs; context=context())

Creates an affine map with results defined by the given list of affine expressions.
The map resulting map also has the requested number of input dimensions and symbols, regardless of them being used in the results.
"""
AffineMap(ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=context()) =
function AffineMap(ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=context())
AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), exprs))
end

"""
ConstantAffineMap(val; context=context())

Creates a single constant result affine map in the context. The affine map is owned by the context.
"""
ConstantAffineMap(val; context::Context=context()) =
function ConstantAffineMap(val; context::Context=context())
AffineMap(API.mlirAffineMapConstantGet(context, val))
end

"""
IdentityAffineMap(ndims; context=context())

Creates an affine map with 'ndims' identity in the context. The affine map is owned by the context.
"""
IdentityAffineMap(ndims; context::Context=context()) =
function IdentityAffineMap(ndims; context::Context=context())
AffineMap(API.mlirAffineMapMultiDimIdentityGet(context, ndims))
end

"""
MinorIdentityAffineMap(ndims, nresults; context=context())
Expand Down Expand Up @@ -189,8 +193,9 @@ Base.isperm(map::AffineMap) = API.mlirAffineMapIsPermutation(map)

Returns the affine map consisting of the `positions` subset.
"""
submap(map::AffineMap, pos::Vector{Int}) =
function submap(map::AffineMap, pos::Vector{Int})
AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pos))
end

"""
majorsubmap(affineMap, nresults)
Expand All @@ -199,28 +204,34 @@ Returns the affine map consisting of the most major `nresults` results.
Returns the null AffineMap if the `nresults` is equal to zero.
Returns the `affineMap` if `nresults` is greater or equals to number of results of the given affine map.
"""
majorsubmap(map::AffineMap, nresults) =
function majorsubmap(map::AffineMap, nresults)
AffineMap(API.mlirAffineMapGetMajorSubMap(map, nresults))
end

"""
minorsubmap(affineMap, nresults)

Returns the affine map consisting of the most minor `nresults` results. Returns the null AffineMap if the `nresults` is equal to zero.
Returns the `affineMap` if `nresults` is greater or equals to number of results of the given affine map.
"""
minorsubmap(map::AffineMap, nresults) =
function minorsubmap(map::AffineMap, nresults)
AffineMap(API.mlirAffineMapGetMinorSubMap(map, nresults))
end

"""
mlirAffineMapReplace(affineMap, expression => replacement, numResultDims, numResultSyms)

Apply `AffineExpr::replace(map)` to each of the results and return a new new AffineMap with the new results and the specified number of dims and symbols.
"""
Base.replace(
function Base.replace(
map::AffineMap, old_new::Pair{AffineExpr,AffineExpr}, nresultdims, nresultsyms
) = AffineMap(
API.mlirAffineMapReplace(map, old_new.first, old_new.second, nresultdims, nresultsyms),
)
AffineMap(
API.mlirAffineMapReplace(
map, old_new.first, old_new.second, nresultdims, nresultsyms
),
)
end

"""
simplify(affineMaps, size, result, populateResult)
Expand Down
Loading